Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backends/cadence/aot/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,11 @@
- arg_meta: null
kernel_name: impl::generic::quantized_relu_asym8u_asym8u_per_tensor_out

- func: cadence::quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::generic::quantized_max_pool2d_out

- func: cadence::quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
Expand Down
47 changes: 47 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,13 @@ def register_fake(
"quantized_relu.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor (a!)"
)

lib.define(
"quantized_max_pool2d(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
)
lib.define(
"quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
)

lib.define(
"quantized_conv2d_nhwc(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)"
)
Expand Down Expand Up @@ -2270,6 +2277,46 @@ def quantized_relu_asym8u_asym8u_per_tensor_meta(
return input.new_empty(input.size(), dtype=input.dtype)


@register_fake("cadence::quantized_max_pool2d")
def quantized_max_pool2d_meta(
input: torch.Tensor,
kernel_size: list[int],
stride: list[int],
padding: list[int],
dilation: list[int],
ceil_mode: bool,
) -> torch.Tensor:
# Compute output size using standard max_pool2d formula
# Output = floor((Input + 2*Padding - Dilation*(Kernel-1) - 1) / Stride + 1)
batch = input.size(0)
channels = input.size(1)
height_in = input.size(2)
width_in = input.size(3)

if ceil_mode:
import math

height_out = math.ceil(
(height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1)
/ stride[0]
+ 1
)
width_out = math.ceil(
(width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1)
/ stride[1]
+ 1
)
else:
height_out = (
height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
) // stride[0] + 1
width_out = (
width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
) // stride[1] + 1

return input.new_empty([batch, channels, height_out, width_out], dtype=input.dtype)


@register_fake("cadence::fully_connected")
def fully_connected_meta(
src: torch.Tensor,
Expand Down
40 changes: 40 additions & 0 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -6,6 +6,7 @@

# pyre-strict

import operator as op_module
from typing import Any, cast, Dict, List, Tuple

import torch
Expand All @@ -24,6 +25,8 @@
LayerNormPattern,
LinearPattern,
MatmulPattern,
MaxPool2dPattern,
MaxPool2dWithoutIndicesPattern,
MixedW8A32ConvPattern,
MixedW8A32GruPattern,
MixedW8A32LinearPattern,
Expand Down Expand Up @@ -457,6 +460,34 @@
return args, kwargs


def get_args_and_kwargs_max_pool2d(
inputs_inputs: List[fx.Node],
op_node: fx.Node,
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
"""
Returns the args and kwargs for the max_pool2d replacement op.

Max pooling is order-preserving, so we can perform the max operation
directly on quantized values without any requantization.
"""
# Get the pooling parameters from the original op node
kernel_size = op_node.args[1] if len(op_node.args) > 1 else [1, 1]
stride = op_node.args[2] if len(op_node.args) > 2 else kernel_size
padding = op_node.args[3] if len(op_node.args) > 3 else [0, 0]
dilation = op_node.args[4] if len(op_node.args) > 4 else [1, 1]
ceil_mode = op_node.args[5] if len(op_node.args) > 5 else False

args = (inputs_inputs[0],)
kwargs = {
"kernel_size": kernel_size,
"stride": stride,
"padding": padding,
"dilation": dilation,
"ceil_mode": ceil_mode,
}
return args, kwargs


def get_args_and_kwargs_mixed_w8a32_gru(
graph_module: GraphModule,
other_inputs: List[fx.Node],
Expand Down Expand Up @@ -549,6 +580,10 @@

assert op_node is not None, "op_node is None"
quant_node = list(op_node.users.keys())[0]
# For ops that return tuples (e.g., max_pool2d_with_indices),
# traverse through the getitem to find the actual quant node
if quant_node.target is op_module.getitem:
quant_node = list(quant_node.users.keys())[0]

with graph_module.graph.inserting_after(op_node):
args = tuple(
Expand Down Expand Up @@ -697,6 +732,11 @@
dequants_biases,
op_node,
)
elif isinstance(pattern, (MaxPool2dPattern, MaxPool2dWithoutIndicesPattern)):
args, kwargs = get_args_and_kwargs_max_pool2d(
inputs_inputs,
op_node,
)

fused = graph_module.graph.call_function(
pattern.replacement_op(),
Expand Down
84 changes: 84 additions & 0 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,90 @@ def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_matmul.default


class MaxPool2dPattern(QuantizationPattern):
"""
Pattern for quantized max pooling (with indices variant).

Max pooling is order-preserving, so max(a, b) in the quantized domain gives
the same result as quantizing max(dequant(a), dequant(b)) when using the same
scale/zero_point. This means we can perform max pooling directly on quantized
values without any requantization.

The input and output share quantization parameters.
"""

def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.max_pool2d_with_indices.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
max_pool_node = fused_partition[0].nodes[-1]

# Input and output share quantization parameters since max is order-preserving
return (
PartitionAnchors(
inputs=[(max_pool_node, 0)],
weights=[],
biases=[],
# kernel_size, stride, padding, dilation, ceil_mode are literals
literals=[
(max_pool_node, i) for i in range(1, len(max_pool_node.args))
],
output=[
(
max_pool_node,
SharedQuantizationSpec((max_pool_node.args[0], max_pool_node)),
)
],
),
max_pool_node,
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_max_pool2d.default


class MaxPool2dWithoutIndicesPattern(QuantizationPattern):
"""
Pattern for quantized max pooling (without indices variant).

Same as MaxPool2dPattern but matches aten.max_pool2d.default which returns
a single tensor instead of a tuple (values, indices).
"""

def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.max_pool2d.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
max_pool_node = fused_partition[0].nodes[-1]

return (
PartitionAnchors(
inputs=[(max_pool_node, 0)],
weights=[],
biases=[],
literals=[
(max_pool_node, i) for i in range(1, len(max_pool_node.args))
],
output=[
(
max_pool_node,
SharedQuantizationSpec((max_pool_node.args[0], max_pool_node)),
)
],
),
max_pool_node,
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_max_pool2d.default


# This is a base class for ReLU, since it can be used with two different aten ops
class ReluBasePattern(QuantizationPattern):
@abstractmethod
Expand Down
4 changes: 4 additions & 0 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
LayerNormPattern,
LinearPattern,
MatmulPattern,
MaxPool2dPattern,
MaxPool2dWithoutIndicesPattern,
MixedW8A32ConvPattern,
MixedW8A32GruPattern,
MixedW8A32LinearPattern,
Expand Down Expand Up @@ -227,6 +229,8 @@ def get_cadence_default_quantizers() -> List[Quantizer]:
CadenceAtenQuantizer(Conv2dPattern(), qconfig_A8W8sym),
CadenceAtenQuantizer(LinearPattern(), qconfig_A8W8),
CadenceAtenQuantizer(MatmulPattern(), qconfig_A8W8),
CadenceAtenQuantizer(MaxPool2dPattern(), qconfig_A8W8),
CadenceAtenQuantizer(MaxPool2dWithoutIndicesPattern(), qconfig_A8W8),
CadenceAtenQuantizer(ReluPattern0(), qconfig_A8W8),
CadenceAtenQuantizer(ReluPattern1(), qconfig_A8W8),
]
Expand Down
29 changes: 29 additions & 0 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1868,6 +1868,35 @@ def rms_norm(
return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X)


@impl_tracked(m, "quantized_max_pool2d")
def quantized_max_pool2d(
input: torch.Tensor,
kernel_size: list[int],
stride: list[int],
padding: list[int],
dilation: list[int],
ceil_mode: bool,
) -> torch.Tensor:
"""
Quantized max pooling operation.

Max pooling is order-preserving, so max(a, b) in the quantized domain gives
the same result as quantizing max(dequant(a), dequant(b)) when using the same
scale/zero_point. This means we can perform max pooling directly on quantized
integer values without dequantization/requantization.
"""
# Directly apply max_pool2d on quantized values
# Since max is order-preserving, the result is correct without any dequant/requant
return F.max_pool2d(
input,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
ceil_mode=ceil_mode,
)


@impl_tracked(m, "where_Scalar")
def where_Scalar(
condition: torch.Tensor,
Expand Down
Loading
Loading