Skip to content

Commit 1a36a0c

Browse files
committed
Remove unwrap_tensor_subclass
1 parent 149e23d commit 1a36a0c

File tree

10 files changed

+3
-37
lines changed

10 files changed

+3
-37
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
e6f766c7d750d40603eee3f66c5915bac606b3ea
1+
72df1db744431d24ee1a4c0e42e514426ce0d45f

backends/test/harness/stages/quantize.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
)
1717
from torchao.quantization.pt2e.quantizer import Quantizer
1818
from torchao.quantization.quant_api import quantize_
19-
from torchao.utils import unwrap_tensor_subclass
2019

2120

2221
class Quantize(Stage):
@@ -111,9 +110,6 @@ def run(
111110
# Apply quantize_ to the model
112111
quantize_(artifact, self.config, self.filter_fn)
113112

114-
# Unwrap tensor subclasses for export compatibility
115-
unwrap_tensor_subclass(artifact)
116-
117113
self.quantized_module = artifact
118114

119115
@property

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343

4444
from torchao.quantization.pt2e.quantizer import Quantizer
4545
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
46-
from torchao.utils import unwrap_tensor_subclass
4746

4847
try:
4948
ctypes.CDLL("libvulkan.so.1")
@@ -2363,7 +2362,6 @@ def apply_quantization(self):
23632362
granularity=self.quant_granularity,
23642363
)
23652364
quantize_(self, q_config)
2366-
unwrap_tensor_subclass(self)
23672365
return self
23682366

23692367
# Test with GEMV pattern (batch_size=1, seq_len=1)
@@ -2686,15 +2684,13 @@ def apply_8da4w_quantization(self):
26862684
quantize_,
26872685
)
26882686
from torchao.quantization.granularity import PerGroup
2689-
from torchao.utils import unwrap_tensor_subclass
26902687

26912688
quantize_(
26922689
self,
26932690
Int8DynamicActivationIntxWeightConfig(
26942691
weight_dtype=torch.int4, granularity=PerGroup(self.group_size)
26952692
),
26962693
)
2697-
unwrap_tensor_subclass(self)
26982694
return self
26992695

27002696
# Test with GEMV pattern (batch_size=1, seq_len=1)

backends/xnnpack/test/ops/test_linear.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
Int8DynamicActivationIntxWeightConfig,
4040
quantize_,
4141
)
42-
from torchao.utils import unwrap_tensor_subclass
4342

4443
torchao_installed = True
4544
except:
@@ -400,7 +399,6 @@ def _test_groupwise_dq_linear(
400399
weight_granularity=PerGroup(group_size),
401400
),
402401
)
403-
unwrap_tensor_subclass(mod)
404402
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
405403
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
406404
per_op_mode=True,

examples/apple/coreml/llama/export.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
from torchao.quantization.granularity import PerAxis, PerGroup
3030
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
31-
from torchao.utils import unwrap_tensor_subclass
3231

3332

3433
def main() -> None:
@@ -193,8 +192,6 @@ def main() -> None:
193192
)
194193
example_inputs = input_manager.get_inputs(tokens=[0])
195194

196-
model = unwrap_tensor_subclass(model)
197-
198195
ep = torch.export.export(model, example_inputs, strict=True)
199196
print("Exported program")
200197
print(ep)

examples/models/llama/source_transformation/quantize.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def quantize( # noqa C901
122122
Int8DynamicActivationIntxWeightConfig,
123123
quantize_,
124124
)
125-
from torchao.utils import unwrap_tensor_subclass
126125

127126
with torch.no_grad():
128127
# Computation dtype is fixed to fp32 in the implementation of quantize_, so
@@ -142,7 +141,6 @@ def quantize( # noqa C901
142141
),
143142
),
144143
)
145-
model = unwrap_tensor_subclass(model)
146144
if verbose:
147145
print("quantized model:", model)
148146
return model
@@ -156,7 +154,6 @@ def quantize( # noqa C901
156154
quantize_,
157155
)
158156
from torchao.quantization.granularity import PerGroup
159-
from torchao.utils import unwrap_tensor_subclass
160157

