diff --git a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py index e9f2ff379..56ecdd2e4 100755 --- a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py +++ b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py @@ -201,6 +201,45 @@ def create_call_method(node: fx.Node) -> fx.Node: return gm + def remove_redundant_to_calls(self, gm: fx.GraphModule) -> fx.GraphModule: + """ + Remove redundant .to(dtype) calls from the graph. + + After ShapeProp with low-precision inputs, each node's tensor_meta + contains the real runtime dtype. A .to(target_dtype) call is redundant + if its input already has target_dtype. + + Must be called after ShapeProp on the rewritten graph with low-precision inputs. + """ + graph = gm.graph + nodes_to_remove = [] + + for node in graph.nodes: + if node.op != "call_method" or node.target != "to": + continue + # .to(dtype) node: args = (input_tensor, dtype) + if len(node.args) < 2 or node.args[1] != self.torch_dtype: + continue + + input_node = node.args[0] + if not isinstance(input_node, fx.Node): + continue + + # Check if input already has target dtype via ShapeProp metadata + if "tensor_meta" not in input_node.meta: + continue + input_meta = input_node.meta["tensor_meta"] + if hasattr(input_meta, "dtype") and input_meta.dtype == self.torch_dtype: + # Input is already target dtype, this .to() is redundant + node.replace_all_uses_with(input_node) + nodes_to_remove.append(node) + + for node in nodes_to_remove: + graph.erase_node(node) + + gm.recompile() + return gm + def _is_float32_tensor(self, node: fx.Node) -> bool: """ Check if a node represents a float32 tensor. diff --git a/graph_net/torch/sample_pass/dtype_generalizer.py b/graph_net/torch/sample_pass/dtype_generalizer.py index 711ce4a2d..7da61ad70 100755 --- a/graph_net/torch/sample_pass/dtype_generalizer.py +++ b/graph_net/torch/sample_pass/dtype_generalizer.py @@ -373,7 +373,10 @@ def _read_dtype_pass_names(self, model_path: str) -> List[str]: return metadata.get(kDataTypeGeneralizationPasses, []) def _apply_pass_and_generate( - self, rel_model_path: str, traced_model: fx.GraphModule, pass_name: str + self, + rel_model_path: str, + traced_model: fx.GraphModule, + pass_name: str, ) -> str: """ Apply a specific pass and generate a new sample. @@ -413,17 +416,29 @@ def _apply_pass_and_generate( # Copy metadata files of original sample self._copy_sample(rel_model_path, output_dir) - # Update model.py - model_code = serialize_graph_module_to_str(gm_modified) - templated_model_code = utils.apply_templates(model_code) - (output_dir / "model.py").write_text(templated_model_code) - - # Update weight_meta.py and input_meta.py dtypes + # Update weight_meta.py and input_meta.py dtypes FIRST, + # so we can use the updated meta to generate inputs for ShapeProp target_dtype_str = f"torch.{dtype}" self._update_tensor_meta_dtypes( output_dir, converted_tensor_names, target_dtype_str ) + # Remove redundant .to() calls: + # Load inputs from the updated meta files (dtype matches meta exactly), + # run ShapeProp to get real runtime dtypes, then prune redundant .to() nodes. + try: + torch.cuda.empty_cache() + _, meta_inputs = get_torch_module_and_inputs(str(output_dir)) + ShapeProp(gm_modified).propagate(*meta_inputs) + gm_modified = dtype_pass.remove_redundant_to_calls(gm_modified) + except Exception as e: + logging.warning(f"Failed to remove redundant .to() calls: {e}") + + # Update model.py (after redundant .to() removal) + model_code = serialize_graph_module_to_str(gm_modified) + templated_model_code = utils.apply_templates(model_code) + (output_dir / "model.py").write_text(templated_model_code) + # Update graph_hash.txt model_hash = get_sha256_hash(model_code) (output_dir / "graph_hash.txt").write_text(model_hash)