From 5809e6798d5dbe973d71b18ddc7f4875b40f1b15 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Tue, 2 Dec 2025 15:35:52 +0100 Subject: [PATCH] Arm backend: Add pass to handle dynamic conv+pad MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds pass that replaces dynamic conv-padding with pad+conv sequence. Co-authored-by: Per Åstrand Signed-off-by: Oscar Andersson Change-Id: I80d014c1f0cb445ccdb021151b350aeebe481920 --- backends/arm/_passes/__init__.py | 1 + .../arm/_passes/insert_dynamic_padding.py | 88 +++++++++++++++++++ .../test_insert_dynamic_padding_pass.py | 69 +++++++++++++++ backends/arm/tosa/dialect/ops/conv2d.py | 4 +- .../arm/tosa/dialect/ops/depthwise_conv2d.py | 4 +- 5 files changed, 162 insertions(+), 4 deletions(-) create mode 100644 backends/arm/_passes/insert_dynamic_padding.py create mode 100644 backends/arm/test/passes/test_insert_dynamic_padding_pass.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 6650efd0883..9bbd24e02b1 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -111,6 +111,7 @@ from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa from .fuse_view_copy_transform_pass import FuseViewCopyTransformPass # noqa from .insert_const_shapes import InsertConstShapesPass # noqa +from .insert_dynamic_padding import InsertDynamicPaddingPass # noqa from .insert_int32_casts_after_int64_placeholders import ( # noqa InsertInt32CastsAfterInt64PlaceholdersPass, ) diff --git a/backends/arm/_passes/insert_dynamic_padding.py b/backends/arm/_passes/insert_dynamic_padding.py new file mode 100644 index 00000000000..ea03e231ae8 --- /dev/null +++ b/backends/arm/_passes/insert_dynamic_padding.py @@ -0,0 +1,88 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +import torch + +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm.tosa.dialect.shape import is_shape_op_node + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, ProxyValue + + +class InsertDynamicPaddingPass(ArmPass): + """This pass rewrites conv operations with padding to use an explicit pad + operator before the conv2d operation and setting the padding to zero in the + conv2d operator. E.g. conv2d(x, weight, bias, stride, padding, dilation) + becomes: x_padded = pad(x, explicit_padding) conv2d(x_padded, + weight, bias, stride, (0,0,0,0), dilation) where explicit_padding is + calculated based on the original padding value. + + To be used with dynamic shapes only. + + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def _is_dynamic_padding( + self, padding: ProxyValue | list[int] | tuple[int, ...] + ) -> bool: + return (isinstance(padding, ProxyValue) and is_shape_op_node(padding.node)) or ( + ( + isinstance(padding, (list, tuple)) + and any(isinstance(p, torch.SymInt) for p in padding) + ) + ) + + def call_operator(self, op, args, kwargs, meta, updated=False) -> ProxyValue: + if op not in ( + exir_ops.backend.tosa.CONV2D.default, + exir_ops.backend.tosa.DEPTHWISE_CONV2D.default, + ): + return super().call_operator(op, args, kwargs, meta, updated) + padding = args[4] + if not self._is_dynamic_padding(padding): + return super().call_operator(op, args, kwargs, meta, updated) + + # Create a pad op before conv2d + input_tensor = args[0] + + zero_padding = [0, 0, 0, 0] + NC_padding = super().call_shape_operator( + exir_ops.backend.tosa.CONST_SHAPE.default, + (zero_padding,), + {}, + meta, + True, + ) + + padding_shape_args = [NC_padding, padding] + + padding_shape = super().call_shape_operator( + exir_ops.backend.tosa.CONCAT_SHAPE.default, + (padding_shape_args,), + {}, + meta, + True, + ) + + pad_res = super().call_operator( + exir_ops.backend.tosa.PAD.default, + ( + input_tensor, + padding_shape, + ), + { + "value": 0, + }, + meta, + True, + ) + new_conv2d_args = list(args) + new_conv2d_args[0] = pad_res + new_conv2d_args[4] = zero_padding + return super().call_operator(op, tuple(new_conv2d_args), kwargs, meta, updated) diff --git a/backends/arm/test/passes/test_insert_dynamic_padding_pass.py b/backends/arm/test/passes/test_insert_dynamic_padding_pass.py new file mode 100644 index 00000000000..01d2f7fd669 --- /dev/null +++ b/backends/arm/test/passes/test_insert_dynamic_padding_pass.py @@ -0,0 +1,69 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm._passes.insert_dynamic_padding import ( + InsertDynamicPaddingPass, +) +from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir import to_edge +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import Dim, export + + +class ConvModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, kernel_size=2, stride=3, padding=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +def test_insert_dynamic_padding_no_target(): + model = ConvModule() + example_inputs = (torch.randn(1, 3, 8, 8),) + ep = export( + model, + example_inputs, + dynamic_shapes={ + "x": {2: Dim("height", min=4, max=10), 3: Dim("width", min=4, max=10)} + }, + ) + edge_model = to_edge(ep) + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")): + edge_model = edge_model.transform( + [RewriteConvPass(edge_model.exported_program())] + ) + nodes = edge_model.exported_program().graph.nodes + conv_node = next( + n for n in nodes if n.target == exir_ops.backend.tosa.CONV2D.default + ) + initial_padding = conv_node.args[4] + assert any(isinstance(p, torch.SymInt) for p in initial_padding) + + edge_model = edge_model.transform( + [ + InsertDynamicPaddingPass(), + ] + ) + nodes = edge_model.exported_program().graph.nodes + conv_node = next( + n for n in nodes if n.target == exir_ops.backend.tosa.CONV2D.default + ) + padding = conv_node.args[4] + assert padding == [0, 0, 0, 0] + padding_node = next( + n for n in nodes if n.target == exir_ops.backend.tosa.PAD.default + ) + assert padding_node is not None + pad_list = padding_node.args[1].meta["val"] + assert len(pad_list) == 8 + assert pad_list[:4] == [0, 0, 0, 0] # NC-padding + assert pad_list[4:] == initial_padding # HW-padding diff --git a/backends/arm/tosa/dialect/ops/conv2d.py b/backends/arm/tosa/dialect/ops/conv2d.py index 4d8edda9db8..2b991600994 100644 --- a/backends/arm/tosa/dialect/ops/conv2d.py +++ b/backends/arm/tosa/dialect/ops/conv2d.py @@ -82,7 +82,7 @@ def validate_conv2d_args_dtypes( "Tensor weight, " "Tensor bias, " "int[2] stride, " - "int[4] pad, " + "SymInt[4] pad, " "int[2] dilation) -> Tensor", # schema TosaSpecification.all_versions_and_profiles(), # target TOSA specifications ) @@ -91,7 +91,7 @@ def CONV2D( weight: torch.Tensor, bias: torch.Tensor, stride: list[int], - pad: list[int], + pad: list[int | torch.SymInt], dilation: list[int], ) -> torch.Tensor: tosa_spec = get_context_spec() diff --git a/backends/arm/tosa/dialect/ops/depthwise_conv2d.py b/backends/arm/tosa/dialect/ops/depthwise_conv2d.py index 64e0e16479c..7d8d5f9edc8 100644 --- a/backends/arm/tosa/dialect/ops/depthwise_conv2d.py +++ b/backends/arm/tosa/dialect/ops/depthwise_conv2d.py @@ -20,7 +20,7 @@ "Tensor weight, " "Tensor bias, " "int[2] stride, " - "int[4] pad, " + "SymInt[4] pad, " "int[2] dialation) -> Tensor", # schema TosaSpecification.all_versions_and_profiles(), ) @@ -29,7 +29,7 @@ def DEPTHWISE_CONV2D( weight: torch.Tensor, bias: torch.Tensor, stride: list[int], - pad: list[int], + pad: list[int | torch.SymInt], dilation: list[int], ) -> torch.Tensor: tosa_spec = get_context_spec()