diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index 52cbf924fc4..66cb1bb1636 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -222,6 +222,21 @@ def get_batch_prod_dim(shape, spatial_rank): return (N_old != N_new) or (C_old != C_new) + @staticmethod + def _is_nop_transpose(shape, perm) -> bool: + """Return ``True`` when a transpose only permutes size-1 dimensions. + + A transpose is a NOP (no-operation) when the relative order of + all non-size-1 dimensions is unchanged — permuting size-1 dims + does not alter the physical byte layout. + + Example: ``[14, 72, 1, 1]`` with perm ``(0, 1, 3, 2)`` → True + (only the two trailing size-1 dims swap). + """ + old_order = [i for i, s in enumerate(shape) if s != 1] + new_order = [i for i, s in zip(perm, [shape[p] for p in perm]) if s != 1] + return old_order == new_order + @staticmethod def insert_input_transpose(node, input_node, graph_module): """Ensure an input tensor is converted to channels-last ordering by @@ -271,7 +286,7 @@ def insert_output_transpose(node, graph_module): # Guard: mem_format must be a true permutation for the current rank assert sorted(mem_format) == list( range(rank) - ), f"bad perm {mem_format} for rank {rank} in insert_input_transpose" + ), f"bad perm {mem_format} for rank {rank} in insert_output_transpose" with graph_module.graph.inserting_after(node): permute_node = create_node( @@ -296,6 +311,104 @@ def insert_output_transpose(node, graph_module): for user in users: user.replace_input_with(node, permute_node) + @staticmethod + def _get_shape_indices( + src_shape: list[int], tgt_shape: list[int] + ) -> list[list[int]] | None: + """Greedy dimension matching for reshape operations. + + For each target dimension, greedily consumes contiguous source + dimensions whose product equals the target size. Size-1 target + dimensions that do not correspond to any source dimension produce + empty index lists (inserted dims). + + Returns ``None`` when no valid mapping exists. + """ + src_idx = 0 + result: list[list[int]] = [] + + for tgt_dim in tgt_shape: + if tgt_dim <= 0: + return None + + indices: list[int] = [] + remaining = tgt_dim + + while src_idx < len(src_shape) and remaining % src_shape[src_idx] == 0: + indices.append(src_idx) + remaining //= src_shape[src_idx] + src_idx += 1 + if remaining == 1: + break + + if remaining != 1: + return None + + result.append(indices) + + if src_idx != len(src_shape): + return None + + return result + + @staticmethod + def _is_monotonic(indices: list[list[int]]) -> bool: + """Return ``True`` when all non-empty index groups are strictly + ordered — i.e. each group's indices follow the previous group's. + """ + last_max = -1 + for group in indices: + if not group: + continue + if group[0] <= last_max: + return False + last_max = group[-1] + return True + + @staticmethod + def _is_nhwc_safe_reshape( + input_shape, output_shape, input_sr, output_sr # noqa: ARG004 + ) -> bool: + """Detect whether a 4-D+ reshape can operate directly on NHWC data. + + By the time ``ToTosaMemoryFormatPass`` runs, 4-D tensor shapes in + ``meta["val"]`` are already in NHWC physical order (the channel + dimension sits at position ``rank - spatial_rank - 1``, not at + position 1 as in NCHW). We therefore check the shape indices on + the **raw** input/output shapes — no extra permutation is needed. + + Returns ``True`` when: + 1. The reshape has monotonic shape_indices (each output dim maps + to a contiguous, in-order group of input dims), AND + 2. The channel dimension is preserved alone (not merged with + spatial dims). + """ + rank_in = len(input_shape) + rank_out = len(output_shape) + if rank_in < 4 or rank_out < 4: + return False + + indices = ToTosaMemoryFormatPass._get_shape_indices( + list(input_shape), list(output_shape) + ) + if indices is None: + return False + + if not ToTosaMemoryFormatPass._is_monotonic(indices): + return False + + # In the TOSA pipeline the physical memory order is NHWC. + # The channel dimension in NHWC is always the **last** axis + # (position ``rank - 1``). It must appear *alone* in its + # output group — if it is merged with spatial dims the reshape + # would reorder channel data and the optimisation is invalid. + channel_idx = rank_in - 1 + for group in indices: + if channel_idx in group: + return len(group) == 1 + # Channel dim not consumed by any group — conservative reject. + return False + @staticmethod def _insert_view_transpose( input_shape, output_shape, node, input_node, graph_module @@ -317,6 +430,14 @@ def _insert_view_transpose( output_sr, ) + # When the NHWC-space reshape has monotonic shape_indices the + # view_copy can operate directly on NHWC data — no transposes + # are needed. + if channel_reshape and ToTosaMemoryFormatPass._is_nhwc_safe_reshape( + input_shape, output_shape, input_sr, output_sr + ): + return + if ( channel_reshape or nhwc_to_nchw ) and ToTosaMemoryFormatPass.memory_format_differs(input_shape, input_sr): @@ -345,10 +466,44 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): - 1D/2D tensors """ - for node in graph_module.graph.nodes: + for node in list(graph_module.graph.nodes): if node.op != "call_function": continue + # Eliminate model-level permute_copy ops that are redundant + # with the tosa_dim_order annotation. When a permute_copy's + # permutation matches the channels-last order (or its + # inverse), the permute does the same NCHW↔NHWC conversion + # that tosa_dim_order already handles — keeping both would + # double-convert. Replace with view_copy (identity reshape). + if node.target in ( + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.permute.default, + ): + perm = list(node.args[1]) + rank = len(perm) + sr = node.meta.get("tosa_spatial_rank", 0) + + if rank >= 3 and sr >= 1: + cl_order = list( + self._channels_last_order(rank, sr) + ) + cl_inv = list( + self._channels_last_inverse_order(rank, sr) + ) + if perm == cl_order or perm == cl_inv: + input_node = node.args[0] + output_shape = list(node.meta["val"].shape) + with graph_module.graph.inserting_before(node): + view_node = graph_module.graph.call_function( + exir_ops.edge.aten.view_copy.default, + (input_node, output_shape), + ) + view_node.meta = dict(node.meta) + node.replace_all_uses_with(view_node) + graph_module.graph.erase_node(node) + continue + # Transpose views elif node.target == exir_ops.edge.aten.view_copy.default: input_node = node.args[0] diff --git a/backends/arm/test/passes/test_to_tosa_memory_format.py b/backends/arm/test/passes/test_to_tosa_memory_format.py index dfd57aa7e61..9c524642ec1 100644 --- a/backends/arm/test/passes/test_to_tosa_memory_format.py +++ b/backends/arm/test/passes/test_to_tosa_memory_format.py @@ -177,11 +177,76 @@ def get_inputs(self) -> input_t: return (torch.rand(4, 4, 4, 4),) +class NHWCSafeSpatialMerge(torch.nn.Module): + """Test-module with a 4D->4D reshape that merges spatial dims H*W while + preserving the last-dim channel. + + For models with view_copy shapes [1,2,14,72]->[1,28,1,72] where C=2 + sits at NCHW position 1 and the last dim (72) is the NHWC channel that gets + preserved. ``_is_nhwc_safe_reshape`` detects that shape_indices on the raw + shapes are monotonic with the last dim alone, so no transposes are inserted + around the view_copy. + + Setup: conv2d (forces NHWC, C=2) -> view_copy -> add (keeps in NHWC). + """ + + ops_before_pass: Dict[str, int] = {} + # Only the 2 I/O transposes for the conv, NO extra transposes from view_copy + ops_after_pass: Dict[str, int] = { + "executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 2 + } + ops_not_after_pass: List[str] = [] + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=2, out_channels=2, kernel_size=1, bias=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) # forces NHWC path; output [1, 2, 14, 72] + x = x.view(1, 28, 1, 72) # spatial merge: H*W=2*14->28, last dim 72 preserved + return x + x # keep result 4-D in NHWC + + def get_inputs(self) -> input_t: + return (torch.randn(1, 2, 14, 72),) + + +class NHWCUnsafeChannelChange(torch.nn.Module): + """Test-module with a 4D->4D reshape that is NOT NHWC-safe because the + target shape cannot be produced by a monotonic merge of NHWC input dims. + The pass MUST still insert transposes around the view_copy. + """ + + ops_before_pass: Dict[str, int] = {} + # conv I/O transposes (2) + view_copy transposes (2) = 4 + ops_after_pass: Dict[str, int] = { + "executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 4 + } + ops_not_after_pass: List[str] = [] + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=72, out_channels=72, kernel_size=1, bias=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) # output [1, 72, 2, 14] + x = x.view(1, 14, 2, 72) # not NHWC-safe (channels shuffled) + return x + x + + def get_inputs(self) -> input_t: + return (torch.randn(1, 72, 2, 14),) + + modules: Dict[str, ModuleMetadata] = { "no_nhwc": NoNHWC(), "parallel_clusters": ParallelClusters(), "serial_clusters": SerialClusters(), "reshapes": Reshapes(), + "nhwc_safe_spatial_merge": NHWCSafeSpatialMerge(), + "nhwc_unsafe_channel_change": NHWCUnsafeChannelChange(), } @@ -209,3 +274,92 @@ def test_to_tosa_memory_format_tosa_INT_functional(module: ModuleMetadata) -> No module_nn = cast(torch.nn.Module, module) pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), []) pipeline.run() + + +# --- Direct unit tests for NHWC-safe reshape helpers --- + + +def test_get_shape_indices_spatial_merge(): + """[1,2,14,72] -> [1,28,1,72]: merge H*W, insert size-1 dim, preserve C.""" + indices = ToTosaMemoryFormatPass._get_shape_indices( + [1, 2, 14, 72], [1, 28, 1, 72] + ) + assert indices == [[0], [1, 2], [], [3]] + + +def test_get_shape_indices_identity(): + """Same shape => each dim maps to itself.""" + indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 4], [2, 3, 4]) + assert indices == [[0], [1], [2]] + + +def test_get_shape_indices_full_merge(): + """[2, 3, 4] -> [24]: merge all dims into one.""" + indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 4], [24]) + assert indices == [[0, 1, 2]] + + +def test_get_shape_indices_incompatible(): + """Sizes that don't divide => None.""" + indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 5], [6, 4]) + assert indices is None + + +def test_get_shape_indices_size_one_insert(): + """[6, 4] -> [6, 1, 4]: inserted size-1 dim in the middle.""" + indices = ToTosaMemoryFormatPass._get_shape_indices([6, 4], [6, 1, 4]) + assert indices is not None + assert indices == [[0], [], [1]] + + +def test_is_monotonic_true(): + assert ToTosaMemoryFormatPass._is_monotonic([[0], [1, 2], [], [3]]) + assert ToTosaMemoryFormatPass._is_monotonic([[0], [], [1], [2, 3]]) + assert ToTosaMemoryFormatPass._is_monotonic([[], [0, 1, 2]]) + + +def test_is_monotonic_false(): + assert not ToTosaMemoryFormatPass._is_monotonic([[1], [0]]) + assert not ToTosaMemoryFormatPass._is_monotonic([[0, 2], [1]]) + + +def test_is_nhwc_safe_forward(): + """Shapes already NHWC by the time the pass runs. + [1,2,14,72] -> [1,28,1,72], sr=2 -> NHWC-safe (spatial merge, C=72 preserved). + """ + assert ToTosaMemoryFormatPass._is_nhwc_safe_reshape( + [1, 2, 14, 72], [1, 28, 1, 72], input_sr=2, output_sr=2 + ) + + +def test_is_nhwc_safe_non_4d(): + """Reshapes below rank 4 are never NHWC-safe.""" + assert not ToTosaMemoryFormatPass._is_nhwc_safe_reshape( + [6, 4], [24], input_sr=0, output_sr=0 + ) + + +def test_is_nop_transpose_size1_swap(): + """[14, 72, 1, 1] with perm (0, 1, 3, 2) only swaps trailing size-1 dims.""" + assert ToTosaMemoryFormatPass._is_nop_transpose([14, 72, 1, 1], (0, 1, 3, 2)) + + +def test_is_nop_transpose_real_reorder(): + """[14, 72, 1, 1] with perm (1, 0, 2, 3) swaps non-size-1 dims.""" + assert not ToTosaMemoryFormatPass._is_nop_transpose([14, 72, 1, 1], (1, 0, 2, 3)) + + +def test_is_nop_transpose_all_size1(): + """[1, 1, 1, 1] with any perm is always a NOP.""" + assert ToTosaMemoryFormatPass._is_nop_transpose([1, 1, 1, 1], (3, 2, 1, 0)) + + +def test_is_nop_transpose_identity(): + """Identity permutation is always a NOP.""" + assert ToTosaMemoryFormatPass._is_nop_transpose([2, 3, 4], (0, 1, 2)) + + +def test_is_nop_transpose_nhwc_on_size1_spatial(): + """[1, 28, 1, 72] with channels_last (0,2,3,1): non-size-1 dims 28,72 + change relative order (28→pos3, 72→pos2) → NOT a NOP.""" + assert not ToTosaMemoryFormatPass._is_nop_transpose([1, 28, 1, 72], (0, 2, 3, 1))