From b8151cce5c62f8cc99b5b1d2276298cc65bae623 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 3 Apr 2026 13:27:59 -0700 Subject: [PATCH] Remove references to torchao's AffineQuantizedTensor **Summary:** TorchAO recently deprecated AffineQuantizedTensor and related classes (https://github.com/pytorch/ao/issues/2752). These will be removed in the next release. We should remove references of these classes in diffusers before then. **Test Plan:** python -m pytest -s -v tests/quantization/torchao/test_torchao.py --- .../quantizers/torchao/torchao_quantizer.py | 19 +++------ tests/quantization/torchao/test_torchao.py | 42 ++++++++----------- 2 files changed, 23 insertions(+), 38 deletions(-) diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 88b45349daea..3a20dca88ecf 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -133,19 +133,10 @@ def fuzzy_match_size(config_name: str) -> str | None: return None -def _quantization_type(weight): - from torchao.dtypes import AffineQuantizedTensor - from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor - - if isinstance(weight, AffineQuantizedTensor): - return f"{weight.__class__.__name__}({weight._quantization_type()})" - - if isinstance(weight, LinearActivationQuantizedTensor): - return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})" - - def _linear_extra_repr(self): - weight = _quantization_type(self.weight) + from torchao.utils import TorchAOBaseTensor + + weight = self.weight.__class__.__name__ if isinstance(self.weight, TorchAOBaseTensor) else None if weight is None: return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None" else: @@ -283,12 +274,12 @@ def create_quantized_param( if self.pre_quantized: # If we're loading pre-quantized weights, replace the repr of linear layers for pretty printing info - # about AffineQuantizedTensor + # about the quantized tensor type module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device)) if isinstance(module, nn.Linear): module.extra_repr = types.MethodType(_linear_extra_repr, module) else: - # As we perform quantization here, the repr of linear layers is that of AQT, so we don't have to do it ourselves + # As we perform quantization here, the repr of linear layers is set by TorchAO, so we don't have to do it ourselves module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) quantize_(module, self.quantization_config.get_apply_tensor_subclass()) diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 7a05582cbfba..8a811cfc1c73 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -75,17 +75,17 @@ def _is_xpu_or_cuda_capability_atleast_8_9() -> bool: if is_torchao_available(): - from torchao.dtypes import AffineQuantizedTensor from torchao.quantization import ( Float8WeightOnlyConfig, + Int4Tensor, Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig, Int8DynamicActivationIntxWeightConfig, + Int8Tensor, Int8WeightOnlyConfig, IntxWeightOnlyConfig, ) - from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor - from torchao.utils import get_model_size_in_bytes + from torchao.utils import TorchAOBaseTensor, get_model_size_in_bytes @require_torch @@ -260,9 +260,7 @@ def test_int4wo_quant_bfloat16_conversion(self): ) weight = quantized_model.transformer_blocks[0].ff.net[2].weight - self.assertTrue(isinstance(weight, AffineQuantizedTensor)) - self.assertEqual(weight.quant_min, 0) - self.assertEqual(weight.quant_max, 15) + self.assertTrue(isinstance(weight, Int4Tensor)) def test_device_map(self): """ @@ -322,7 +320,7 @@ def test_device_map(self): if "transformer_blocks.0" in device_map: self.assertTrue(isinstance(weight, nn.Parameter)) else: - self.assertTrue(isinstance(weight, AffineQuantizedTensor)) + self.assertTrue(isinstance(weight, Int4Tensor)) output = quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() @@ -343,7 +341,7 @@ def test_device_map(self): if "transformer_blocks.0" in device_map: self.assertTrue(isinstance(weight, nn.Parameter)) else: - self.assertTrue(isinstance(weight, AffineQuantizedTensor)) + self.assertTrue(isinstance(weight, Int4Tensor)) output = quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() @@ -360,11 +358,11 @@ def test_modules_to_not_convert(self): unquantized_layer = quantized_model_with_not_convert.transformer_blocks[0].ff.net[2] self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear)) - self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor)) + self.assertFalse(isinstance(unquantized_layer.weight, Int8Tensor)) self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16) quantized_layer = quantized_model_with_not_convert.proj_out - self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor)) + self.assertTrue(isinstance(quantized_layer.weight, Int8Tensor)) quantization_config = TorchAoConfig(Int8WeightOnlyConfig()) quantized_model = FluxTransformer2DModel.from_pretrained( @@ -448,18 +446,18 @@ def test_memory_footprint(self): # Will not quantized all the layers by default due to the model weights shapes not being divisible by group_size=64 for block in transformer_int4wo.transformer_blocks: - self.assertTrue(isinstance(block.ff.net[2].weight, AffineQuantizedTensor)) - self.assertTrue(isinstance(block.ff_context.net[2].weight, AffineQuantizedTensor)) + self.assertTrue(isinstance(block.ff.net[2].weight, Int4Tensor)) + self.assertTrue(isinstance(block.ff_context.net[2].weight, Int4Tensor)) # Will quantize all the linear layers except x_embedder for name, module in transformer_int4wo_gs32.named_modules(): if isinstance(module, nn.Linear) and name not in ["x_embedder"]: - self.assertTrue(isinstance(module.weight, AffineQuantizedTensor)) + self.assertTrue(isinstance(module.weight, Int4Tensor)) # Will quantize all the linear layers for module in transformer_int8wo.modules(): if isinstance(module, nn.Linear): - self.assertTrue(isinstance(module.weight, AffineQuantizedTensor)) + self.assertTrue(isinstance(module.weight, Int8Tensor)) total_int4wo = get_model_size_in_bytes(transformer_int4wo) total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32) @@ -588,7 +586,7 @@ def _test_original_model_expected_slice(self, quant_type, expected_slice): output = quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() weight = quantized_model.transformer_blocks[0].ff.net[2].weight - self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))) + self.assertTrue(isinstance(weight, TorchAOBaseTensor)) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) def _check_serialization_expected_slice(self, quant_type, expected_slice, device): @@ -604,11 +602,7 @@ def _check_serialization_expected_slice(self, quant_type, expected_slice, device output = loaded_quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() - self.assertTrue( - isinstance( - loaded_quantized_model.proj_out.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor) - ) - ) + self.assertTrue(isinstance(loaded_quantized_model.proj_out.weight, TorchAOBaseTensor)) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) def test_int_a8w8_accelerator(self): @@ -756,7 +750,7 @@ def _test_quant_type(self, quantization_config, expected_slice): pipe.enable_model_cpu_offload() weight = pipe.transformer.transformer_blocks[0].ff.net[2].weight - self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))) + self.assertTrue(isinstance(weight, TorchAOBaseTensor)) inputs = self.get_dummy_inputs(torch_device) output = pipe(**inputs)[0].flatten() @@ -790,7 +784,7 @@ def test_serialization_int8wo(self): pipe.enable_model_cpu_offload() weight = pipe.transformer.x_embedder.weight - self.assertTrue(isinstance(weight, AffineQuantizedTensor)) + self.assertTrue(isinstance(weight, Int8Tensor)) inputs = self.get_dummy_inputs(torch_device) output = pipe(**inputs)[0].flatten()[:128] @@ -809,7 +803,7 @@ def test_serialization_int8wo(self): pipe.enable_model_cpu_offload() weight = transformer.x_embedder.weight - self.assertTrue(isinstance(weight, AffineQuantizedTensor)) + self.assertTrue(isinstance(weight, Int8Tensor)) loaded_output = pipe(**inputs)[0].flatten()[:128] # Seems to require higher tolerance depending on which machine it is being run. @@ -897,7 +891,7 @@ def test_transformer_int8wo(self): # Verify that all linear layer weights are quantized for name, module in pipe.transformer.named_modules(): if isinstance(module, nn.Linear): - self.assertTrue(isinstance(module.weight, AffineQuantizedTensor)) + self.assertTrue(isinstance(module.weight, Int8Tensor)) # Verify outputs match expected slice inputs = self.get_dummy_inputs(torch_device)