From d695731abb7e76f19b621cc931d9e24916272d29 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Mon, 9 Mar 2026 10:37:20 +0100 Subject: [PATCH] Arm backend: Refactor and bug-fix RewriteIndexPutPass The patch should hopefully make the pass easier to understand. Make explicit that we set N=1, handle explicit indexing by folding them in the K dimension, and handle full indexing (select all values) by folding them in the C dimension. Note that TOSA and torch has switched terminology regarding what the parameter 'values' means, instead, use a new naming: TOSA values_in == torch x/self tensor, call this 'destination'. TOSA input == torch values, call this 'data'. Additionally, the pass earlier didn't account for that 1) There are fully indexed dimensions 2) Index tensors can be broadcast 3) The data tensor can be smaller than (N, W, C), and require broadcasting first. 4) None index tensors were incorrectly handled. Regarding 1-3): Given destination of shape (N, K, C), TOSA.SCATTER semantics require the shape (N, W) of the index tensor, including possibly an implicit C dimension, to match the data shape (N, W_d, C_d). Torch can however broadcast both these inputs. We need to expand/reshape the data tensor correctly. Example (ignoring N, it's always 1): >>> destination = torch.ones(5, 2), K=5, C=2 >>> indices = (torch.tensor([0, 2]),) # Indexes K dim W=2 times, C is implicitly assumed to be C=2. >>> data = torch.tensor([10.0, 20.0]) # W_d = 1 !!, C_d=2 >>> torch.index_put(destination, indices, data) tensor([[10., 20.], [ 1., 1.], [10., 20.], [ 1., 1.], [ 1., 1.]]) Or even >>> [...] >>> data = torch.tensor([10.0]) # W_d = 1, C_d=1 !! >>> torch.index_put(destination, indices, data) tensor([[10., 10.], [ 1., 1.], [10., 10.], [ 1., 1.], [10., 10.]]) The patch generalizes this to multiple dimensions. Refer to docstring in patch for complete explaination. 4) Is handled by adding a normalization pass that rewrites None indice tensors to fully indexed tensors. Signed-off-by: Erik Lundell Change-Id: I272255377a53169c7b3547aabbca55c967fbf1d9 --- backends/arm/_passes/__init__.py | 3 + backends/arm/_passes/arm_pass_manager.py | 2 + .../normalize_index_put_none_indices_pass.py | 143 +++++++++ .../arm/_passes/rewrite_index_put_pass.py | 279 ++++++++---------- .../arm/test/modules/test_static_cache.py | 10 +- backends/arm/test/ops/test_index_put.py | 104 ++++++- 6 files changed, 374 insertions(+), 167 deletions(-) create mode 100644 backends/arm/_passes/normalize_index_put_none_indices_pass.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 36e3fe004d9..55251a40fee 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -122,6 +122,9 @@ from .match_arg_dtype_pass import MatchArgDtypePass # noqa from .match_arg_ranks_pass import MatchArgRanksPass # noqa from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa +from .normalize_index_put_none_indices_pass import ( # noqa + NormalizeIndexPutNoneIndicesPass, +) from .normalize_while_initial_args_pass import NormalizeWhileInitialArgsPass # noqa from .promote_bool_operands_pass import PromoteBoolOperandsPass # noqa from .remove_getitem_pass import RemoveGetItemPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index b3f9fd2ef8a..b5d80110902 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -109,6 +109,7 @@ InsertTableOpsPass, MatchArgDtypePass, MatchArgRanksPass, + NormalizeIndexPutNoneIndicesPass, NormalizeWhileInitialArgsPass, PromoteBoolOperandsPass, QuantizeClampArgumentsPass, @@ -336,6 +337,7 @@ def _tosa_pipeline( # Node transformation passes (post scalar-removal) self.add_passes( [ + NormalizeIndexPutNoneIndicesPass(), RewriteIndexPutPass(), RewriteBoolBitwiseToLogicalPass(), DecomposeRemainderPass(), diff --git a/backends/arm/_passes/normalize_index_put_none_indices_pass.py b/backends/arm/_passes/normalize_index_put_none_indices_pass.py new file mode 100644 index 00000000000..7aaace641b0 --- /dev/null +++ b/backends/arm/_passes/normalize_index_put_none_indices_pass.py @@ -0,0 +1,143 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Set, Type + +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.rewrite_index_put_pass import RewriteIndexPutPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +class NormalizeIndexPutNoneIndicesPass(ArmPass): + """Normalize index_put with None:s in the indices_tensor list by moving + None-indexed dims to the channel dimensions (*C_j in RewriteIndexPutPass + teminology) by permutating the destination and data tensors. A None-index + corresponds to selecting the entire dim, which is equivalent with being a + channel dimension. + + Example: + out = index_put(destination, [None, idx1, None, idx2], data) + becomes + destination_permuted = permute(destination, destination_dim_order) + data_front_padded = reshape(data, front_padded_data_shape) + data_permuted = permute(data, data_dim_order) + out_permuted = index_put(destination_permuted, [idx1, idx2], data_permuted) + out = permute(out_permuted, inverse_destination_dim_order) + + Where the permutations of destination and data are decided by how the indexes move. + + Note that None tensors are handled differently in pytorch depending on how many indices tensors there are, + causing the data tensor to require different shapes, which will require different data permutation. + Many: all explicit dims are broadcast to a single dim and put in front of data tensor + destination shape (5,3,4,3) with indices (None, [1,0], None, [0,2]) -> data shape (2, 5, 4) + Note that this is the behaviour we want! No permutation of data is neccessary. + One: The explicit dim is kept in place + destination shape (5,3,4,3) with indices (None, [1,0], None, None) -> data shape (5, 2, 4, 3) + dim 1 needs to be moved to the front: dim_order = (1,0,2,3). + This is the same dim order as for the destination tensor. + + """ + + _passes_required_after: Set[Type[ExportPass]] = {RewriteIndexPutPass} + + def __init__(self): + super().__init__() + self.permute_op = exir_ops.edge.aten.permute_copy.default + self.reshape_op = exir_ops.edge.aten.view_copy.default + + def _get_data_dim_order( + self, + explicit_dims: list[int], + destination_dim_order: list[int], + ) -> list[int]: + """Return dim_order of data tensor.""" + + normalized_non_index_dims = destination_dim_order[len(explicit_dims) :] + data_dim_order = list(range(len(normalized_non_index_dims))) + + if not explicit_dims: + raise RuntimeError("Expected at least one non-None index tensor.") + elif len(explicit_dims) > 1: + # For multiple explicit index tensors, data is already in the order we want. + return data_dim_order + else: + # For single explicit index tensor, use same dim_order as destination + return destination_dim_order + + def call_operator(self, op, args, kwargs, meta, updated: bool | None = False): + if op not in (exir_ops.edge.aten.index_put.default,): + return super().call_operator(op, args, kwargs, meta) + + destination, indices_tensor_list, data = args[:3] + indices_tensor_list = list(indices_tensor_list) + if not any(indices_tensor is None for indices_tensor in indices_tensor_list): + return super().call_operator(op, args, kwargs, meta) + + destination_shape = destination.data.shape + explicit_dims = [ + dim_idx + for dim_idx, index_tensor in enumerate(indices_tensor_list) + if index_tensor is not None + ] + + none_dims = [ + dim_idx + for dim_idx, index_tensor in enumerate(indices_tensor_list) + if index_tensor is None + ] + trailing_dims = list(range(len(indices_tensor_list), len(destination_shape))) + + # Handle None indexing of destination tensor. + destination_dim_order = explicit_dims + none_dims + trailing_dims + needs_destination_permute = destination_dim_order != list( + range(len(destination_shape)) + ) + if needs_destination_permute: + destination = super().call_operator( + self.permute_op, + (destination, destination_dim_order), + {}, + meta, + updated=True, + ) + + # Handle None indexing of data tensor. + data_dim_order = self._get_data_dim_order( + explicit_dims=explicit_dims, + destination_dim_order=destination_dim_order, + ) + needs_data_permute = data_dim_order != list(range(len(data_dim_order))) + + if needs_data_permute: + data_shape = list(data.data.shape) + aligned_rank = len(data_dim_order) + if len(data_shape) < aligned_rank: + # We add dims to data when we move none dims, front pad data with unit dims to match. + padded_shape = [1] * (aligned_rank - len(data_shape)) + data_shape + data = super().call_operator( + self.reshape_op, (data, padded_shape), {}, meta, updated=True + ) + data = super().call_operator( + self.permute_op, (data, data_dim_order), {}, meta, updated=True + ) + + # Call index_put op. + explicit_indices_tensors = [ + indices_tensor_list[dim_idx] for dim_idx in explicit_dims + ] + normalized_args = (destination, explicit_indices_tensors, data, *args[3:]) + out = super().call_operator(op, normalized_args, kwargs, meta, updated=True) + + if not needs_destination_permute: + return out + + # If needed, reverse permutation of destination tensor. + inv_dim_order = [0] * len(destination_dim_order) + for new_dim, original_dim in enumerate(destination_dim_order): + inv_dim_order[original_dim] = new_dim + + return super().call_operator( + self.permute_op, (out, inv_dim_order), {}, meta, updated=True + ) diff --git a/backends/arm/_passes/rewrite_index_put_pass.py b/backends/arm/_passes/rewrite_index_put_pass.py index 24752772129..c0898673fd7 100644 --- a/backends/arm/_passes/rewrite_index_put_pass.py +++ b/backends/arm/_passes/rewrite_index_put_pass.py @@ -2,13 +2,15 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - import math -from typing import Any, Iterable, List, Sequence, Set, Type +from typing import Sequence, Set, Type import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( + ConvertExpandCopyToRepeatPass, +) from executorch.backends.arm._passes.fuse_view_copy_transform_pass import ( FuseViewCopyTransformPass, ) @@ -16,131 +18,92 @@ from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue -def get_index_put_ops(op): - if op == exir_ops.edge.aten.index_put.default: - return ( - exir_ops.edge.aten.view_copy.default, - exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.mul.Tensor, - exir_ops.backend.tosa.SCATTER.default, - exir_ops.edge.aten.full.default, - ) - raise RuntimeError(f"Can't get index_put decomposition for op {op}") - - -def calculate_tosa_values( - index_shape: list[int], - index_nodes: list[Any], - source_shape: list[int], -) -> tuple[int, int, int, int]: - # Calculate K, W, C - # N - kept to 1 for generic n-dim implementation - # W - the number of positions being updated - N, K, W, C = 1, 1, 1, 1 - - W = math.prod(index_shape) - - for i, dim in enumerate(source_shape): - if i < len(index_nodes): - K *= dim - - total_vals = math.prod(source_shape) - C = int(total_vals / K) - - return N, K, W, C - - -def calculate_values_stride(source_shape: list[int]) -> list[int]: - """Calculate strides for a flattened view of the source tensor that are +def calculate_data_stride(destination_shape: list[int]) -> list[int]: + """Calculate strides for a flattened view of the destination tensor that are multiplied with the indices to build the [N, W] tensor. """ - values_strides: list[int] = [] + data_strides: list[int] = [] stride = 1 - for dim in reversed(source_shape): - values_strides.insert(0, stride) + for dim in reversed(destination_shape): + data_strides.insert(0, stride) stride *= dim - return values_strides + return data_strides class RewriteIndexPutPass(ArmPass): """ - This pass transforms index_put operations into TOSA-compatible scatter operations by: - 1. Expanding None indices into explicit range tensors - 2. Calculating flattened index positions - 3. Reshaping tensors into a 3D layout [N, K, C] required by TOSA SCATTER - 4. Applying the scatter operation and reshaping back to the original shape - - Example: - For source[i, :, j] = values, this pass: - - Expands ':' to arange(0, dim_size) - - Calculates flat indices: i * stride[0] + expanded_range * stride[1] + j * stride[2] - - Reshapes to 3D, applies scatter, reshapes back + This pass transforms index_put with arguments + - destination, of shape (*K_i, *C_j) + where *K_i means some number of dims >1 explicitly indexed (copy some entries from data to destination). + *C_j means some number of dims >=0 fully indexed (copy entire dim from data to destination) + - indices_tensor_list, a list containing len(*K_i) tensors with shape (W or 1,), indices for each dim K_i. + W is the number of explicit indexes. + Indicies_tensors are required to not be None. + - data, of shape (*C_d_j) + where len(*C_d_j) can be <= len(C_j)+1 and + *C_d_j can be broadcast into (W, *C_j), + - accumulate = False + + The lowering strategy is as follows: + - Flatten *K_i into K = prod(*K_i) + - Flatten *C_j int C = prod(*C_j) + - source_flattened = reshape(source, [N=1, K, C]) + - index_flattened = _calculate_flat_indices() + - data_broadcast = expand(data, [W, *C_j]) + - data_flattened = reshape(data_broadcast, [N=1, W, C]) + - Apply TOSA.SCATTER(source_flattened, index_flattened, data_flattened) + - Reshape output back to original destination shape """ - _passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransformPass} + def __init__(self): + super().__init__() + self.reshape_op = exir_ops.edge.aten.view_copy.default + self.expand_op = exir_ops.edge.aten.expand_copy.default + self.add_op = exir_ops.edge.aten.add.Tensor + self.mul_op = exir_ops.edge.aten.mul.Tensor + self.scatter_op = exir_ops.backend.tosa.SCATTER.default + self.full_op = exir_ops.edge.aten.full.default - def _expand_none_indices( - self, - source_shape: Sequence[int], - indices: Iterable[Any], - meta: NodeMetadata, - ) -> List[ProxyValue]: - """Replace None indices with explicit ranges.""" - expanded: List[ProxyValue] = [] - for dim_idx, idx in enumerate(indices): - if idx is None: - end_index = int(source_shape[dim_idx]) - # Use arange via call to edge operator since full can't create ranges - full_range = super().call_operator( - exir_ops.edge.aten.arange.start_step, - (0, end_index, 1), - {}, - meta, - updated=True, - ) - expanded.append(full_range) - elif not isinstance(idx, ProxyValue): - raise NotImplementedError( - "index_put indices must be tensor ProxyValues or None" - ) - else: - expanded.append(idx) - return expanded + _passes_required_after: Set[Type[ExportPass]] = { + FuseViewCopyTransformPass, + ConvertExpandCopyToRepeatPass, + } def _calculate_flat_indices( self, indices: Sequence[ProxyValue], - source_shape: Sequence[int], - num_channels: int, - ops: Sequence[Any], - full_op: Any, + shape: list[int], meta: NodeMetadata, ) -> ProxyValue: """ - The flat index is computed as: sum(index[i] * (stride[i] / num_channels)) for each dimension i. + The flat index is computed as: + sum(index[i] * (stride[i])) for each dimension i in shape. """ - values_strides = calculate_values_stride(list(source_shape)) - mul_op, add_op = ops + data_strides = calculate_data_stride(shape) new_indices: ProxyValue | None = None + W = 1 for i, index_val in enumerate(indices): # Get the shape of this dimension's indices to create matching constant tensor - scale_val = int(values_strides[i] / num_channels) + scale_val = int(data_strides[i]) # Create constant tensor directly: full([numel], scale_val) mul_const = super().call_operator( - full_op, + self.full_op, ((1,), scale_val), - {}, + { + "dtype": index_val.data.dtype, + "device": index_val.data.device, + }, meta, True, ) # Multiply indices by their stride constant mul_node = super().call_operator( - mul_op, (index_val, mul_const), {}, meta, True + self.mul_op, (index_val, mul_const), {}, meta, True ) # Accumulate contributions from each dimension @@ -148,95 +111,89 @@ def _calculate_flat_indices( new_indices = mul_node else: new_indices = super().call_operator( - add_op, (mul_node, new_indices), {}, meta, True + self.add_op, (mul_node, new_indices), {}, meta, True ) + assert new_indices is not None + W = max(new_indices.data.shape[0], 1) - if new_indices is None: - raise RuntimeError("No indices were provided for index_put") - - return new_indices + return super().call_operator( + self.reshape_op, (new_indices, (1, W)), {}, meta, True + ) - def call_operator(self, op, args, kwargs, meta): + def call_operator(self, op, args, kwargs, meta, updated: bool | None = False): if op not in (exir_ops.edge.aten.index_put.default,): return super().call_operator(op, args, kwargs, meta) - (reshape_op, add_op, mul_op, scatter_op, full_op) = get_index_put_ops(op) + destination, indices_tensor_list, data = args[:3] + accumulate = len(args) > 3 and bool(args[3]) + if accumulate: + raise RuntimeError( + "Encountered index_put with accumulate=True, this is assumed to be handled by an earlier pass." + ) - source, indices, values = args[:3] - if not isinstance(indices, (list, tuple)): - raise NotImplementedError("index_put indices must be provided as a tuple") + indices_tensor_list = list(indices_tensor_list) + num_explicit_indices = len(indices_tensor_list) + if any(index_tensor is None for index_tensor in indices_tensor_list): + raise RuntimeError( + "Encountered None indices in RewriteIndexPutPass. " + "Run NormalizeIndexPutNoneIndicesPass before this pass." + ) + + destination_shape = destination.data.shape - source_tensor = source.data - source_shape = list(source_tensor.shape) + K_i = destination_shape[:num_explicit_indices] + C_j = destination_shape[num_explicit_indices:] + C_j = torch.broadcast_shapes(C_j) + K = math.prod(K_i) + C = math.prod(C_j) + # A shape is a tensor of rank 1 -> rank of shape is shape[0]. + indices_shapes = [tuple(idx.data.shape) for idx in indices_tensor_list] + W = torch.broadcast_shapes(*indices_shapes)[0] + + # CALCULATE FLATTENED DESTINATION [N=1, K, C] + destination_flattened = super().call_operator( + self.reshape_op, (destination, [1, K, C]), {}, meta, True + ) + # CALCULATE FLATTENED INDEX [N=1, W] plain_meta_dict = dict(meta.data) plain_meta_dict["input_qparams"] = {} plain_meta_dict["output_qparams"] = {} plain_meta = NodeMetadata(plain_meta_dict) - - index_dtype = None - for idx in indices: - if idx is not None: - if not isinstance(idx, ProxyValue): - raise NotImplementedError( - "index_put indices must be tensor ProxyValues or None" - ) - index_dtype = idx.data.dtype - break - if index_dtype is None: - raise NotImplementedError( - "index_put with only None indices is not supported" - ) - - processed_indices = self._expand_none_indices(source_shape, indices, plain_meta) - index_shapes = [tuple(idx.data.shape) for idx in processed_indices] - try: - broadcast_shape = torch.broadcast_shapes(*index_shapes) - except Exception as exc: - raise RuntimeError( - "RewriteIndexPutPass: failed to broadcast index shapes %s: %s" - % (index_shapes, exc) - ) from exc - - N, K, W, C = calculate_tosa_values( - list(broadcast_shape), - [idx.node for idx in processed_indices], - source_shape, - ) - - indices_reshaped = self._calculate_flat_indices( - processed_indices, - source_shape, - C, - (mul_op, add_op), - full_op, - plain_meta, + indices_flattened = self._calculate_flat_indices( + indices_tensor_list, K_i, plain_meta ) - idx_shape = list(indices_reshaped.data.shape) - idx_numel = math.prod(idx_shape) - if idx_numel != N * W: - raise RuntimeError( - "RewriteIndexPutPass: flat index numel (%s) does not match expected N*W (%s)" - % (idx_numel, N * W) - ) - # Scatter expects a 3D layout; flatten everything into [N, K, C]. - reshape_indices = super().call_operator( - reshape_op, (indices_reshaped, [N, W]), {}, plain_meta, True - ) - reshape_source = super().call_operator( - reshape_op, (source, [N, K, C]), {}, meta, True + # CALCULATE FLATTENED DATA [N=1, W, C] + data_broadcast = super().call_operator( + self.expand_op, + ( + data, + ( + W, + *C_j, + ), + ), + {}, + meta, + updated=True, ) - reshape_values = super().call_operator( - reshape_op, (values, (N, W, C)), {}, meta, True + data_flattened = super().call_operator( + self.reshape_op, (data_broadcast, (1, W, C)), {}, meta, updated=True ) + + # DO SCATTER scatter_node = super().call_operator( - scatter_op, - (reshape_source, reshape_indices, reshape_values), + self.scatter_op, + (destination_flattened, indices_flattened, data_flattened), {}, meta, True, ) - return super().call_operator( - reshape_op, (scatter_node, source_shape), kwargs, meta, True + + # RESHAPE BACK TO ORIGINAL SHAPE + out = super().call_operator( + self.reshape_op, (scatter_node, destination_shape), kwargs, meta, True ) + + return out diff --git a/backends/arm/test/modules/test_static_cache.py b/backends/arm/test/modules/test_static_cache.py index a0e7d24cdac..8ddf6da1273 100644 --- a/backends/arm/test/modules/test_static_cache.py +++ b/backends/arm/test/modules/test_static_cache.py @@ -147,7 +147,15 @@ def test_static_cache_u55_INT(test_data): @common.XfailIfNoCorstone320 -@common.parametrize("test_data", test_configs) +@common.parametrize( + "test_data", + test_configs, + xfails={ + "multihead_attention": "Incorrect numerical behavior: MLBEDSW-11589", + "grouped_query_attention": "Incorrect numerical behavior: MLBEDSW-11589", + "multi_query_attention": "Incorrect numerical behavior: MLBEDSW-11589", + }, +) def test_static_cache_u85_INT(test_data): module = StaticCacheModule(test_data).eval() pipeline = EthosU85PipelineINT[input_t]( diff --git a/backends/arm/test/ops/test_index_put.py b/backends/arm/test/ops/test_index_put.py index d4cccb547bd..d6bef6bbeb6 100644 --- a/backends/arm/test/ops/test_index_put.py +++ b/backends/arm/test/ops/test_index_put.py @@ -3,8 +3,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple - import torch from executorch.backends.arm._passes import InsertInt32CastsAfterInt64PlaceholdersPass @@ -122,6 +120,96 @@ ), 0, ), + "broadcast_values_scalar": ( + lambda: ( + torch.zeros((3, 4), dtype=torch.float32), + (torch.tensor([0, 2], dtype=torch.int64),), + torch.tensor([5.0], dtype=torch.float32), + False, + ), + 0, + ), + "broadcast_values_scalar_0d": ( + lambda: ( + torch.zeros((3, 4), dtype=torch.float32), + (torch.tensor([0, 2], dtype=torch.int64),), + torch.tensor(5.0, dtype=torch.float32), + False, + ), + 0, + ), + "broadcast_values_vector": ( + lambda: ( + torch.zeros((3, 4), dtype=torch.float32), + (torch.tensor([0, 2], dtype=torch.int64),), + torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32), + False, + ), + 0, + ), + "broadcast_values_with_implicit_w_and_c": ( + lambda: ( + torch.zeros((5, 2), dtype=torch.float32), + (torch.tensor([0, 2], dtype=torch.int64),), + torch.tensor([10.0, 20.0], dtype=torch.float32), + False, + ), + 0, + ), + "none_indices": ( + lambda: ( + torch.ones((5, 3, 2, 2), dtype=torch.float32), + (torch.IntTensor([2, 3, 0]), None), + torch.zeros(1), + False, + ), + 0, + ), + "none_indices_2": ( + lambda: ( + torch.ones((5, 3, 2, 2), dtype=torch.float32), + (None, torch.IntTensor([2, 0]), None), + torch.rand(2, 2, 2), + False, + ), + 0, + ), + "none_indices_3": ( + lambda: ( + torch.ones((5, 3, 2, 2), dtype=torch.float32), + (None, torch.IntTensor([2, 1, 0]), None, None), + torch.zeros(1), + False, + ), + 0, + ), + "none_indices_4": ( + lambda: ( + torch.ones((5, 3, 2, 2), dtype=torch.float32), + ( + None, + torch.IntTensor( + [ + 2, + ] + ), + None, + torch.IntTensor([0]), + ), + torch.zeros(1, 5, 2), + False, + ), + 0, + ), + "none_indices_5": ( + lambda: ( + torch.ones((5, 3, 2, 2), dtype=torch.float32), + (None, torch.IntTensor([2, 0]), None, torch.IntTensor([0])), + torch.zeros(2, 1, 2), + False, + ), + 0, + ), } test_data_int = { "rank3_zeros_int8": ( @@ -190,10 +278,12 @@ def forward( z: torch.Tensor, acc: bool, ): - return torch.index_put(x, indices=y, values=z, accumulate=acc) + # Needs to use aten op directly to allow None indices. + return torch.ops.aten.index_put.default(x, indices=y, values=z, accumulate=acc) -input_t = Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool], int] +input_t = tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool], int] +input_no_indices_t = tuple[torch.Tensor, torch.Tensor] xfails = { "same_index": "MLETORCH-1596: index_put with repeated indices not supported", @@ -243,7 +333,11 @@ def test_index_put_u55_INT(test_module: input_t): @common.XfailIfNoCorstone320 -@common.parametrize("test_module", test_data_suite_fp | test_data_int) +@common.parametrize( + "test_module", + test_data_suite_fp | test_data_int, + xfails={"none_indices_4": "Incorrect numerical behavior: MLBEDSW-11589"}, +) def test_index_put_u85_INT(test_module: input_t): """same_index test case already supported on u85 even though it is not supported by TOSA spec.