Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 157 additions & 2 deletions backends/arm/_passes/to_tosa_memory_format_pass.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -222,6 +222,21 @@

return (N_old != N_new) or (C_old != C_new)

@staticmethod
def _is_nop_transpose(shape, perm) -> bool:
"""Return ``True`` when a transpose only permutes size-1 dimensions.

A transpose is a NOP (no-operation) when the relative order of
all non-size-1 dimensions is unchanged — permuting size-1 dims
does not alter the physical byte layout.

Example: ``[14, 72, 1, 1]`` with perm ``(0, 1, 3, 2)`` → True
(only the two trailing size-1 dims swap).
"""
old_order = [i for i, s in enumerate(shape) if s != 1]
new_order = [i for i, s in zip(perm, [shape[p] for p in perm]) if s != 1]
return old_order == new_order

@staticmethod
def insert_input_transpose(node, input_node, graph_module):
"""Ensure an input tensor is converted to channels-last ordering by
Expand Down Expand Up @@ -271,7 +286,7 @@
# Guard: mem_format must be a true permutation for the current rank
assert sorted(mem_format) == list(
range(rank)
), f"bad perm {mem_format} for rank {rank} in insert_input_transpose"
), f"bad perm {mem_format} for rank {rank} in insert_output_transpose"

with graph_module.graph.inserting_after(node):
permute_node = create_node(
Expand All @@ -296,6 +311,104 @@
for user in users:
user.replace_input_with(node, permute_node)

@staticmethod
def _get_shape_indices(
src_shape: list[int], tgt_shape: list[int]
) -> list[list[int]] | None:
"""Greedy dimension matching for reshape operations.

For each target dimension, greedily consumes contiguous source
dimensions whose product equals the target size. Size-1 target
dimensions that do not correspond to any source dimension produce
empty index lists (inserted dims).

Returns ``None`` when no valid mapping exists.
"""
src_idx = 0
result: list[list[int]] = []

for tgt_dim in tgt_shape:
if tgt_dim <= 0:
return None

indices: list[int] = []
remaining = tgt_dim

while src_idx < len(src_shape) and remaining % src_shape[src_idx] == 0:
indices.append(src_idx)
remaining //= src_shape[src_idx]
src_idx += 1
if remaining == 1:
break

if remaining != 1:
return None

result.append(indices)

if src_idx != len(src_shape):
return None

return result

@staticmethod
def _is_monotonic(indices: list[list[int]]) -> bool:
"""Return ``True`` when all non-empty index groups are strictly
ordered — i.e. each group's indices follow the previous group's.
"""
last_max = -1
for group in indices:
if not group:
continue
if group[0] <= last_max:
return False
last_max = group[-1]
return True

@staticmethod
def _is_nhwc_safe_reshape(
input_shape, output_shape, input_sr, output_sr # noqa: ARG004
) -> bool:
"""Detect whether a 4-D+ reshape can operate directly on NHWC data.

By the time ``ToTosaMemoryFormatPass`` runs, 4-D tensor shapes in
``meta["val"]`` are already in NHWC physical order (the channel
dimension sits at position ``rank - spatial_rank - 1``, not at
position 1 as in NCHW). We therefore check the shape indices on
the **raw** input/output shapes — no extra permutation is needed.

Returns ``True`` when:
1. The reshape has monotonic shape_indices (each output dim maps
to a contiguous, in-order group of input dims), AND
2. The channel dimension is preserved alone (not merged with
spatial dims).
"""
rank_in = len(input_shape)
rank_out = len(output_shape)
if rank_in < 4 or rank_out < 4:
return False

indices = ToTosaMemoryFormatPass._get_shape_indices(
list(input_shape), list(output_shape)
)
if indices is None:
return False

if not ToTosaMemoryFormatPass._is_monotonic(indices):
return False

# In the TOSA pipeline the physical memory order is NHWC.
# The channel dimension in NHWC is always the **last** axis
# (position ``rank - 1``). It must appear *alone* in its
# output group — if it is merged with spatial dims the reshape
# would reorder channel data and the optimisation is invalid.
channel_idx = rank_in - 1
for group in indices:
if channel_idx in group:
return len(group) == 1
# Channel dim not consumed by any group — conservative reject.
return False

@staticmethod
def _insert_view_transpose(
input_shape, output_shape, node, input_node, graph_module
Expand All @@ -317,6 +430,14 @@
output_sr,
)

# When the NHWC-space reshape has monotonic shape_indices the
# view_copy can operate directly on NHWC data — no transposes
# are needed.
if channel_reshape and ToTosaMemoryFormatPass._is_nhwc_safe_reshape(
input_shape, output_shape, input_sr, output_sr
):
return

if (
channel_reshape or nhwc_to_nchw
) and ToTosaMemoryFormatPass.memory_format_differs(input_shape, input_sr):
Expand All @@ -329,7 +450,7 @@
) and ToTosaMemoryFormatPass.memory_format_differs(output_shape, output_sr):
ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module)

def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):

Check warning on line 453 in backends/arm/_passes/to_tosa_memory_format_pass.py

View workflow job for this annotation

GitHub Actions / lintrunner

FLAKE8 C901

'ToTosaMemoryFormatPass.insert_tosa_transposes' is too complex (13) See https://www.flake8rules.com/rules/C901.html.
"""Transposes are needed for operators transforming the input to a
different rank, as 4D and 5D-tensors are assumed to be in (N)NHWC-
format, whereas all other are in (N)NCHW format.
Expand All @@ -345,10 +466,44 @@
- 1D/2D tensors

