Skip to content

Commit 65e9611

Browse files
committed
make mxtensor printing nicer
Summary: Fix printing of linear weight wrapped with MXTensor. Test Plan: quantize a Qwen MoE model with mxfp4 and print it old version would print the data for each weight new version prints this: ```python (self_attn): Qwen2MoeSdpaAttention( (q_proj): Linear(in_features=2048, out_features=2048, weight=MXTensor(self._elem_dtype=torch.float4_e2m1fn_x2, self._block_size=32, torch.bfloat16, MXGemmKernelChoice.EMULATED, self.act_quant_kwargs=QuantizeTensorToMXKwargs(elem_dtype=torch.float4_e2m1fn_x2, block_size=32, scaling_mode=<ScaleCalculationMode.FLOOR: 'floor'>, gemm_kernel_choice=<MXGemmKernelChoice.EMULATED: 'emulated'>, pack_fp6=False))) ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: fc9e992 ghstack-comment-id: 3336009172 Pull Request resolved: #3068
1 parent f109317 commit 65e9611

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

torchao/prototype/mx_formats/inference_workflow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
QuantizeTensorToNVFP4Kwargs,
2525
per_tensor_amax_to_scale,
2626
)
27+
from torchao.quantization.quant_api import _quantization_type
2728
from torchao.quantization.transform_module import (
2829
register_quantize_module_handler,
2930
)
@@ -89,7 +90,7 @@ def __post_init__(self):
8990

9091

9192
def _linear_extra_repr(self):
92-
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={repr(self.weight)}"
93+
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}"
9394

9495

9596
@register_quantize_module_handler(MXFPInferenceConfig)

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,9 @@ def __repr__(self):
544544
# TODO better elem dtype print for fp4
545545
return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self._scale_e8m0}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}" # noqa: E501
546546

547+
def _quantization_type(self):
548+
return f"{self._elem_dtype=}, {self._block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}"
549+
547550
@classmethod
548551
def __torch_dispatch__(cls, func, types, args, kwargs=None):
549552
# avoid circular dependency

0 commit comments

Comments
 (0)