diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 6650efd0883..324c45fcde9 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -123,6 +123,9 @@ from .match_arg_dtype_pass import MatchArgDtypePass # noqa from .match_arg_ranks_pass import MatchArgRanksPass # noqa from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa +from .normalize_index_put_none_indices_pass import ( # noqa + NormalizeIndexPutNoneIndicesPass, +) from .normalize_while_initial_args_pass import NormalizeWhileInitialArgsPass # noqa from .promote_bool_operands_pass import PromoteBoolOperandsPass # noqa from .remove_getitem_pass import RemoveGetItemPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index e4aa821fa7c..2492d1031f0 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -110,6 +110,7 @@ InsertTableOpsPass, MatchArgDtypePass, MatchArgRanksPass, + NormalizeIndexPutNoneIndicesPass, NormalizeWhileInitialArgsPass, PromoteBoolOperandsPass, QuantizeClampArgumentsPass, @@ -337,6 +338,7 @@ def _tosa_pipeline( # Node transformation passes (post scalar-removal) self.add_passes( [ + NormalizeIndexPutNoneIndicesPass(), RewriteIndexPutPass(), RewriteBoolBitwiseToLogicalPass(), DecomposeRemainderPass(), diff --git a/backends/arm/_passes/normalize_index_put_none_indices_pass.py b/backends/arm/_passes/normalize_index_put_none_indices_pass.py new file mode 100644 index 00000000000..7aaace641b0 --- /dev/null +++ b/backends/arm/_passes/normalize_index_put_none_indices_pass.py @@ -0,0 +1,143 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Set, Type + +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.rewrite_index_put_pass import RewriteIndexPutPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +class NormalizeIndexPutNoneIndicesPass(ArmPass): + """Normalize index_put with None:s in the indices_tensor list by moving + None-indexed dims to the channel dimensions (*C_j in RewriteIndexPutPass + teminology) by permutating the destination and data tensors. A None-index + corresponds to selecting the entire dim, which is equivalent with being a + channel dimension. + + Example: + out = index_put(destination, [None, idx1, None, idx2], data) + becomes + destination_permuted = permute(destination, destination_dim_order) + data_front_padded = reshape(data, front_padded_data_shape) + data_permuted = permute(data, data_dim_order) + out_permuted = index_put(destination_permuted, [idx1, idx2], data_permuted) + out = permute(out_permuted, inverse_destination_dim_order) + + Where the permutations of destination and data are decided by how the indexes move. + + Note that None tensors are handled differently in pytorch depending on how many indices tensors there are, + causing the data tensor to require different shapes, which will require different data permutation. + Many: all explicit dims are broadcast to a single dim and put in front of data tensor + destination shape (5,3,4,3) with indices (None, [1,0], None, [0,2]) -> data shape (2, 5, 4) + Note that this is the behaviour we want! No permutation of data is neccessary. + One: The explicit dim is kept in place + destination shape (5,3,4,3) with indices (None, [1,0], None, None) -> data shape (5, 2, 4, 3) + dim 1 needs to be moved to the front: dim_order = (1,0,2,3). + This is the same dim order as for the destination tensor. + + """ + + _passes_required_after: Set[Type[ExportPass]] = {RewriteIndexPutPass} + + def __init__(self): + super().__init__() + self.permute_op = exir_ops.edge.aten.permute_copy.default + self.reshape_op = exir_ops.edge.aten.view_copy.default + + def _get_data_dim_order( + self, + explicit_dims: list[int], + destination_dim_order: list[int], + ) -> list[int]: + """Return dim_order of data tensor.""" + + normalized_non_index_dims = destination_dim_order[len(explicit_dims) :] + data_dim_order = list(range(len(normalized_non_index_dims))) + + if not explicit_dims: + raise RuntimeError("Expected at least one non-None index tensor.") + elif len(explicit_dims) > 1: + # For multiple explicit index tensors, data is already in the order we want. + return data_dim_order + else: + # For single explicit index tensor, use same dim_order as destination + return destination_dim_order + + def call_operator(self, op, args, kwargs, meta, updated: bool | None = False): + if op not in (exir_ops.edge.aten.index_put.default,): + return super().call_operator(op, args, kwargs, meta) + + destination, indices_tensor_list, data = args[:3] + indices_tensor_list = list(indices_tensor_list) + if not any(indices_tensor is None for indices_tensor in indices_tensor_list): + return super().call_operator(op, args, kwargs, meta) + + destination_shape = destination.data.shape + explicit_dims = [ + dim_idx + for dim_idx, index_tensor in enumerate(indices_tensor_list) + if index_tensor is not None + ] + + none_dims = [ + dim_idx + for dim_idx, index_tensor in enumerate(indices_tensor_list) + if index_tensor is None + ] + trailing_dims = list(range(len(indices_tensor_list), len(destination_shape))) + + # Handle None indexing of destination tensor. + destination_dim_order = explicit_dims + none_dims + trailing_dims + needs_destination_permute = destination_dim_order != list( + range(len(destination_shape)) + ) + if needs_destination_permute: + destination = super().call_operator( + self.permute_op, + (destination, destination_dim_order), + {}, + meta, + updated=True, + ) + + # Handle None indexing of data tensor. + data_dim_order = self._get_data_dim_order( + explicit_dims=explicit_dims, + destination_dim_order=destination_dim_order, + ) + needs_data_permute = data_dim_order != list(range(len(data_dim_order))) + + if needs_data_permute: + data_shape = list(data.data.shape) + aligned_rank = len(data_dim_order) + if len(data_shape) < aligned_rank: + # We add dims to data when we move none dims, front pad data with unit dims to match. + padded_shape = [1] * (aligned_rank - len(data_shape)) + data_shape + data = super().call_operator( + self.reshape_op, (data, padded_shape), {}, meta, updated=True + ) + data = super().call_operator( + self.permute_op, (data, data_dim_order), {}, meta, updated=True + ) + + # Call index_put op. + explicit_indices_tensors = [ + indices_tensor_list[dim_idx] for dim_idx in explicit_dims + ] + normalized_args = (destination, explicit_indices_tensors, data, *args[3:]) + out = super().call_operator(op, normalized_args, kwargs, meta, updated=True) + + if not needs_destination_permute: + return out + + # If needed, reverse permutation of destination tensor. + inv_dim_order = [0] * len(destination_dim_order) + for new_dim, original_dim in enumerate(destination_dim_order): + inv_dim_order[original_dim] = new_dim + + return super().call_operator( + self.permute_op, (out, inv_dim_order), {}, meta, updated=True + ) diff --git a/backends/arm/_passes/rewrite_index_put_pass.py b/backends/arm/_passes/rewrite_index_put_pass.py index 24752772129..c0898673fd7 100644 --- a/backends/arm/_passes/rewrite_index_put_pass.py +++ b/backends/arm/_passes/rewrite_index_put_pass.py @@ -2,13 +2,15 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - import math -from typing import Any, Iterable, List, Sequence, Set, Type +from typing import Sequence, Set, Type import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( + ConvertExpandCopyToRepeatPass, +) from executorch.backends.arm._passes.fuse_view_copy_transform_pass import ( FuseViewCopyTransformPass, ) @@ -16,131 +18,92 @@ from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue -def get_index_put_ops(op): - if op == exir_ops.edge.aten.index_put.default: - return ( - exir_ops.edge.aten.view_copy.default, - exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.mul.Tensor, - exir_ops.backend.tosa.SCATTER.default, - exir_ops.edge.aten.full.default, - ) - raise RuntimeError(f"Can't get index_put decomposition for op {op}") - - -def calculate_tosa_values( - index_shape: list[int], - index_nodes: list[Any], - source_shape: list[int], -) -> tuple[int, int, int, int]: - # Calculate K, W, C - # N - kept to 1 for generic n-dim implementation - # W - the number of positions being updated - N, K, W, C = 1, 1, 1, 1 - - W = math.prod(index_shape) - - for i, dim in enumerate(source_shape): - if i < len(index_nodes): - K *= dim - - total_vals = math.prod(source_shape) - C = int(total_vals / K) - - return N, K, W, C - - -def calculate_values_stride(source_shape: list[int]) -> list[int]: - """Calculate strides for a flattened view of the source tensor that are +def calculate_data_stride(destination_shape: list[int]) -> list[int]: + """Calculate strides for a flattened view of the destination tensor that are multiplied with the indices to build the [N, W] tensor. """ - values_strides: list[int] = [] + data_strides: list[int] = [] stride = 1 - for dim in reversed(source_shape): - values_strides.insert(0, stride) + for dim in reversed(destination_shape): + data_strides.insert(0, stride) stride *= dim - return values_strides + return data_strides class RewriteIndexPutPass(ArmPass): """ - This pass transforms index_put operations into TOSA-compatible scatter operations by: - 1. Expanding None indices into explicit range tensors - 2. Calculating flattened index positions - 3. Reshaping tensors into a 3D layout [N, K, C] required by TOSA SCATTER - 4. Applying the scatter operation and reshaping back to the original shape - - Example: - For source[i, :, j] = values, this pass: - - Expands ':' to arange(0, dim_size) - - Calculates flat indices: i * stride[0] + expanded_range * stride[1] + j * stride[2] - - Reshapes to 3D, applies scatter, reshapes back + This pass transforms index_put with arguments + - destination, of shape (*K_i, *C_j) + where *K_i means some number of dims >1 explicitly indexed (copy some entries from data to destination). + *C_j means some number of dims >=0 fully indexed (copy entire dim from data to destination) + - indices_tensor_list, a list containing len(*K_i) tensors with shape (W or 1,), indices for each dim K_i. + W is the number of explicit indexes. + Indicies_tensors are required to not be None. + - data, of shape (*C_d_j) + where len(*C_d_j) can be <= len(C_j)+1 and + *C_d_j can be broadcast into (W, *C_j), + - accumulate = False + + The lowering strategy is as follows: + - Flatten *K_i into K = prod(*K_i) + - Flatten *C_j int C = prod(*C_j) + - source_flattened = reshape(source, [N=1, K, C]) + - index_flattened = _calculate_flat_indices() + - data_broadcast = expand(data, [W, *C_j]) + - data_flattened = reshape(data_broadcast, [N=1, W, C]) + - Apply TOSA.SCATTER(source_flattened, index_flattened, data_flattened) + - Reshape output back to original destination shape """ - _passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransformPass} + def __init__(self): + super().__init__() + self.reshape_op = exir_ops.edge.aten.view_copy.default + self.expand_op = exir_ops.edge.aten.expand_copy.default + self.add_op = exir_ops.edge.aten.add.Tensor + self.mul_op = exir_ops.edge.aten.mul.Tensor + self.scatter_op = exir_ops.backend.tosa.SCATTER.default + self.full_op = exir_ops.edge.aten.full.default - def _expand_none_indices( - self, - source_shape: Sequence[int], - indices: Iterable[Any], - meta: NodeMetadata, - ) -> List[ProxyValue]: - """Replace None indices with explicit ranges.""" - expanded: List[ProxyValue] = [] - for dim_idx, idx in enumerate(indices): - if idx is None: - end_index = int(source_shape[dim_idx]) - # Use arange via call to edge operator since full can't create ranges - full_range = super().call_operator( - exir_ops.edge.aten.arange.start_step, - (0, end_index, 1), - {}, - meta, - updated=True, - ) - expanded.append(full_range) - elif not isinstance(idx, ProxyValue): - raise NotImplementedError( - "index_put indices must be tensor ProxyValues or None" - ) - else: - expanded.append(idx) - return expanded + _passes_required_after: Set[Type[ExportPass]] = { + FuseViewCopyTransformPass, + ConvertExpandCopyToRepeatPass, + } def _calculate_flat_indices( self, indices: Sequence[ProxyValue], - source_shape: Sequence[int], - num_channels: int, - ops: Sequence[Any], - full_op: Any, + shape: list[int], meta: NodeMetadata, ) -> ProxyValue: """ - The flat index is computed as: sum(index[i] * (stride[i] / num_channels)) for each dimension i. + The flat index is computed as: + sum(index[i] * (stride[i])) for each dimension i in shape. """ - values_strides = calculate_values_stride(list(source_shape)) - mul_op, add_op = ops + data_strides = calculate_data_stride(shape) new_indices: ProxyValue | None = None + W = 1 for i, index_val in enumerate(indices): # Get the shape of this dimension's indices to create matching constant tensor - scale_val = int(values_strides[i] / num_channels) + scale_val = int(data_strides[i]) # Create constant tensor directly: full([numel], scale_val) mul_const = super().call_operator( - full_op, + self.full_op, ((1,), scale_val), - {}, + { + "dtype": index_val.data.dtype, + "device": index_val.data.device, + }, meta, True, ) # Multiply indices by their stride constant mul_node = super().call_operator( - mul_op, (index_val, mul_const), {}, meta, True + self.mul_op, (index_val, mul_const), {}, meta, True ) # Accumulate contributions from each dimension @@ -148,95 +111,89 @@ def _calculate_flat_indices( new_indices = mul_node else: new_indices = super().call_operator( - add_op, (mul_node, new_indices), {}, meta, True + self.add_op, (mul_node, new_indices), {}, meta, True ) + assert new_indices is not None + W = max(new_indices.data.shape[0], 1) - if new_indices is None: - raise RuntimeError("No indices were provided for index_put") - - return new_indices + return super().call_operator( + self.reshape_op, (new_indices, (1, W)), {}, meta, True + ) - def call_operator(self, op, args, kwargs, meta): + def call_operator(self, op, args, kwargs, meta, updated: bool | None = False): if op not in (exir_ops.edge.aten.index_put.default,): return super().call_operator(op, args, kwargs, meta) - (reshape_op, add_op, mul_op, scatter_op, full_op) = get_index_put_ops(op) + destination, indices_tensor_list, data = args[:3] + accumulate = len(args) > 3 and bool(args[3]) + if accumulate: + raise RuntimeError( + "Encountered index_put with accumulate=True, this is assumed to be handled by an earlier pass." + ) - source, indices, values = args[:3] - if not isinstance(indices, (list, tuple)): - raise NotImplementedError("index_put indices must be provided as a tuple") + indices_tensor_list = list(indices_tensor_list) + num_explicit_indices = len(indices_tensor_list) + if any(index_tensor is None for index_tensor in indices_tensor_list): + raise RuntimeError( + "Encountered None indices in RewriteIndexPutPass. " + "Run NormalizeIndexPutNoneIndicesPass before this pass." + ) + + destination_shape = destination.data.shape - source_tensor = source.data - source_shape = list(source_tensor.shape) + K_i = destination_shape[:num_explicit_indices] + C_j = destination_shape[num_explicit_indices:] + C_j = torch.broadcast_shapes(C_j) + K = math.prod(K_i) + C = math.prod(C_j) + # A shape is a tensor of rank 1 -> rank of shape is shape[0]. + indices_shapes = [tuple(idx.data.shape) for idx in indices_tensor_list] + W = torch.broadcast_shapes(*indices_shapes)[0] + + # CALCULATE FLATTENED DESTINATION [N=1, K, C] + destination_flattened = super().call_operator( + self.reshape_op, (destination, [1, K, C]), {}, meta, True + ) + # CALCULATE FLATTENED INDEX [N=1, W] plain_meta_dict = dict(meta.data) plain_meta_dict["input_qparams"] = {} plain_meta_dict["output_qparams"] = {} plain_meta = NodeMetadata(plain_meta_dict) - - index_dtype = None - for idx in indices: - if idx is not None: - if not isinstance(idx, ProxyValue): - raise NotImplementedError( - "index_put indices must be tensor ProxyValues or None" - ) - index_dtype = idx.data.dtype - break - if index_dtype is None: - raise NotImplementedError( - "index_put with only None indices is not supported" - ) - - processed_indices = self._expand_none_indices(source_shape, indices, plain_meta) - index_shapes = [tuple(idx.data.shape) for idx in processed_indices] - try: - broadcast_shape = torch.broadcast_shapes(*index_shapes) - except Exception as exc: - raise RuntimeError( - "RewriteIndexPutPass: failed to broadcast index shapes %s: %s" - % (index_shapes, exc) - ) from exc - - N, K, W, C = calculate_tosa_values( - list(broadcast_shape), - [idx.node for idx in processed_indices], - source_shape, - ) - - indices_reshaped = self._calculate_flat_indices( - processed_indices, - source_shape, - C, - (mul_op, add_op), - full_op, - plain_meta, + indices_flattened = self._calculate_flat_indices( + indices_tensor_list, K_i, plain_meta ) - idx_shape = list(indices_reshaped.data.shape) - idx_numel = math.prod(idx_shape) - if idx_numel != N * W: - raise RuntimeError( - "RewriteIndexPutPass: flat index numel (%s) does not match expected N*W (%s)" - % (idx_numel, N * W) - ) - # Scatter expects a 3D layout; flatten everything into [N, K, C]. - reshape_indices = super().call_operator( - reshape_op, (indices_reshaped, [N, W]), {}, plain_meta, True - ) - reshape_source = super().call_operator( - reshape_op, (source, [N, K, C]), {}, meta, True + # CALCULATE FLATTENED DATA [N=1, W, C] + data_broadcast = super().call_operator( + self.expand_op, + ( + data, + ( + W, + *C_j, + ), + ), + {}, + meta, + updated=True, ) - reshape_values = super().call_operator( - reshape_op, (values, (N, W, C)), {}, meta, True + data_flattened = super().call_operator( + self.reshape_op, (data_broadcast, (1, W, C)), {}, meta, updated=True ) + + # DO SCATTER scatter_node = super().call_operator( - scatter_op, - (reshape_source, reshape_indices, reshape_values), + self.scatter_op, + (destination_flattened, indices_flattened, data_flattened), {}, meta, True, ) - return super().call_operator( - reshape_op, (scatter_node, source_shape), kwargs, meta, True + + # RESHAPE BACK TO ORIGINAL SHAPE + out = super().call_operator( + self.reshape_op, (scatter_node, destination_shape), kwargs, meta, True ) + + return out diff --git a/backends/arm/test/modules/test_static_cache.py b/backends/arm/test/modules/test_static_cache.py index a0e7d24cdac..8ddf6da1273 100644 --- a/backends/arm/test/modules/test_static_cache.py +++ b/backends/arm/test/modules/test_static_cache.py @@ -147,7 +147,15 @@ def test_static_cache_u55_INT(test_data): @common.XfailIfNoCorstone320 -@common.parametrize("test_data", test_configs) +@common.parametrize( + "test_data", + test_configs, + xfails={ + "multihead_attention": "Incorrect numerical behavior: MLBEDSW-11589", + "grouped_query_attention": "Incorrect numerical behavior: MLBEDSW-11589", + "multi_query_attention": "Incorrect numerical behavior: MLBEDSW-11589", + }, +) def test_static_cache_u85_INT(test_data): module = StaticCacheModule(test_data).eval() pipeline = EthosU85PipelineINT[input_t]( diff --git a/backends/arm/test/ops/test_index_put.py b/backends/arm/test/ops/test_index_put.py index d4cccb547bd..d6bef6bbeb6 100644 --- a/backends/arm/test/ops/test_index_put.py +++ b/backends/arm/test/ops/test_index_put.py @@ -3,8 +3,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple - import torch from executorch.backends.arm._passes import InsertInt32CastsAfterInt64PlaceholdersPass @@ -122,6 +120,96 @@ ), 0, ), + "broadcast_values_scalar": ( + lambda: ( + torch.zeros((3, 4), dtype=torch.float32), + (torch.tensor([0, 2], dtype=torch.int64),), + torch.tensor([5.0], dtype=torch.float32), + False, + ), + 0, + ), + "broadcast_values_scalar_0d": ( + lambda: ( + torch.zeros((3, 4), dtype=torch.float32), + (torch.tensor([0, 2], dtype=torch.int64),), + torch.tensor(5.0, dtype=torch.float32), + False, + ), + 0, + ), + "broadcast_values_vector": ( + lambda: ( + torch.zeros((3, 4), dtype=torch.float32), + (torch.tensor([0, 2], dtype=torch.int64),), + torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32), + False, + ), + 0, + ), + "broadcast_values_with_implicit_w_and_c": ( + lambda: ( + torch.zeros((5, 2), dtype=torch.float32), + (torch.tensor([0, 2], dtype=torch.int64),), + torch.tensor([10.0, 20.0], dtype=torch.float32), + False, + ), + 0, + ), + "none_indices": ( + lambda: ( + torch.ones((5, 3, 2, 2), dtype=torch.float32), + (torch.IntTensor([2, 3, 0]), None), + torch.zeros(1), + False, + ), + 0, + ), + "none_indices_2": ( + lambda: ( + torch.ones((5, 3, 2, 2), dtype=torch.float32), + (None, torch.IntTensor([2, 0]), None), + torch.rand(2, 2, 2), + False, + ), + 0, + ), + "none_indices_3": ( + lambda: ( + torch.ones((5, 3, 2, 2), dtype=torch.float32), + (None, torch.IntTensor([2, 1, 0]), None, None), + torch.zeros(1), + False, + ), + 0, + ), + "none_indices_4": ( + lambda: ( + torch.ones((5, 3, 2, 2), dtype=torch.float32), + ( + None, + torch.IntTensor( + [ + 2, + ] + ), + None, + torch.IntTensor([0]), + ), + torch.zeros(1, 5, 2), + False, + ), + 0, + ), + "none_indices_5": ( + lambda: ( + torch.ones((5, 3, 2, 2), dtype=torch.float32), + (None, torch.IntTensor([2, 0]), None, torch.IntTensor([0])), + torch.zeros(2, 1, 2), + False, + ), + 0, + ), } test_data_int = { "rank3_zeros_int8": ( @@ -190,10 +278,12 @@ def forward( z: torch.Tensor, acc: bool, ): - return torch.index_put(x, indices=y, values=z, accumulate=acc) + # Needs to use aten op directly to allow None indices. + return torch.ops.aten.index_put.default(x, indices=y, values=z, accumulate=acc) -input_t = Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool], int] +input_t = tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool], int] +input_no_indices_t = tuple[torch.Tensor, torch.Tensor] xfails = { "same_index": "MLETORCH-1596: index_put with repeated indices not supported", @@ -243,7 +333,11 @@ def test_index_put_u55_INT(test_module: input_t): @common.XfailIfNoCorstone320 -@common.parametrize("test_module", test_data_suite_fp | test_data_int) +@common.parametrize( + "test_module", + test_data_suite_fp | test_data_int, + xfails={"none_indices_4": "Incorrect numerical behavior: MLBEDSW-11589"}, +) def test_index_put_u85_INT(test_module: input_t): """same_index test case already supported on u85 even though it is not supported by TOSA spec.