diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index 528ceadaf19..80de190fedf 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -309,6 +309,11 @@ - arg_meta: null kernel_name: impl::generic::quantized_relu_asym8u_asym8u_per_tensor_out +- func: cadence::quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::generic::quantized_max_pool2d_out + - func: cadence::quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index cbc179e05d2..0f9effeac49 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -213,6 +213,13 @@ def register_fake( "quantized_relu.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor (a!)" ) +lib.define( + "quantized_max_pool2d(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor" +) +lib.define( + "quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)" +) + lib.define( "quantized_conv2d_nhwc(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)" ) @@ -2270,6 +2277,46 @@ def quantized_relu_asym8u_asym8u_per_tensor_meta( return input.new_empty(input.size(), dtype=input.dtype) +@register_fake("cadence::quantized_max_pool2d") +def quantized_max_pool2d_meta( + input: torch.Tensor, + kernel_size: list[int], + stride: list[int], + padding: list[int], + dilation: list[int], + ceil_mode: bool, +) -> torch.Tensor: + # Compute output size using standard max_pool2d formula + # Output = floor((Input + 2*Padding - Dilation*(Kernel-1) - 1) / Stride + 1) + batch = input.size(0) + channels = input.size(1) + height_in = input.size(2) + width_in = input.size(3) + + if ceil_mode: + import math + + height_out = math.ceil( + (height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) + / stride[0] + + 1 + ) + width_out = math.ceil( + (width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) + / stride[1] + + 1 + ) + else: + height_out = ( + height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1 + ) // stride[0] + 1 + width_out = ( + width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1 + ) // stride[1] + 1 + + return input.new_empty([batch, channels, height_out, width_out], dtype=input.dtype) + + @register_fake("cadence::fully_connected") def fully_connected_meta( src: torch.Tensor, diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 0cdda1ad3bc..49e6d536116 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -6,6 +6,7 @@ # pyre-strict +import operator as op_module from typing import Any, cast, Dict, List, Tuple import torch @@ -24,6 +25,8 @@ LayerNormPattern, LinearPattern, MatmulPattern, + MaxPool2dPattern, + MaxPool2dWithoutIndicesPattern, MixedW8A32ConvPattern, MixedW8A32GruPattern, MixedW8A32LinearPattern, @@ -457,6 +460,34 @@ def get_args_and_kwargs_mixed_w8a32_conv( return args, kwargs +def get_args_and_kwargs_max_pool2d( + inputs_inputs: List[fx.Node], + op_node: fx.Node, +) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: + """ + Returns the args and kwargs for the max_pool2d replacement op. + + Max pooling is order-preserving, so we can perform the max operation + directly on quantized values without any requantization. + """ + # Get the pooling parameters from the original op node + kernel_size = op_node.args[1] if len(op_node.args) > 1 else [1, 1] + stride = op_node.args[2] if len(op_node.args) > 2 else kernel_size + padding = op_node.args[3] if len(op_node.args) > 3 else [0, 0] + dilation = op_node.args[4] if len(op_node.args) > 4 else [1, 1] + ceil_mode = op_node.args[5] if len(op_node.args) > 5 else False + + args = (inputs_inputs[0],) + kwargs = { + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "dilation": dilation, + "ceil_mode": ceil_mode, + } + return args, kwargs + + def get_args_and_kwargs_mixed_w8a32_gru( graph_module: GraphModule, other_inputs: List[fx.Node], @@ -549,6 +580,10 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 assert op_node is not None, "op_node is None" quant_node = list(op_node.users.keys())[0] + # For ops that return tuples (e.g., max_pool2d_with_indices), + # traverse through the getitem to find the actual quant node + if quant_node.target is op_module.getitem: + quant_node = list(quant_node.users.keys())[0] with graph_module.graph.inserting_after(op_node): args = tuple( @@ -697,6 +732,11 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 dequants_biases, op_node, ) + elif isinstance(pattern, (MaxPool2dPattern, MaxPool2dWithoutIndicesPattern)): + args, kwargs = get_args_and_kwargs_max_pool2d( + inputs_inputs, + op_node, + ) fused = graph_module.graph.call_function( pattern.replacement_op(), diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 7a11541b601..0d52c004dea 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -417,6 +417,90 @@ def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_matmul.default +class MaxPool2dPattern(QuantizationPattern): + """ + Pattern for quantized max pooling (with indices variant). + + Max pooling is order-preserving, so max(a, b) in the quantized domain gives + the same result as quantizing max(dequant(a), dequant(b)) when using the same + scale/zero_point. This means we can perform max pooling directly on quantized + values without any requantization. + + The input and output share quantization parameters. + """ + + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.max_pool2d_with_indices.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> Tuple[PartitionAnchors, fx.Node]: + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + max_pool_node = fused_partition[0].nodes[-1] + + # Input and output share quantization parameters since max is order-preserving + return ( + PartitionAnchors( + inputs=[(max_pool_node, 0)], + weights=[], + biases=[], + # kernel_size, stride, padding, dilation, ceil_mode are literals + literals=[ + (max_pool_node, i) for i in range(1, len(max_pool_node.args)) + ], + output=[ + ( + max_pool_node, + SharedQuantizationSpec((max_pool_node.args[0], max_pool_node)), + ) + ], + ), + max_pool_node, + ) + + def replacement_op(self) -> OpOverload: + return torch.ops.cadence.quantized_max_pool2d.default + + +class MaxPool2dWithoutIndicesPattern(QuantizationPattern): + """ + Pattern for quantized max pooling (without indices variant). + + Same as MaxPool2dPattern but matches aten.max_pool2d.default which returns + a single tensor instead of a tuple (values, indices). + """ + + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.max_pool2d.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> Tuple[PartitionAnchors, fx.Node]: + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + max_pool_node = fused_partition[0].nodes[-1] + + return ( + PartitionAnchors( + inputs=[(max_pool_node, 0)], + weights=[], + biases=[], + literals=[ + (max_pool_node, i) for i in range(1, len(max_pool_node.args)) + ], + output=[ + ( + max_pool_node, + SharedQuantizationSpec((max_pool_node.args[0], max_pool_node)), + ) + ], + ), + max_pool_node, + ) + + def replacement_op(self) -> OpOverload: + return torch.ops.cadence.quantized_max_pool2d.default + + # This is a base class for ReLU, since it can be used with two different aten ops class ReluBasePattern(QuantizationPattern): @abstractmethod diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index bdd4cc810a0..9399efe632a 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -24,6 +24,8 @@ LayerNormPattern, LinearPattern, MatmulPattern, + MaxPool2dPattern, + MaxPool2dWithoutIndicesPattern, MixedW8A32ConvPattern, MixedW8A32GruPattern, MixedW8A32LinearPattern, @@ -227,6 +229,8 @@ def get_cadence_default_quantizers() -> List[Quantizer]: CadenceAtenQuantizer(Conv2dPattern(), qconfig_A8W8sym), CadenceAtenQuantizer(LinearPattern(), qconfig_A8W8), CadenceAtenQuantizer(MatmulPattern(), qconfig_A8W8), + CadenceAtenQuantizer(MaxPool2dPattern(), qconfig_A8W8), + CadenceAtenQuantizer(MaxPool2dWithoutIndicesPattern(), qconfig_A8W8), CadenceAtenQuantizer(ReluPattern0(), qconfig_A8W8), CadenceAtenQuantizer(ReluPattern1(), qconfig_A8W8), ] diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 44cae6e55ea..ed8b3ca60ae 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1868,6 +1868,35 @@ def rms_norm( return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X) +@impl_tracked(m, "quantized_max_pool2d") +def quantized_max_pool2d( + input: torch.Tensor, + kernel_size: list[int], + stride: list[int], + padding: list[int], + dilation: list[int], + ceil_mode: bool, +) -> torch.Tensor: + """ + Quantized max pooling operation. + + Max pooling is order-preserving, so max(a, b) in the quantized domain gives + the same result as quantizing max(dequant(a), dequant(b)) when using the same + scale/zero_point. This means we can perform max pooling directly on quantized + integer values without dequantization/requantization. + """ + # Directly apply max_pool2d on quantized values + # Since max is order-preserving, the result is correct without any dequant/requant + return F.max_pool2d( + input, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + ) + + @impl_tracked(m, "where_Scalar") def where_Scalar( condition: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index 831ab3b95b6..813eb886bbb 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -206,7 +206,26 @@ # Use None to skip comparison for bias since it's a DerivedQuantizationSpec [None, qconfig_A8W8.input_activation, qconfig_A8W8.weight], ), + ( + "default_max_pool2d_A8W8", + lambda self: self._build_max_pool2d_graph(), + CadenceDefaultQuantizer(), + torch.ops.aten.max_pool2d_with_indices.default, + # Output uses SharedQuantizationSpec, but the base qspec should match input + qconfig_A8W8.output_activation, + # For max_pool2d: only input_activation (no weights, order-preserving op) + [qconfig_A8W8.input_activation], + ), # CadenceFusedConvReluQuantizer test cases + ( + "fused_conv1d_relu_A8W8sym", + lambda self: self._build_conv1d_relu_graph(), + CadenceFusedConvReluQuantizer(), + torch.ops.aten.relu.default, + qconfig_A8W8sym.output_activation, + # For fused conv1d+relu: [input_activation, weight] from conv1d node + [qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight], + ), ( "fused_conv2d_relu_A8W8sym", lambda self: self._build_conv2d_relu_graph(), @@ -457,6 +476,38 @@ def _build_addmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: self.assertEqual(len(addmm_nodes), 1, "Should find exactly one addmm node") return gm, addmm_nodes[0] + def _build_max_pool2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: + """Build a simple graph with a max_pool2d_with_indices operation.""" + builder = GraphBuilder() + # Input shape: (batch, channels, height, width) + x = builder.placeholder("x", torch.randn(1, 3, 8, 8)) + # max_pool2d_with_indices args: (input, kernel_size, stride, padding, dilation, ceil_mode) + max_pool = builder.call_operator( + op=torch.ops.aten.max_pool2d_with_indices.default, + args=(x, [2, 2], [2, 2], [0, 0], [1, 1], False), + meta=NodeMetadata( + { + "source_fn_stack": [ + ( + "max_pool2d_with_indices", + torch.ops.aten.max_pool2d_with_indices.default, + ) + ] + } + ), + ) + builder.output([max_pool]) + gm = builder.get_graph_module() + + max_pool_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.max_pool2d_with_indices.default, + ) + self.assertEqual( + len(max_pool_nodes), 1, "Should find exactly one max_pool2d_with_indices node" + ) + return gm, max_pool_nodes[0] + def _build_conv2d_relu_graph( self, ) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]: @@ -503,6 +554,52 @@ def _build_conv2d_relu_graph( return gm, relu_nodes[0], conv2d_nodes[0] + def _build_conv1d_relu_graph( + self, + ) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]: + """Build a graph with a conv1d followed by relu (fused pattern). + + Returns: + A tuple of (graph_module, relu_node, conv_node). + The relu_node is the target node where the annotation is placed. + The conv_node is the input source node whose args contain the quantized inputs. + """ + builder = GraphBuilder() + # Input shape: (batch, in_channels, length) + x = builder.placeholder("x", torch.randn(1, 3, 10)) + # Weight shape: (out_channels, in_channels, kernel_size) + weight = builder.placeholder("weight", torch.randn(6, 3, 3)) + conv1d = builder.call_operator( + op=torch.ops.aten.conv1d.default, + args=(x, weight), + meta=NodeMetadata( + {"source_fn_stack": [("conv1d", torch.ops.aten.conv1d.default)]} + ), + ) + relu = builder.call_operator( + op=torch.ops.aten.relu.default, + args=(conv1d,), + meta=NodeMetadata( + {"source_fn_stack": [("relu", torch.ops.aten.relu.default)]} + ), + ) + builder.output([relu]) + gm = builder.get_graph_module() + + relu_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.relu.default, + ) + self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node") + + conv1d_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.conv1d.default, + ) + self.assertEqual(len(conv1d_nodes), 1, "Should find exactly one conv1d node") + + return gm, relu_nodes[0], conv1d_nodes[0] + @parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES) def test_quantizer_annotation( self, @@ -608,6 +705,8 @@ def test_default_quantizer_ops_to_preserve(self) -> None: torch.ops.aten.conv2d.default, torch.ops.aten.linear.default, torch.ops.aten.matmul.default, + torch.ops.aten.max_pool2d_with_indices.default, + torch.ops.aten.max_pool2d.default, torch.ops.aten.relu.default, torch.ops.aten.relu_.default, ] diff --git a/backends/cadence/generic/operators/op_quantized_max_pool2d.cpp b/backends/cadence/generic/operators/op_quantized_max_pool2d.cpp new file mode 100644 index 00000000000..b5e4a5b368a --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_max_pool2d.cpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +namespace { + +template +void quantized_max_pool2d_impl( + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + ET_UNUSED bool ceil_mode, + Tensor& output) { + const T* __restrict__ in_data = input.const_data_ptr(); + T* __restrict__ out_data = output.mutable_data_ptr(); + + // Input dimensions: [N, C, H, W] + const int64_t batch_size = input.size(0); + const int64_t channels = input.size(1); + const int64_t in_height = input.size(2); + const int64_t in_width = input.size(3); + + // Output dimensions + const int64_t out_height = output.size(2); + const int64_t out_width = output.size(3); + + // Pooling parameters + const int64_t kernel_h = kernel_size[0]; + const int64_t kernel_w = kernel_size[1]; + const int64_t stride_h = stride[0]; + const int64_t stride_w = stride[1]; + const int64_t pad_h = padding[0]; + const int64_t pad_w = padding[1]; + const int64_t dilation_h = dilation[0]; + const int64_t dilation_w = dilation[1]; + + // Iterate over batch and channels + for (int64_t n = 0; n < batch_size; ++n) { + for (int64_t c = 0; c < channels; ++c) { + // Iterate over output spatial dimensions + for (int64_t oh = 0; oh < out_height; ++oh) { + for (int64_t ow = 0; ow < out_width; ++ow) { + // Compute the input region for this output pixel + const int64_t ih_start = oh * stride_h - pad_h; + const int64_t iw_start = ow * stride_w - pad_w; + + // Initialize with minimum value for the type + T max_val = std::numeric_limits::lowest(); + + // Iterate over the kernel + for (int64_t kh = 0; kh < kernel_h; ++kh) { + for (int64_t kw = 0; kw < kernel_w; ++kw) { + const int64_t ih = ih_start + kh * dilation_h; + const int64_t iw = iw_start + kw * dilation_w; + + // Check bounds + if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) { + const int64_t in_idx = + ((n * channels + c) * in_height + ih) * in_width + iw; + max_val = std::max(max_val, in_data[in_idx]); + } + } + } + + // Write output + const int64_t out_idx = + ((n * channels + c) * out_height + oh) * out_width + ow; + out_data[out_idx] = max_val; + } + } + } + } +} + +} // namespace + +Tensor& quantized_max_pool2d_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode, + Tensor& output) { +#define typed_quantized_max_pool2d(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_max_pool2d_impl( \ + input, kernel_size, stride, padding, dilation, ceil_mode, output); \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_max_pool2d) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_max_pool2d + return output; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_max_pool2d.h b/backends/cadence/generic/operators/op_quantized_max_pool2d.h new file mode 100644 index 00000000000..07f406a37a7 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_max_pool2d.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_max_pool2d_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + ::executorch::aten::IntArrayRef kernel_size, + ::executorch::aten::IntArrayRef stride, + ::executorch::aten::IntArrayRef padding, + ::executorch::aten::IntArrayRef dilation, + bool ceil_mode, + ::executorch::aten::Tensor& output); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/targets.bzl b/backends/cadence/generic/operators/targets.bzl index faa63e4f46f..bf1de9e009a 100644 --- a/backends/cadence/generic/operators/targets.bzl +++ b/backends/cadence/generic/operators/targets.bzl @@ -213,6 +213,18 @@ def define_common_targets(): visibility = ["PUBLIC"], ) + runtime.cxx_library( + name = "op_quantized_max_pool2d", + srcs = ["op_quantized_max_pool2d.cpp"], + exported_headers = ["op_quantized_max_pool2d.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + ":cadence_type_util", + ], + visibility = ["PUBLIC"], + ) + runtime.cxx_library( name = "op_quantized_matmul", srcs = ["op_quantized_matmul.cpp"],