@@ -122,7 +122,6 @@ def quantize( # noqa C901
122122 Int8DynamicActivationIntxWeightConfig ,
123123 quantize_ ,
124124 )
125- from torchao .utils import unwrap_tensor_subclass
126125
127126 with torch .no_grad ():
128127 # Computation dtype is fixed to fp32 in the implementation of quantize_, so
@@ -142,7 +141,6 @@ def quantize( # noqa C901
142141 ),
143142 ),
144143 )
145- model = unwrap_tensor_subclass (model )
146144 if verbose :
147145 print ("quantized model:" , model )
148146 return model
@@ -156,7 +154,6 @@ def quantize( # noqa C901
156154 quantize_ ,
157155 )
158156 from torchao .quantization .granularity import PerGroup
159- from torchao .utils import unwrap_tensor_subclass
160157
161158 def filter_fn (m , fqn ):
162159 is_linear = isinstance (m , nn .Linear )
@@ -181,8 +178,6 @@ def filter_fn(m, fqn):
181178 filter_fn = filter_fn ,
182179 )
183180
184- model = unwrap_tensor_subclass (model )
185-
186181 # TODO: deal with checkpoint / computation dtype decoupling.
187182
188183 if verbose :
@@ -191,7 +186,6 @@ def filter_fn(m, fqn):
191186 elif qmode == "4w" :
192187 from torchao .quantization .granularity import PerGroup
193188 from torchao .quantization .quant_api import IntxWeightOnlyConfig , quantize_
194- from torchao .utils import unwrap_tensor_subclass
195189
196190 q_group_size = 256 if group_size is None else group_size
197191 q_config = IntxWeightOnlyConfig (
@@ -204,7 +198,6 @@ def filter_fn(m, fqn):
204198 ),
205199 )
206200 quantize_ (model , q_config )
207- model = unwrap_tensor_subclass (model )
208201
209202 return model
210203 else :
0 commit comments