@@ -186,103 +186,60 @@ class MyTensor(TorchAOBaseTensor):
186186 tensor_data_names = ["qdata" ]
187187 tensor_attribute_names = ["attr" , "device" ]
188188
189- def __new__ (cls , qdata , attr , device ):
189+ def __new__ (cls , qdata , attr , device = None ):
190190 shape = qdata .shape
191191 if device is None :
192192 device = qdata .device
193193 kwargs = {"device" : device }
194194 return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
195195
196- def __init__ (self , qdata , attr , device ):
196+ def __init__ (self , qdata , attr , device = None ):
197197 self .qdata = qdata
198198 self .attr = attr
199199
200200 l = torch .nn .Linear (2 , 3 )
201- l .weight = torch .nn .Parameter (MyTensor (l .weight , "attr" , None ))
201+ l .weight = torch .nn .Parameter (MyTensor (l .weight , "attr" ))
202202 lp_tensor = l .weight
203203
204204 another_tensor = torch .nn .Linear (2 , 3 ).weight
205205 # attribute has to be the same
206- lp_tensor_for_copy = MyTensor (another_tensor , "attr" , None )
206+ lp_tensor_for_copy = MyTensor (another_tensor , "attr" )
207207 self ._test_default_impls_helper (lp_tensor , lp_tensor_for_copy )
208208
209209 @skip_if_no_cuda ()
210210 def test_default_impls_with_optional_data (self ):
211211 class MyTensorWithOptionalData (TorchAOBaseTensor ):
212212 tensor_data_names = ["qdata" ]
213- tensor_attribute_names = ["attr" , "device" ]
214213 optional_tensor_data_names = ["zero_point" ]
215-
216- def __new__ (cls , qdata , attr , device , zero_point = None ):
217- shape = qdata .shape
218- if device is None :
219- device = qdata .device
220- kwargs = {"device" : device }
221- return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
222-
223- def __init__ (self , qdata , attr , device , zero_point = None ):
224- self .qdata = qdata
225- self .attr = attr
226- self .zero_point = zero_point
227-
228- # test both the optional Tensor is None
229- # and not None
230- l = torch .nn .Linear (2 , 3 )
231- lp_tensor = MyTensorWithOptionalData (l .weight , "attr" , None , None )
232- l = torch .nn .Linear (2 , 3 )
233- lp_tensor_for_copy = MyTensorWithOptionalData (l .weight , "attr" , None , None )
234- self ._test_default_impls_helper (lp_tensor , lp_tensor_for_copy )
235-
236- l = torch .nn .Linear (2 , 3 )
237- lp_tensor = MyTensorWithOptionalData (
238- l .weight , "attr" , None , torch .zeros_like (l .weight )
239- )
240- l = torch .nn .Linear (2 , 3 )
241- lp_tensor_for_copy = MyTensorWithOptionalData (
242- l .weight , "attr" , None , torch .zeros_like (l .weight )
243- )
244- self ._test_default_impls_helper (lp_tensor , lp_tensor_for_copy )
245-
246- @skip_if_no_cuda ()
247- def test_default_impls_with_optional_attr (self ):
248- class MyTensorWithOptionalData (TorchAOBaseTensor ):
249- tensor_data_names = ["qdata" ]
250214 tensor_attribute_names = ["attr" , "device" ]
251- optional_tensor_data_names = ["zero_point" ]
252- optional_tensor_attribute_names = ["optional_attr" ]
253215
254- def __new__ (cls , qdata , attr , device , zero_point = None , optional_attr = None ):
216+ def __new__ (cls , qdata , zero_point = None , attr = 1.0 , device = None ):
255217 shape = qdata .shape
256218 if device is None :
257219 device = qdata .device
258220 kwargs = {"device" : device }
259221 return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
260222
261- def __init__ (
262- self , qdata , attr , device , zero_point = None , optional_attr = None
263- ):
223+ def __init__ (self , qdata , zero_point = None , attr = 1.0 , device = None ):
264224 self .qdata = qdata
265- self .attr = attr
266225 self .zero_point = zero_point
267- self .optional_attr = optional_attr
226+ self .attr = attr
268227
269228 # test both the optional Tensor is None
270229 # and not None
271230 l = torch .nn .Linear (2 , 3 )
272- lp_tensor = MyTensorWithOptionalData (l .weight , "attr" , None , zero_point = None )
231+ lp_tensor = MyTensorWithOptionalData (l .weight , None , "attr" )
273232 l = torch .nn .Linear (2 , 3 )
274- lp_tensor_for_copy = MyTensorWithOptionalData (
275- l .weight , "attr" , None , zero_point = None
276- )
233+ lp_tensor_for_copy = MyTensorWithOptionalData (l .weight , None , "attr" )
277234 self ._test_default_impls_helper (lp_tensor , lp_tensor_for_copy )
278235
279236 l = torch .nn .Linear (2 , 3 )
280237 lp_tensor = MyTensorWithOptionalData (
281- l .weight , "attr" , None , zero_point = None , optional_attr = "value "
238+ l .weight , torch . zeros_like ( l . weight ), "attr "
282239 )
283240 l = torch .nn .Linear (2 , 3 )
284241 lp_tensor_for_copy = MyTensorWithOptionalData (
285- l .weight , "attr" , None , zero_point = None , optional_attr = "value "
242+ l .weight , torch . zeros_like ( l . weight ), "attr "
286243 )
287244 self ._test_default_impls_helper (lp_tensor , lp_tensor_for_copy )
288245
0 commit comments