Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,45 @@ def create_call_method(node: fx.Node) -> fx.Node:

return gm

def remove_redundant_to_calls(self, gm: fx.GraphModule) -> fx.GraphModule:
"""
Remove redundant .to(dtype) calls from the graph.
After ShapeProp with low-precision inputs, each node's tensor_meta
contains the real runtime dtype. A .to(target_dtype) call is redundant
if its input already has target_dtype.
Must be called after ShapeProp on the rewritten graph with low-precision inputs.
"""
graph = gm.graph
nodes_to_remove = []

for node in graph.nodes:
if node.op != "call_method" or node.target != "to":
continue
# .to(dtype) node: args = (input_tensor, dtype)
if len(node.args) < 2 or node.args[1] != self.torch_dtype:
continue

input_node = node.args[0]
if not isinstance(input_node, fx.Node):
continue

# Check if input already has target dtype via ShapeProp metadata
if "tensor_meta" not in input_node.meta:
continue
input_meta = input_node.meta["tensor_meta"]
if hasattr(input_meta, "dtype") and input_meta.dtype == self.torch_dtype:
# Input is already target dtype, this .to() is redundant
node.replace_all_uses_with(input_node)
nodes_to_remove.append(node)

for node in nodes_to_remove:
graph.erase_node(node)

gm.recompile()
return gm

def _is_float32_tensor(self, node: fx.Node) -> bool:
"""
Check if a node represents a float32 tensor.
Expand Down
32 changes: 24 additions & 8 deletions graph_net/torch/sample_pass/dtype_generalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def resume(self, rel_model_path: str) -> List[str]:
for pass_name in dtype_pass_names:
try:
sample_dir = self._apply_pass_and_generate(
rel_model_path, traced_model, pass_name
rel_model_path, traced_model, pass_name, inputs
)
generated_samples.append(sample_dir)
logging.info(f"Generated sample: {sample_dir}")
Expand Down Expand Up @@ -373,7 +373,11 @@ def _read_dtype_pass_names(self, model_path: str) -> List[str]:
return metadata.get(kDataTypeGeneralizationPasses, [])

def _apply_pass_and_generate(
self, rel_model_path: str, traced_model: fx.GraphModule, pass_name: str
self,
rel_model_path: str,
traced_model: fx.GraphModule,
pass_name: str,
inputs: list = None,
) -> str:
"""
Apply a specific pass and generate a new sample.
Expand All @@ -383,6 +387,7 @@ def _apply_pass_and_generate(
traced_model: Original traced model
pass_name: Name of the pass file (without .py extension),
e.g., "dtype_generalization_pass_float16"
inputs: Original model inputs (unused, kept for compatibility)
Returns:
Path to the generated sample directory
Expand Down Expand Up @@ -413,17 +418,28 @@ def _apply_pass_and_generate(
# Copy metadata files of original sample
self._copy_sample(rel_model_path, output_dir)

# Update model.py
model_code = serialize_graph_module_to_str(gm_modified)
templated_model_code = utils.apply_templates(model_code)
(output_dir / "model.py").write_text(templated_model_code)

# Update weight_meta.py and input_meta.py dtypes
# Update weight_meta.py and input_meta.py dtypes FIRST,
# so we can use the updated meta to generate inputs for ShapeProp
target_dtype_str = f"torch.{dtype}"
self._update_tensor_meta_dtypes(
output_dir, converted_tensor_names, target_dtype_str
)

# Remove redundant .to() calls:
# Load inputs from the updated meta files (dtype matches meta exactly),
# run ShapeProp to get real runtime dtypes, then prune redundant .to() nodes.
try:
_, meta_inputs = get_torch_module_and_inputs(str(output_dir))
ShapeProp(gm_modified).propagate(*meta_inputs)
gm_modified = dtype_pass.remove_redundant_to_calls(gm_modified)
except Exception as e:
logging.warning(f"Failed to remove redundant .to() calls: {e}")

# Update model.py (after redundant .to() removal)
model_code = serialize_graph_module_to_str(gm_modified)
templated_model_code = utils.apply_templates(model_code)
(output_dir / "model.py").write_text(templated_model_code)

# Update graph_hash.txt
model_hash = get_sha256_hash(model_code)
(output_dir / "graph_hash.txt").write_text(model_hash)
Expand Down
Loading