From c85d1a200ffac7b577a51dc8c7ed58622ee4d80a Mon Sep 17 00:00:00 2001 From: Honglei-Qiu <1044497581@qq.com> Date: Tue, 10 Mar 2026 08:23:44 +0000 Subject: [PATCH 1/3] fix dtype --- .../dtype_generalization_pass.py | 39 +++++++++++++++++++ .../torch/sample_pass/dtype_generalizer.py | 27 ++++++++++++- 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py index e9f2ff379..56ecdd2e4 100755 --- a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py +++ b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py @@ -201,6 +201,45 @@ def create_call_method(node: fx.Node) -> fx.Node: return gm + def remove_redundant_to_calls(self, gm: fx.GraphModule) -> fx.GraphModule: + """ + Remove redundant .to(dtype) calls from the graph. + + After ShapeProp with low-precision inputs, each node's tensor_meta + contains the real runtime dtype. A .to(target_dtype) call is redundant + if its input already has target_dtype. + + Must be called after ShapeProp on the rewritten graph with low-precision inputs. + """ + graph = gm.graph + nodes_to_remove = [] + + for node in graph.nodes: + if node.op != "call_method" or node.target != "to": + continue + # .to(dtype) node: args = (input_tensor, dtype) + if len(node.args) < 2 or node.args[1] != self.torch_dtype: + continue + + input_node = node.args[0] + if not isinstance(input_node, fx.Node): + continue + + # Check if input already has target dtype via ShapeProp metadata + if "tensor_meta" not in input_node.meta: + continue + input_meta = input_node.meta["tensor_meta"] + if hasattr(input_meta, "dtype") and input_meta.dtype == self.torch_dtype: + # Input is already target dtype, this .to() is redundant + node.replace_all_uses_with(input_node) + nodes_to_remove.append(node) + + for node in nodes_to_remove: + graph.erase_node(node) + + gm.recompile() + return gm + def _is_float32_tensor(self, node: fx.Node) -> bool: """ Check if a node represents a float32 tensor. diff --git a/graph_net/torch/sample_pass/dtype_generalizer.py b/graph_net/torch/sample_pass/dtype_generalizer.py index 711ce4a2d..0fb0ce5c5 100755 --- a/graph_net/torch/sample_pass/dtype_generalizer.py +++ b/graph_net/torch/sample_pass/dtype_generalizer.py @@ -342,7 +342,7 @@ def resume(self, rel_model_path: str) -> List[str]: for pass_name in dtype_pass_names: try: sample_dir = self._apply_pass_and_generate( - rel_model_path, traced_model, pass_name + rel_model_path, traced_model, pass_name, inputs ) generated_samples.append(sample_dir) logging.info(f"Generated sample: {sample_dir}") @@ -373,7 +373,11 @@ def _read_dtype_pass_names(self, model_path: str) -> List[str]: return metadata.get(kDataTypeGeneralizationPasses, []) def _apply_pass_and_generate( - self, rel_model_path: str, traced_model: fx.GraphModule, pass_name: str + self, + rel_model_path: str, + traced_model: fx.GraphModule, + pass_name: str, + inputs: list = None, ) -> str: """ Apply a specific pass and generate a new sample. @@ -383,6 +387,8 @@ def _apply_pass_and_generate( traced_model: Original traced model pass_name: Name of the pass file (without .py extension), e.g., "dtype_generalization_pass_float16" + inputs: Original model inputs, used for ShapeProp to remove + redundant .to() calls Returns: Path to the generated sample directory @@ -404,6 +410,23 @@ def _apply_pass_and_generate( gm_copy = copy.deepcopy(traced_model) gm_modified = dtype_pass.rewrite(gm_copy) + # Remove redundant .to() calls by re-running ShapeProp with + # low-precision inputs to get real runtime dtypes, then pruning + # .to() nodes whose input is already target dtype. + if inputs is not None: + try: + torch_dtype = getattr(torch, dtype) + low_prec_inputs = [ + x.to(torch_dtype) + if isinstance(x, torch.Tensor) and x.is_floating_point() + else x + for x in inputs + ] + ShapeProp(gm_modified).propagate(*low_prec_inputs) + gm_modified = dtype_pass.remove_redundant_to_calls(gm_modified) + except Exception as e: + logging.warning(f"Failed to remove redundant .to() calls: {e}") + # Get the list of tensor names that need dtype conversion converted_tensor_names = set(getattr(dtype_pass, "converted_tensor_names", [])) From 7b432935b7d27d4046c8c26e73b0d376b4dcea64 Mon Sep 17 00:00:00 2001 From: Honglei-Qiu <1044497581@qq.com> Date: Wed, 11 Mar 2026 02:20:50 +0000 Subject: [PATCH 2/3] fix dtype --- .../torch/sample_pass/dtype_generalizer.py | 43 ++++++++----------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/graph_net/torch/sample_pass/dtype_generalizer.py b/graph_net/torch/sample_pass/dtype_generalizer.py index 0fb0ce5c5..9e17788b6 100755 --- a/graph_net/torch/sample_pass/dtype_generalizer.py +++ b/graph_net/torch/sample_pass/dtype_generalizer.py @@ -387,8 +387,7 @@ def _apply_pass_and_generate( traced_model: Original traced model pass_name: Name of the pass file (without .py extension), e.g., "dtype_generalization_pass_float16" - inputs: Original model inputs, used for ShapeProp to remove - redundant .to() calls + inputs: Original model inputs (unused, kept for compatibility) Returns: Path to the generated sample directory @@ -410,23 +409,6 @@ def _apply_pass_and_generate( gm_copy = copy.deepcopy(traced_model) gm_modified = dtype_pass.rewrite(gm_copy) - # Remove redundant .to() calls by re-running ShapeProp with - # low-precision inputs to get real runtime dtypes, then pruning - # .to() nodes whose input is already target dtype. - if inputs is not None: - try: - torch_dtype = getattr(torch, dtype) - low_prec_inputs = [ - x.to(torch_dtype) - if isinstance(x, torch.Tensor) and x.is_floating_point() - else x - for x in inputs - ] - ShapeProp(gm_modified).propagate(*low_prec_inputs) - gm_modified = dtype_pass.remove_redundant_to_calls(gm_modified) - except Exception as e: - logging.warning(f"Failed to remove redundant .to() calls: {e}") - # Get the list of tensor names that need dtype conversion converted_tensor_names = set(getattr(dtype_pass, "converted_tensor_names", [])) @@ -436,17 +418,28 @@ def _apply_pass_and_generate( # Copy metadata files of original sample self._copy_sample(rel_model_path, output_dir) - # Update model.py - model_code = serialize_graph_module_to_str(gm_modified) - templated_model_code = utils.apply_templates(model_code) - (output_dir / "model.py").write_text(templated_model_code) - - # Update weight_meta.py and input_meta.py dtypes + # Update weight_meta.py and input_meta.py dtypes FIRST, + # so we can use the updated meta to generate inputs for ShapeProp target_dtype_str = f"torch.{dtype}" self._update_tensor_meta_dtypes( output_dir, converted_tensor_names, target_dtype_str ) + # Remove redundant .to() calls: + # Load inputs from the updated meta files (dtype matches meta exactly), + # run ShapeProp to get real runtime dtypes, then prune redundant .to() nodes. + try: + _, meta_inputs = get_torch_module_and_inputs(str(output_dir)) + ShapeProp(gm_modified).propagate(*meta_inputs) + gm_modified = dtype_pass.remove_redundant_to_calls(gm_modified) + except Exception as e: + logging.warning(f"Failed to remove redundant .to() calls: {e}") + + # Update model.py (after redundant .to() removal) + model_code = serialize_graph_module_to_str(gm_modified) + templated_model_code = utils.apply_templates(model_code) + (output_dir / "model.py").write_text(templated_model_code) + # Update graph_hash.txt model_hash = get_sha256_hash(model_code) (output_dir / "graph_hash.txt").write_text(model_hash) From b8db1784f68d2a71ea4480f213b3eb84eb9a51c9 Mon Sep 17 00:00:00 2001 From: Honglei-Qiu <1044497581@qq.com> Date: Wed, 11 Mar 2026 06:39:45 +0000 Subject: [PATCH 3/3] fix dtype --- graph_net/torch/sample_pass/dtype_generalizer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/graph_net/torch/sample_pass/dtype_generalizer.py b/graph_net/torch/sample_pass/dtype_generalizer.py index 9e17788b6..7da61ad70 100755 --- a/graph_net/torch/sample_pass/dtype_generalizer.py +++ b/graph_net/torch/sample_pass/dtype_generalizer.py @@ -342,7 +342,7 @@ def resume(self, rel_model_path: str) -> List[str]: for pass_name in dtype_pass_names: try: sample_dir = self._apply_pass_and_generate( - rel_model_path, traced_model, pass_name, inputs + rel_model_path, traced_model, pass_name ) generated_samples.append(sample_dir) logging.info(f"Generated sample: {sample_dir}") @@ -377,7 +377,6 @@ def _apply_pass_and_generate( rel_model_path: str, traced_model: fx.GraphModule, pass_name: str, - inputs: list = None, ) -> str: """ Apply a specific pass and generate a new sample. @@ -387,7 +386,6 @@ def _apply_pass_and_generate( traced_model: Original traced model pass_name: Name of the pass file (without .py extension), e.g., "dtype_generalization_pass_float16" - inputs: Original model inputs (unused, kept for compatibility) Returns: Path to the generated sample directory @@ -429,6 +427,7 @@ def _apply_pass_and_generate( # Load inputs from the updated meta files (dtype matches meta exactly), # run ShapeProp to get real runtime dtypes, then prune redundant .to() nodes. try: + torch.cuda.empty_cache() _, meta_inputs = get_torch_module_and_inputs(str(output_dir)) ShapeProp(gm_modified).propagate(*meta_inputs) gm_modified = dtype_pass.remove_redundant_to_calls(gm_modified)