161158
def filter_fn(m, fqn):
162159
is_linear = isinstance(m, nn.Linear)
@@ -181,8 +178,6 @@ def filter_fn(m, fqn):
181178
filter_fn=filter_fn,
182179
)
183180

184-
model = unwrap_tensor_subclass(model)
185-
186181
# TODO: deal with checkpoint / computation dtype decoupling.
187182

188183
if verbose:
@@ -191,7 +186,6 @@ def filter_fn(m, fqn):
191186
elif qmode == "4w":
192187
from torchao.quantization.granularity import PerGroup
193188
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
194-
from torchao.utils import unwrap_tensor_subclass
195189

196190
q_group_size = 256 if group_size is None else group_size
197191
q_config = IntxWeightOnlyConfig(
@@ -204,7 +198,6 @@ def filter_fn(m, fqn):
204198
),
205199
)
206200
quantize_(model, q_config)
207-
model = unwrap_tensor_subclass(model)
208201

209202
return model
210203
else:

export/stages.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
ComposableQuantizer,
2727
Quantizer as TorchAOPT2EQuantizer,
2828
)
29-
from torchao.utils import unwrap_tensor_subclass
3029

3130

3231
class PipelineArtifact:
@@ -344,7 +343,6 @@ def run(self, artifact: PipelineArtifact) -> None:
344343

345344
ao_config = self._quantization_recipe.ao_quantization_configs[0]
346345
quantize_(model, ao_config.ao_base_config, ao_config.filter_fn)
347-
unwrap_tensor_subclass(model)
348346

349347
self._artifact = artifact.copy_with_new_data(self._transformed_models)
350348

export/tests/test_export_stages.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -280,10 +280,7 @@ def test_source_transform_stage_no_quantization(self) -> None:
280280
self.assertEqual(result_artifact.data, self.models_dict)
281281

282282
@patch("executorch.export.stages.quantize_")
283-
@patch("executorch.export.stages.unwrap_tensor_subclass")
284-
def test_run_with_ao_quantization_configs(
285-
self, mock_unwrap: Mock, mock_quantize: Mock
286-
) -> None:
283+
def test_run_with_ao_quantization_configs(self, mock_quantize: Mock) -> None:
287284
from torchao.core.config import AOBaseConfig
288285

289286
mock_config = Mock(spec=AOBaseConfig)
@@ -308,9 +305,6 @@ def test_run_with_ao_quantization_configs(
308305
self.assertEqual(call_args[1], mock_config)
309306
self.assertEqual(call_args[2], mock_filter_fn)
310307

311-
# Verify unwrap_tensor_subclass was called once (with the copied model)
312-
self.assertEqual(mock_unwrap.call_count, 1)
313-
314308
# Verify that the original models_dict is unchanged
315309
self.assertEqual(models_dict, {"forward": self.model})
316310

extension/llm/export/builder.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from torch.nn.attention import SDPBackend
3939
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
4040
from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer
41-
from torchao.utils import unwrap_tensor_subclass
4241

4342
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
4443
logging.basicConfig(level=logging.INFO, format=FORMAT)
@@ -203,11 +202,6 @@ def _get_edge_config(self) -> EdgeCompileConfig:
203202
return edge_config
204203

205204
def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
206-
if module is not None:
207-
unwrap_tensor_subclass(module)
208-
else:
209-
unwrap_tensor_subclass(self.model)
210-
211205
dynamic_shape = self._get_dynamic_shape()
212206
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
213207
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)

torch_pin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
TORCH_VERSION = "2.10.0"
2-
NIGHTLY_VERSION = "dev20251015"
2+
NIGHTLY_VERSION = "dev20251104"

0 commit comments

Comments
 (0)