Skip to content

Commit

Permalink
fix dtype converting issue (#403)
Browse files Browse the repository at this point in the history
  • Loading branch information
wenhuach21 authored Jan 6, 2025
1 parent 4fd9bbd commit 6ac7a2b
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions auto_round/auto_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,6 @@ def validate_environment(self, *args, **kwargs):
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
torch_dtype = torch.float16
elif torch_dtype != torch.float16 and not is_hpu_supported():
logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AutoRound.")
return torch_dtype

def find_backend(self, target_backend: str):
Expand Down Expand Up @@ -406,10 +404,9 @@ def convert_model(self, model: nn.Module):
if ("hpu" == target_device or "cpu" == target_device) and model.dtype != torch.bfloat16:
logger.info(f"Change the dtype to `bfloat16` as {target_device.upper()} does not support float16")
model = model.to(torch.bfloat16)
else:
if model.dtype != torch.float16:
logger.info(f"Change the dtype to `float16` for better performance")
model = model.to(torch.float16)
elif "cuda" == target_device and model.dtype != torch.float16:
logger.info(f"Change the dtype to `float16` for better performance")
model = model.to(torch.float16)

bits = quantization_config.bits
group_size = quantization_config.group_size
Expand Down

0 comments on commit 6ac7a2b

Please sign in to comment.