"""
for node in graph_module.graph.nodes:
for node in list(graph_module.graph.nodes):
if node.op != "call_function":
continue

# Eliminate model-level permute_copy ops that are redundant
# with the tosa_dim_order annotation. When a permute_copy's
# permutation matches the channels-last order (or its
# inverse), the permute does the same NCHW↔NHWC conversion
# that tosa_dim_order already handles — keeping both would
# double-convert. Replace with view_copy (identity reshape).
if node.target in (
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.permute.default,
):
perm = list(node.args[1])
rank = len(perm)
sr = node.meta.get("tosa_spatial_rank", 0)

if rank >= 3 and sr >= 1:
cl_order = list(
self._channels_last_order(rank, sr)
)
cl_inv = list(
self._channels_last_inverse_order(rank, sr)
)
if perm == cl_order or perm == cl_inv:
input_node = node.args[0]
output_shape = list(node.meta["val"].shape)
with graph_module.graph.inserting_before(node):
view_node = graph_module.graph.call_function(
exir_ops.edge.aten.view_copy.default,
(input_node, output_shape),
)
view_node.meta = dict(node.meta)
node.replace_all_uses_with(view_node)
graph_module.graph.erase_node(node)
continue

# Transpose views
elif node.target == exir_ops.edge.aten.view_copy.default:
input_node = node.args[0]
Expand Down
154 changes: 154 additions & 0 deletions backends/arm/test/passes/test_to_tosa_memory_format.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -177,11 +177,76 @@
return (torch.rand(4, 4, 4, 4),)


class NHWCSafeSpatialMerge(torch.nn.Module):
"""Test-module with a 4D->4D reshape that merges spatial dims H*W while
preserving the last-dim channel.

For models with view_copy shapes [1,2,14,72]->[1,28,1,72] where C=2
sits at NCHW position 1 and the last dim (72) is the NHWC channel that gets
preserved. ``_is_nhwc_safe_reshape`` detects that shape_indices on the raw
shapes are monotonic with the last dim alone, so no transposes are inserted
around the view_copy.

