Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
from .fuse_view_copy_transform_pass import FuseViewCopyTransformPass # noqa
from .insert_const_shapes import InsertConstShapesPass # noqa
from .insert_dynamic_padding import InsertDynamicPaddingPass # noqa
from .insert_int32_casts_after_int64_placeholders import ( # noqa
InsertInt32CastsAfterInt64PlaceholdersPass,
)
Expand Down
88 changes: 88 additions & 0 deletions backends/arm/_passes/insert_dynamic_padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Type

import torch

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm.tosa.dialect.shape import is_shape_op_node

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, ProxyValue


class InsertDynamicPaddingPass(ArmPass):
"""This pass rewrites conv operations with padding to use an explicit pad
operator before the conv2d operation and setting the padding to zero in the
conv2d operator. E.g. conv2d(x, weight, bias, stride, padding, dilation)
becomes: x_padded = pad(x, explicit_padding) conv2d(x_padded,
weight, bias, stride, (0,0,0,0), dilation) where explicit_padding is
calculated based on the original padding value.

To be used with dynamic shapes only.

"""

_passes_required_after: Set[Type[ExportPass]] = set()

def _is_dynamic_padding(
self, padding: ProxyValue | list[int] | tuple[int, ...]
) -> bool:
return (isinstance(padding, ProxyValue) and is_shape_op_node(padding.node)) or (
(
isinstance(padding, (list, tuple))
and any(isinstance(p, torch.SymInt) for p in padding)
)
)

def call_operator(self, op, args, kwargs, meta, updated=False) -> ProxyValue:
if op not in (
exir_ops.backend.tosa.CONV2D.default,
exir_ops.backend.tosa.DEPTHWISE_CONV2D.default,
):
return super().call_operator(op, args, kwargs, meta, updated)
padding = args[4]
if not self._is_dynamic_padding(padding):
return super().call_operator(op, args, kwargs, meta, updated)

# Create a pad op before conv2d
input_tensor = args[0]

zero_padding = [0, 0, 0, 0]
NC_padding = super().call_shape_operator(
exir_ops.backend.tosa.CONST_SHAPE.default,
(zero_padding,),
{},
meta,
True,
)

padding_shape_args = [NC_padding, padding]

padding_shape = super().call_shape_operator(
exir_ops.backend.tosa.CONCAT_SHAPE.default,
(padding_shape_args,),
{},
meta,
True,
)

pad_res = super().call_operator(
exir_ops.backend.tosa.PAD.default,
(
input_tensor,
padding_shape,
),
{
"value": 0,
},
meta,
True,
)
new_conv2d_args = list(args)
new_conv2d_args[0] = pad_res
new_conv2d_args[4] = zero_padding
return super().call_operator(op, tuple(new_conv2d_args), kwargs, meta, updated)
69 changes: 69 additions & 0 deletions backends/arm/test/passes/test_insert_dynamic_padding_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm._passes.insert_dynamic_padding import (
InsertDynamicPaddingPass,
)
from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass
from executorch.backends.arm.tosa.specification import (
TosaLoweringContext,
TosaSpecification,
)
from executorch.exir import to_edge
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export import Dim, export


class ConvModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, kernel_size=2, stride=3, padding=2)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)


def test_insert_dynamic_padding_no_target():
model = ConvModule()
example_inputs = (torch.randn(1, 3, 8, 8),)
ep = export(
model,
example_inputs,
dynamic_shapes={
"x": {2: Dim("height", min=4, max=10), 3: Dim("width", min=4, max=10)}
},
)
edge_model = to_edge(ep)
with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")):
edge_model = edge_model.transform(
[RewriteConvPass(edge_model.exported_program())]
)
nodes = edge_model.exported_program().graph.nodes
conv_node = next(
n for n in nodes if n.target == exir_ops.backend.tosa.CONV2D.default
)
initial_padding = conv_node.args[4]
assert any(isinstance(p, torch.SymInt) for p in initial_padding)

edge_model = edge_model.transform(
[
InsertDynamicPaddingPass(),
]
)
nodes = edge_model.exported_program().graph.nodes
conv_node = next(
n for n in nodes if n.target == exir_ops.backend.tosa.CONV2D.default
)
padding = conv_node.args[4]
assert padding == [0, 0, 0, 0]
padding_node = next(
n for n in nodes if n.target == exir_ops.backend.tosa.PAD.default
)
assert padding_node is not None
pad_list = padding_node.args[1].meta["val"]
assert len(pad_list) == 8
assert pad_list[:4] == [0, 0, 0, 0] # NC-padding
assert pad_list[4:] == initial_padding # HW-padding
4 changes: 2 additions & 2 deletions backends/arm/tosa/dialect/ops/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def validate_conv2d_args_dtypes(
"Tensor weight, "
"Tensor bias, "
"int[2] stride, "
"int[4] pad, "
"SymInt[4] pad, "
"int[2] dilation) -> Tensor", # schema
TosaSpecification.all_versions_and_profiles(), # target TOSA specifications
)
Expand All @@ -91,7 +91,7 @@ def CONV2D(
weight: torch.Tensor,
bias: torch.Tensor,
stride: list[int],
pad: list[int],
pad: list[int | torch.SymInt],
dilation: list[int],
) -> torch.Tensor:
tosa_spec = get_context_spec()
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/tosa/dialect/ops/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"Tensor weight, "
"Tensor bias, "
"int[2] stride, "
"int[4] pad, "
"SymInt[4] pad, "
"int[2] dialation) -> Tensor", # schema
TosaSpecification.all_versions_and_profiles(),
)
Expand All @@ -29,7 +29,7 @@ def DEPTHWISE_CONV2D(
weight: torch.Tensor,
bias: torch.Tensor,
stride: list[int],
pad: list[int],
pad: list[int | torch.SymInt],
dilation: list[int],
) -> torch.Tensor:
tosa_spec = get_context_spec()
Expand Down
Loading