diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 862594161a8..080b9c503de 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -9,6 +9,7 @@ """ +import functools import logging import operator from dataclasses import dataclass, replace @@ -22,6 +23,7 @@ from torch._subclasses import FakeTensor from torch.fx import Node +from torchao.quantization.pt2e import PartialWrapper from torchao.quantization.pt2e.quantizer import ( annotate_input_qspec_map, annotate_output_qspec, @@ -85,21 +87,51 @@ def _as_list(x): def _adjust_weight_qspec_for_conv_transpose(node: Node, weight_qspec): if ( - node.target == torch.ops.aten.conv_transpose2d.input - and isinstance(weight_qspec, QuantizationSpec) - and weight_qspec.qscheme == torch.per_channel_symmetric - and weight_qspec.ch_axis != 1 + node.target != torch.ops.aten.conv_transpose2d.input + or not isinstance(weight_qspec, QuantizationSpec) + or weight_qspec.qscheme != torch.per_channel_symmetric ): - return QuantizationSpec( - dtype=weight_qspec.dtype, - observer_or_fake_quant_ctr=weight_qspec.observer_or_fake_quant_ctr, - quant_min=weight_qspec.quant_min, - quant_max=weight_qspec.quant_max, - qscheme=weight_qspec.qscheme, - ch_axis=1, - is_dynamic=weight_qspec.is_dynamic, + return weight_qspec + + # For now skip axis adjustment for a8w4 per-channel configs (int4 weights). + if weight_qspec.quant_min == -7 and weight_qspec.quant_max == 7: + return weight_qspec + + groups = 1 + if len(node.args) > 6 and isinstance(node.args[6], int): + groups = node.args[6] + expected_axis = 0 if groups != 1 else 1 + if weight_qspec.ch_axis == expected_axis: + return weight_qspec + + observer_or_fake_quant_ctr = weight_qspec.observer_or_fake_quant_ctr + # TorchAO PT2e QAT commonly represents the ctor as PartialWrapper(partial(...)). + # Rebuild it to update ch_axis while preserving callable_args. + if isinstance(observer_or_fake_quant_ctr, PartialWrapper): + original_callable_args = dict(observer_or_fake_quant_ctr.callable_args) + base_partial = observer_or_fake_quant_ctr.p + if isinstance(base_partial, functools.partial): + base_keywords = dict(base_partial.keywords or {}) + base_keywords["ch_axis"] = expected_axis + observer_or_fake_quant_ctr = PartialWrapper( + functools.partial(base_partial.func, **base_keywords) + ) + observer_or_fake_quant_ctr.callable_args = original_callable_args + # Non-QAT observer/fake-quant constructors can be updated via with_args. + elif hasattr(observer_or_fake_quant_ctr, "with_args"): + observer_or_fake_quant_ctr = observer_or_fake_quant_ctr.with_args( + ch_axis=expected_axis ) - return weight_qspec + + return QuantizationSpec( + dtype=weight_qspec.dtype, + observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, + quant_min=weight_qspec.quant_min, + quant_max=weight_qspec.quant_max, + qscheme=weight_qspec.qscheme, + ch_axis=expected_axis, + is_dynamic=weight_qspec.is_dynamic, + ) def _is_ok_for_quantization( diff --git a/backends/arm/test/ops/test_transpose_conv2d.py b/backends/arm/test/ops/test_transpose_conv2d.py index 436f57e8cbf..61b0ee2b0ee 100644 --- a/backends/arm/test/ops/test_transpose_conv2d.py +++ b/backends/arm/test/ops/test_transpose_conv2d.py @@ -11,16 +11,21 @@ from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_a16w8_quantization_config, get_symmetric_a8w4_quantization_config, + get_symmetric_quantization_config, + TOSAQuantizer, ) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, OpNotSupportedPipeline, + QuantizationPipeline, TosaPipelineFP, TosaPipelineINT, VgfPipeline, ) +from executorch.backends.arm.tosa.specification import TosaSpecification +from executorch.backends.test.harness.stages.quantize import Quantize aten_op = "torch.ops.aten.conv_transpose2d.input" exir_op = "executorch_exir_dialects_edge__ops_aten_convolution_default" # No edge transpoe conv @@ -94,6 +99,21 @@ def forward(self, x): for q in [True, False] } +test_data_QAT = { + "qat_basic": lambda: ( + TransposeConv2d( + in_channels=16, + out_channels=4, + kernel_size=4, + stride=2, + padding=1, + groups=1, + ), + True, + True, + ), +} + u55_supported_test_data_INT = { k: v for k, v in test_data_INT.items() @@ -150,6 +170,29 @@ def test_conv_transpose2d_tosa_INT(test_data): pipeline.run() +@common.parametrize("test_data", test_data_QAT) +def test_conv_transpose2d_tosa_INT_qat_per_channel_quantization_pipeline(test_data): + model, is_per_channel, is_qat = test_data() + inputs = model.get_inputs() + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) + quantizer.set_global( + get_symmetric_quantization_config( + is_per_channel=is_per_channel, + is_qat=is_qat, + ) + ) + pipeline = QuantizationPipeline[input_t](model, inputs, quantizer) + pipeline.change_args( + "quantize", + Quantize( + quantizer, + quantization_config=quantizer.global_config, + is_qat=is_qat, + ), + ) + pipeline.run() + + _a8w4_transpose_conv_xfails = { k: "per-channel int4 weight quantization is not supported for transpose conv yet." for k in test_data_INT