Skip to content

Commit 29fdecb

Browse files
authored
Fix dtype of unpacked tensor (#1840)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
1 parent ecffc2e commit 29fdecb

File tree

1 file changed

+1
-1
lines changed
  • neural_compressor/torch/algorithms/weight_only

1 file changed

+1
-1
lines changed

neural_compressor/torch/algorithms/weight_only/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def pack_tensor_with_torch(self, raw_tensor):
289289
def unpack_tensor_with_torch(self, packed_tensor):
290290
target_dtype = torch.int8 if not hasattr(self, "qzeros") or "int" not in self.dtype else torch.uint8
291291
target_len = packed_tensor.shape[1] * self.n_pack
292-
unpacked_tensor = torch.zeros(packed_tensor.shape[0], target_len, dtype=self.compression_dtype).to(self.device)
292+
unpacked_tensor = torch.zeros(packed_tensor.shape[0], target_len, dtype=target_dtype).to(self.device)
293293
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device)
294294
for j in range(packed_tensor.shape[1]):
295295
for e in range(self.n_pack):

0 commit comments

Comments
 (0)