Skip to content

Commit

Permalink
avoid cloning on gpu (#886)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qubitium authored Dec 17, 2024
1 parent 44b1255 commit 0e4893b
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions gptqmodel/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,13 @@ def quantize(
actorder=False,
static_groups=False,
):
# TODO: waiting for pytorch implementgation of aten ops for MPS
# TODO: waiting for pytorch implementation of ops for MPS
if sys.platform == "darwin" and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "1":
raise RuntimeError("For MacOS you must set env `PYTORCH_ENABLE_MPS_FALLBACK=1` before running quantization.")

# save mem and temp move to cpu
self.layer.weight.data = self.layer.weight.data.cpu()

W = self.layer.weight.data.clone()

if isinstance(self.layer, nn.Conv2d):
Expand All @@ -92,7 +95,7 @@ def quantize(
if isinstance(self.layer, transformers.Conv1D):
W = W.t()

W = W.float()
W = W.to(device=self.dev, dtype=torch.float)

tick = time.time()

Expand Down Expand Up @@ -229,9 +232,12 @@ def quantize(
Q = Q.t()

if Q.shape != self.layer.weight.shape:
self.layer.weight.data = Q.reshape(self.layer.weight.shape).type_as(self.layer.weight.data)
self.layer.weight.data = Q.cpu().reshape(self.layer.weight.shape).type_as(self.layer.weight.data)
else:
self.layer.weight.data = Q.type_as(self.layer.weight.data)
self.layer.weight.data = Q.cpu().type_as(self.layer.weight.data)

# move back to self.dev
self.layer.weight.data = self.layer.weight.data.to(device=self.dev)

if os.environ.get("DEBUG"):
logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
Expand Down

0 comments on commit 0e4893b

Please sign in to comment.