diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 0854d580f..58ce7b033 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -80,10 +80,13 @@ def quantize( actorder=False, static_groups=False, ): - # TODO: waiting for pytorch implementgation of aten ops for MPS + # TODO: waiting for pytorch implementation of ops for MPS if sys.platform == "darwin" and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "1": raise RuntimeError("For MacOS you must set env `PYTORCH_ENABLE_MPS_FALLBACK=1` before running quantization.") + # save mem and temp move to cpu + self.layer.weight.data = self.layer.weight.data.cpu() + W = self.layer.weight.data.clone() if isinstance(self.layer, nn.Conv2d): @@ -92,7 +95,7 @@ def quantize( if isinstance(self.layer, transformers.Conv1D): W = W.t() - W = W.float() + W = W.to(device=self.dev, dtype=torch.float) tick = time.time() @@ -229,9 +232,12 @@ def quantize( Q = Q.t() if Q.shape != self.layer.weight.shape: - self.layer.weight.data = Q.reshape(self.layer.weight.shape).type_as(self.layer.weight.data) + self.layer.weight.data = Q.cpu().reshape(self.layer.weight.shape).type_as(self.layer.weight.data) else: - self.layer.weight.data = Q.type_as(self.layer.weight.data) + self.layer.weight.data = Q.cpu().type_as(self.layer.weight.data) + + # move back to self.dev + self.layer.weight.data = self.layer.weight.data.to(device=self.dev) if os.environ.get("DEBUG"): logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))