Skip to content

Commit 752502b

Browse files
authored
make mxtensor printing nicer (#3068)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 894b46e commit 752502b

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

torchao/prototype/mx_formats/inference_workflow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
QuantizeTensorToNVFP4Kwargs,
2525
per_tensor_amax_to_scale,
2626
)
27+
from torchao.quantization.quant_api import _quantization_type
2728
from torchao.quantization.transform_module import (
2829
register_quantize_module_handler,
2930
)
@@ -89,7 +90,7 @@ def __post_init__(self):
8990

9091

9192
def _linear_extra_repr(self):
92-
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={repr(self.weight)}"
93+
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}"
9394

9495

9596
@register_quantize_module_handler(MXFPInferenceConfig)

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,9 @@ def __repr__(self):
544544
# TODO better elem dtype print for fp4
545545
return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self._scale_e8m0}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}" # noqa: E501
546546

547+
def _quantization_type(self):
548+
return f"{self._elem_dtype=}, {self._block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}"
549+
547550
@classmethod
548551
def __torch_dispatch__(cls, func, types, args, kwargs=None):
549552
# avoid circular dependency

0 commit comments

Comments
 (0)