Setup: conv2d (forces NHWC, C=2) -> view_copy -> add (keeps in NHWC).
"""

ops_before_pass: Dict[str, int] = {}
# Only the 2 I/O transposes for the conv, NO extra transposes from view_copy
ops_after_pass: Dict[str, int] = {
"executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 2
}
ops_not_after_pass: List[str] = []

def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=2, out_channels=2, kernel_size=1, bias=False
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x) # forces NHWC path; output [1, 2, 14, 72]
x = x.view(1, 28, 1, 72) # spatial merge: H*W=2*14->28, last dim 72 preserved
return x + x # keep result 4-D in NHWC

def get_inputs(self) -> input_t:
return (torch.randn(1, 2, 14, 72),)


class NHWCUnsafeChannelChange(torch.nn.Module):
"""Test-module with a 4D->4D reshape that is NOT NHWC-safe because the
target shape cannot be produced by a monotonic merge of NHWC input dims.
The pass MUST still insert transposes around the view_copy.
"""

ops_before_pass: Dict[str, int] = {}
# conv I/O transposes (2) + view_copy transposes (2) = 4
ops_after_pass: Dict[str, int] = {
"executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 4
}
ops_not_after_pass: List[str] = []

def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=72, out_channels=72, kernel_size=1, bias=False
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x) # output [1, 72, 2, 14]
x = x.view(1, 14, 2, 72) # not NHWC-safe (channels shuffled)
return x + x

def get_inputs(self) -> input_t:
return (torch.randn(1, 72, 2, 14),)


modules: Dict[str, ModuleMetadata] = {
"no_nhwc": NoNHWC(),
"parallel_clusters": ParallelClusters(),
"serial_clusters": SerialClusters(),
"reshapes": Reshapes(),
"nhwc_safe_spatial_merge": NHWCSafeSpatialMerge(),
"nhwc_unsafe_channel_change": NHWCUnsafeChannelChange(),
}


Expand Down Expand Up @@ -209,3 +274,92 @@
module_nn = cast(torch.nn.Module, module)
pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), [])
pipeline.run()


# --- Direct unit tests for NHWC-safe reshape helpers ---


def test_get_shape_indices_spatial_merge():
"""[1,2,14,72] -> [1,28,1,72]: merge H*W, insert size-1 dim, preserve C."""
indices = ToTosaMemoryFormatPass._get_shape_indices(
[1, 2, 14, 72], [1, 28, 1, 72]
)
assert indices == [[0], [1, 2], [], [3]]


def test_get_shape_indices_identity():
"""Same shape => each dim maps to itself."""
indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 4], [2, 3, 4])
assert indices == [[0], [1], [2]]


def test_get_shape_indices_full_merge():
"""[2, 3, 4] -> [24]: merge all dims into one."""
indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 4], [24])
assert indices == [[0, 1, 2]]


def test_get_shape_indices_incompatible():
"""Sizes that don't divide => None."""
indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 5], [6, 4])
assert indices is None


def test_get_shape_indices_size_one_insert():
"""[6, 4] -> [6, 1, 4]: inserted size-1 dim in the middle."""
indices = ToTosaMemoryFormatPass._get_shape_indices([6, 4], [6, 1, 4])
assert indices is not None
assert indices == [[0], [], [1]]


def test_is_monotonic_true():
assert ToTosaMemoryFormatPass._is_monotonic([[0], [1, 2], [], [3]])
assert ToTosaMemoryFormatPass._is_monotonic([[0], [], [1], [2, 3]])
assert ToTosaMemoryFormatPass._is_monotonic([[], [0, 1, 2]])


def test_is_monotonic_false():
assert not ToTosaMemoryFormatPass._is_monotonic([[1], [0]])
assert not ToTosaMemoryFormatPass._is_monotonic([[0, 2], [1]])


def test_is_nhwc_safe_forward():
"""Shapes already NHWC by the time the pass runs.
[1,2,14,72] -> [1,28,1,72], sr=2 -> NHWC-safe (spatial merge, C=72 preserved).
"""
assert ToTosaMemoryFormatPass._is_nhwc_safe_reshape(
[1, 2, 14, 72], [1, 28, 1, 72], input_sr=2, output_sr=2
)


def test_is_nhwc_safe_non_4d():
"""Reshapes below rank 4 are never NHWC-safe."""
assert not ToTosaMemoryFormatPass._is_nhwc_safe_reshape(
[6, 4], [24], input_sr=0, output_sr=0
)


def test_is_nop_transpose_size1_swap():
"""[14, 72, 1, 1] with perm (0, 1, 3, 2) only swaps trailing size-1 dims."""
assert ToTosaMemoryFormatPass._is_nop_transpose([14, 72, 1, 1], (0, 1, 3, 2))


def test_is_nop_transpose_real_reorder():
"""[14, 72, 1, 1] with perm (1, 0, 2, 3) swaps non-size-1 dims."""
assert not ToTosaMemoryFormatPass._is_nop_transpose([14, 72, 1, 1], (1, 0, 2, 3))


def test_is_nop_transpose_all_size1():
"""[1, 1, 1, 1] with any perm is always a NOP."""
assert ToTosaMemoryFormatPass._is_nop_transpose([1, 1, 1, 1], (3, 2, 1, 0))


def test_is_nop_transpose_identity():
"""Identity permutation is always a NOP."""
assert ToTosaMemoryFormatPass._is_nop_transpose([2, 3, 4], (0, 1, 2))


def test_is_nop_transpose_nhwc_on_size1_spatial():
"""[1, 28, 1, 72] with channels_last (0,2,3,1): non-size-1 dims 28,72
change relative order (28→pos3, 72→pos2) → NOT a NOP."""
assert not ToTosaMemoryFormatPass._is_nop_transpose([1, 28, 1, 72], (0, 2, 3, 1))
Loading