Skip to content

Commit b34c103

Browse files
authored
Remove unused attributes in Float8Tensor (#2935)
Removing unused attributes in Float8Tensor Summary: att, hp_value_lb and hp_value_ub for weight are only used when calculating scale for the float8 tensor, doesn't have to be stored in the tensor itself. This PR removes it. We also have BC testing to make sure the change does not break BC. Test Plan: Regression tests: python test/integration/test_load_and_run_checkpoint.py Reviewers: Subscribers: Tasks: Tags:
1 parent aff141e commit b34c103

File tree

1 file changed

+0
-24
lines changed

1 file changed

+0
-24
lines changed

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,6 @@ class Float8Tensor(TorchAOBaseTensor):
8585
sharing the same set of quantization parameters (scale), have the same rank as qdata or
8686
is an empty list (representing per tensor quantization)
8787
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
88-
hp_value_lb (Optional[float]): the lower bound for high precision floating point value for calculating scale
89-
hp_value_ub (Optional[float]): the upper bound for high precision floating point value for calculating scale
9088
act_quant_kwargs (QuantizeTensorToFloat8Kwargs): the kwargs for Float8Tensor.from_hp
9189
kernel_preference (KernelPreference): the preference for quantize, mm etc. kernel to use,
9290
by default, this will be chosen for user based on hardware, library availabilities etc.
@@ -98,8 +96,6 @@ class Float8Tensor(TorchAOBaseTensor):
9896
optional_tensor_attribute_names = [
9997
"block_size",
10098
"mm_config",
101-
"hp_value_lb",
102-
"hp_value_ub",
10399
"act_quant_kwargs",
104100
"kernel_preference",
105101
"dtype",
@@ -111,8 +107,6 @@ def __new__(
111107
scale: torch.Tensor,
112108
block_size: Optional[List[int]] = None,
113109
mm_config: Optional[Float8MMConfig] = None,
114-
hp_value_lb: Optional[float] = None,
115-
hp_value_ub: Optional[float] = None,
116110
act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None,
117111
kernel_preference: KernelPreference = KernelPreference.AUTO,
118112
dtype: Optional[torch.dtype] = None,
@@ -130,8 +124,6 @@ def __init__(
130124
scale: torch.Tensor,
131125
block_size: Optional[List[int]] = None,
132126
mm_config: Optional[Float8MMConfig] = None,
133-
hp_value_lb: Optional[float] = None,
134-
hp_value_ub: Optional[float] = None,
135127
act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None,
136128
kernel_preference: KernelPreference = KernelPreference.AUTO,
137129
dtype: Optional[torch.dtype] = None,
@@ -141,8 +133,6 @@ def __init__(
141133
self.scale = scale
142134
self.block_size = block_size
143135
self.mm_config = mm_config
144-
self.hp_value_lb = hp_value_lb
145-
self.hp_value_ub = hp_value_ub
146136
self.act_quant_kwargs = act_quant_kwargs
147137
self.kernel_preference = kernel_preference
148138

@@ -248,8 +238,6 @@ def from_hp(
248238
scale,
249239
block_size=block_size,
250240
mm_config=mm_config,
251-
hp_value_lb=hp_value_lb,
252-
hp_value_ub=hp_value_ub,
253241
act_quant_kwargs=act_quant_kwargs,
254242
kernel_preference=kernel_preference,
255243
dtype=hp_dtype,
@@ -472,8 +460,6 @@ def _(func, types, args, kwargs):
472460
sliced_scale,
473461
block_size,
474462
self.mm_config,
475-
self.hp_value_lb,
476-
self.hp_value_ub,
477463
self.act_quant_kwargs,
478464
self.kernel_preference,
479465
dtype=self.dtype,
@@ -503,8 +489,6 @@ def _(func, types, args, kwargs):
503489
assert tensor_0.scale.ndim == tensors[i].scale.ndim
504490
assert tensor_0.block_size == tensors[i].block_size
505491
assert tensor_0.mm_config == tensors[i].mm_config
506-
assert tensor_0.hp_value_lb == tensors[i].hp_value_lb
507-
assert tensor_0.hp_value_ub == tensors[i].hp_value_ub
508492
assert tensor_0.act_quant_kwargs == tensors[i].act_quant_kwargs
509493
assert tensor_0.kernel_preference == tensors[i].kernel_preference
510494

@@ -528,8 +512,6 @@ def _(func, types, args, kwargs):
528512
cat_scale,
529513
block_size,
530514
tensor_0.mm_config,
531-
tensor_0.hp_value_lb,
532-
tensor_0.hp_value_ub,
533515
tensor_0.act_quant_kwargs,
534516
tensor_0.kernel_preference,
535517
tensor_0.dtype,
@@ -551,8 +533,6 @@ def _(func, types, args, kwargs):
551533
scale,
552534
block_size,
553535
self.mm_config,
554-
self.hp_value_lb,
555-
self.hp_value_ub,
556536
self.act_quant_kwargs,
557537
self.kernel_preference,
558538
self.dtype,
@@ -603,8 +583,6 @@ def _(func, types, args, kwargs):
603583
scale,
604584
block_size,
605585
self.mm_config,
606-
self.hp_value_lb,
607-
self.hp_value_ub,
608586
self.act_quant_kwargs,
609587
self.kernel_preference,
610588
self.dtype,
@@ -627,8 +605,6 @@ def _(func, types, args, kwargs):
627605
scale,
628606
block_size,
629607
self.mm_config,
630-
self.hp_value_lb,
631-
self.hp_value_ub,
632608
self.act_quant_kwargs,
633609
self.kernel_preference,
634610
self.dtype,

0 commit comments

Comments
 (0)