Skip to content
Merged

dtype #665

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 29 additions & 16 deletions graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -83,32 +83,44 @@ def rewrite(self, gm: fx.GraphModule) -> fx.GraphModule:
Rewrite the graph to convert dtypes.

Strategy:
1. For each placeholder (input), insert .to(target_dtype) after it
2. For each get_attr (weight), insert .to(target_dtype) if not preserved
3. Update the graph and recompile
1. For placeholder (input) and get_attr (weight) nodes, do NOT insert
.to(dtype) in the graph. Instead, collect their names in
self.converted_tensor_names so that weight_meta.py / input_meta.py
can be updated externally.
2. For AMP-sensitive call_function/call_method nodes, insert .to(dtype)
for float32 arguments that are not already converted via meta file.
3. Non-AMP ops are copied as-is without dtype propagation inference.
4. Update the graph and recompile.
"""
new_graph = fx.Graph()
val_map = {}
self.converted_tensor_names = []

# Track nodes whose dtype is converted via meta file (placeholder/get_attr).
# Only these nodes are known to output target dtype at runtime.
# We do NOT infer dtype propagation for intermediate ops, to avoid
# incorrect assumptions about operator dtype behavior.
nodes_converted_via_meta = set()

def create_placeholder(node: fx.Node) -> fx.Node:
"""Create a placeholder node with dtype conversion if needed."""
"""Create a placeholder node, collecting name if dtype conversion needed."""
new_node = new_graph.node_copy(node, lambda x: val_map.get(x, x))
if self._is_float32_tensor(node):
attr_name = str(node.target)
if self.should_preserve_weight(attr_name):
return new_node

return new_graph.call_method("to", args=(new_node, self.torch_dtype))
if not self.should_preserve_weight(attr_name):
self.converted_tensor_names.append(attr_name)
nodes_converted_via_meta.add(node)
return new_node

def create_get_attr(node: fx.Node) -> fx.Node:
"""Create a get_attr node with dtype conversion if needed."""
"""Create a get_attr node, collecting name if dtype conversion needed."""
new_node = new_graph.node_copy(node, lambda x: val_map.get(x, x))
attr_name = str(node.target)
if self._is_float32_tensor(node) and not self.should_preserve_weight(
attr_name
):
return new_graph.call_method("to", args=(new_node, self.torch_dtype))
self.converted_tensor_names.append(attr_name)
nodes_converted_via_meta.add(node)
return new_node

def create_new_args(node: fx.Node) -> list:
Expand All @@ -118,7 +130,10 @@ def create_new_args(node: fx.Node) -> list:
for arg in node.args:
if isinstance(arg, fx.Node):
mapped = val_map[arg]
if self._is_float32_tensor(arg):
if (
self._is_float32_tensor(arg)
and arg not in nodes_converted_via_meta
):
mapped = new_graph.call_method("to", (mapped, self.torch_dtype))
new_args.append(mapped)
else:
Expand All @@ -132,10 +147,9 @@ def create_new_kwargs(node: fx.Node) -> dict:
for k, v in node.kwargs.items():
if isinstance(v, fx.Node):
mapped = val_map[v]
if self._is_float32_tensor(v):
if self._is_float32_tensor(v) and v not in nodes_converted_via_meta:
mapped = new_graph.call_method("to", (mapped, self.torch_dtype))
else:
new_kwargs[k] = mapped
new_kwargs[k] = mapped
else:
new_kwargs[k] = v
return new_kwargs
Expand All @@ -145,8 +159,8 @@ def create_call_function(node: fx.Node) -> fx.Node:
if node.target not in AMP_CALL_FUNCTION:
return new_graph.node_copy(node, lambda x: val_map[x])

# AMP ops: insert .to() for float32 args not converted via meta file
new_args = create_new_args(node)

new_kwargs = create_new_kwargs(node)

return new_graph.call_function(
Expand All @@ -161,7 +175,6 @@ def create_call_method(node: fx.Node) -> fx.Node:
return new_graph.node_copy(node, lambda x: val_map[x])

new_args = create_new_args(node)

new_kwargs = create_new_kwargs(node)

return new_graph.call_method(
Expand Down
48 changes: 48 additions & 0 deletions graph_net/torch/sample_pass/dtype_generalizer.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin

from graph_net.hash_util import get_sha256_hash
from graph_net.tensor_meta import TensorMeta

# Weights that must remain float32 for numerical stability
FLOAT32_PRESERVED_WEIGHTS = {
Expand Down Expand Up @@ -403,6 +404,9 @@ def _apply_pass_and_generate(
gm_copy = copy.deepcopy(traced_model)
gm_modified = dtype_pass.rewrite(gm_copy)

# Get the list of tensor names that need dtype conversion
converted_tensor_names = set(getattr(dtype_pass, "converted_tensor_names", []))

# Generate output directory
output_dir = self._get_output_dir(rel_model_path, dtype)

Expand All @@ -414,6 +418,12 @@ def _apply_pass_and_generate(
templated_model_code = utils.apply_templates(model_code)
(output_dir / "model.py").write_text(templated_model_code)

# Update weight_meta.py and input_meta.py dtypes
target_dtype_str = f"torch.{dtype}"
self._update_tensor_meta_dtypes(
output_dir, converted_tensor_names, target_dtype_str
)

# Update graph_hash.txt
model_hash = get_sha256_hash(model_code)
(output_dir / "graph_hash.txt").write_text(model_hash)
Expand Down Expand Up @@ -445,6 +455,44 @@ def _update_sample_metadata(self, sample_dir: Path, dtype: str) -> None:
update_json(graph_net_json_path, kDtypeGeneralizationPrecision, dtype)
update_json(graph_net_json_path, kDtypeGeneralizationGenerated, True)

def _update_tensor_meta_dtypes(
self,
sample_dir: Path,
converted_tensor_names: set,
target_dtype_str: str,
) -> None:
"""
Update dtype in weight_meta.py and input_meta.py for converted tensors.

Instead of inserting .to(dtype) in model.py, we modify the dtype field
in the meta files so that tensors are generated with the target dtype
directly.

Args:
sample_dir: Path to generated sample directory
converted_tensor_names: Set of tensor names that were converted
target_dtype_str: Target dtype string, e.g. "torch.float16"
"""
for meta_file in ["weight_meta.py", "input_meta.py"]:
meta_path = sample_dir / meta_file
if not meta_path.exists():
continue

tensor_metas = TensorMeta.unserialize_from_py_file_order_preserved(
str(meta_path)
)
changed = False
for tm in tensor_metas:
# FX Graph node.target corresponds to tm.name (the forward
# parameter name), not tm.original_name. Check both to be safe.
if tm.name in converted_tensor_names or (
tm.original_name and tm.original_name in converted_tensor_names
):
tm.dtype = target_dtype_str
changed = True
if changed:
TensorMeta.save_tensor_metas(str(meta_path), tensor_metas)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确认一下,meta保存的顺序跟原来是否一致?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在用 AST 解析保证顺序


def _copy_sample(self, rel_model_path: str, output_dir: str) -> None:
"""
Copy files of sample.
Expand Down