Skip to content

Commit

Permalink
Replace torch.empty with torch.zeros
Browse files Browse the repository at this point in the history
Differential Revision: D64875312

Pull Request resolved: #1157
  • Loading branch information
helunwencser authored Oct 25, 2024
1 parent fec5420 commit 6177bfc
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,12 +583,12 @@ def __init__(
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
self.register_buffer(
"weight",
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32, device=device)
torch.zeros((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32, device=device)
)
self.dtype = dtype
self.register_buffer(
"scales_and_zeros",
torch.empty((in_features // groupsize, out_features, 2), dtype=self.scales_precision, device=device)
torch.zeros((in_features // groupsize, out_features, 2), dtype=self.scales_precision, device=device)
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -935,18 +935,18 @@ def __init__(
# currently storing unpacked int8 weights
self.register_buffer(
"weight",
torch.empty((out_features, in_features), dtype=torch.int8),
torch.zeros((out_features, in_features), dtype=torch.int8),
)
self.register_buffer(
"scales",
torch.empty(
torch.zeros(
(out_features, in_features // groupsize),
dtype=scales_precision,
),
)
self.register_buffer(
"zeros",
torch.empty(
torch.zeros(
(out_features, in_features // groupsize),
dtype=scales_precision,
),
Expand Down

0 comments on commit 6177bfc

Please sign in to comment.