1111
1212from __future__ import annotations
1313
14+ from typing import Optional
15+
1416from onnxscript .function_libs .torch_lib .ops import common
1517from onnxscript .function_libs .torch_lib .registration import torch_op
1618from onnxscript .onnx_opset import opset18 as op
1719from onnxscript .onnx_opset import opset23 as op23
1820from onnxscript .onnx_types import TensorType
19- from typing import Optional
2021
2122
2223@torch_op (
@@ -84,7 +85,7 @@ def quantized_decomposed_quantize_per_channel(
8485) -> TensorType :
8586 """Affine per channel quantization for the Tensor using the same quantization
8687 parameters for each channel/axis to map from floating point to quantized values.
87-
88+
8889 Uses ONNX QuantizeLinear with per-axis quantization support.
8990 """
9091 # Use opset23 for per-axis quantization support
@@ -111,7 +112,7 @@ def quantized_decomposed_dequantize_per_channel(
111112) -> TensorType :
112113 """Affine per channel dequantization for the Tensor using the same quantization
113114 parameters for each channel/axis to map from quantized values to floating point values.
114-
115+
115116 Uses ONNX DequantizeLinear with per-axis quantization support.
116117 """
117118 # Use opset23 for per-axis quantization support with optional output_dtype
@@ -120,4 +121,6 @@ def quantized_decomposed_dequantize_per_channel(
120121 return op23 .DequantizeLinear (input , scales , zero_points , axis = axis )
121122 else :
122123 assert out_dtype > 0 , f"out_dtype must be -1 or > 0 not { out_dtype } "
123- return op23 .DequantizeLinear (input , scales , zero_points , axis = axis , output_dtype = out_dtype )
124+ return op23 .DequantizeLinear (
125+ input , scales , zero_points , axis = axis , output_dtype = out_dtype
126+ )
0 commit comments