@@ -172,7 +172,7 @@ def test_quantized_add(
172172 torch .tensor (
173173 [1073741824 ], dtype = torch .int32
174174 ), # out_multiplier (0.5 * 2^31)
175- torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
175+ torch .tensor ([0 ], dtype = torch .int32 ), # out_shift
176176 0 , # out_zero_point
177177 torch .tensor ([[0 ]], dtype = dtype ), # expected_output
178178 per_tensor ,
@@ -197,7 +197,7 @@ def test_quantized_add(
197197 torch .tensor (
198198 [1073741824 ], dtype = torch .int32
199199 ), # out_multiplier (0.5 * 2^31)
200- torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
200+ torch .tensor ([0 ], dtype = torch .int32 ), # out_shift
201201 0 , # out_zero_point
202202 torch .tensor ([[- 2 , - 8 ]], dtype = dtype ), # expected_output
203203 per_tensor ,
@@ -220,7 +220,7 @@ def test_quantized_add(
220220 torch .tensor (
221221 [1073741824 ], dtype = torch .int32
222222 ), # out_multiplier (0.5 * 2^31)
223- torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
223+ torch .tensor ([0 ], dtype = torch .int32 ), # out_shift
224224 0 , # out_zero_point
225225 torch .tensor ([[0 , 0 ]], dtype = dtype ), # expected_output
226226 per_tensor ,
@@ -244,7 +244,7 @@ def test_quantized_add(
244244 torch .tensor (
245245 [1073741824 ], dtype = torch .int32
246246 ), # out_multiplier (0.5 * 2^31)
247- torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
247+ torch .tensor ([0 ], dtype = torch .int32 ), # out_shift
248248 0 , # out_zero_point
249249 torch .tensor (
250250 [[[0 , - 2 , - 4 ], [- 2 , - 7 , - 12 ]]], dtype = dtype
@@ -270,7 +270,7 @@ def test_quantized_add(
270270 torch .tensor (
271271 [268435456 ], dtype = torch .int32
272272 ), # out_multiplier (1.0 * 2^31)
273- torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
273+ torch .tensor ([0 ], dtype = torch .int32 ), # out_shift
274274 1 , # out_zero_point
275275 torch .tensor ([[1 , 1 ]], dtype = dtype ), # expected_output
276276 per_tensor ,
@@ -295,7 +295,7 @@ def test_quantized_add(
295295 torch .tensor (
296296 [268435456 ], dtype = torch .int32
297297 ), # out_multiplier (1.0 * 2^31)
298- torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
298+ torch .tensor ([0 ], dtype = torch .int32 ), # out_shift
299299 1 , # out_zero_point
300300 torch .tensor ([[1 , 1 ]], dtype = dtype ), # expected_output
301301 False ,
@@ -317,7 +317,7 @@ def test_quantized_add(
317317 [268435456 ], dtype = torch .int32
318318 ), # out_multiplier (0.125 * 2^31)
319319 torch .tensor (
320- [1 ], dtype = torch .int64
320+ [1 ], dtype = torch .int32
321321 ), # out_shift (shift=1, doubles the scale)
322322 1 , # out_zero_point
323323 torch .tensor ([[1 , 2 ]], dtype = dtype ), # expected_output
@@ -339,7 +339,7 @@ def test_quantized_add(
339339 [268435456 ], dtype = torch .int32
340340 ), # out_multiplier (0.125 * 2^31)
341341 torch .tensor (
342- [1 ], dtype = torch .int64
342+ [1 ], dtype = torch .int32
343343 ), # out_shift (shift=1, doubles the scale)
344344 1 , # out_zero_point
345345 torch .tensor ([[1 , 2 ]], dtype = dtype ), # expected_output
0 commit comments