diff --git a/examples/models/llama/source_transformation/pre_quantization.py b/examples/models/llama/source_transformation/pre_quantization.py index d284512e71..0a980bb08f 100644 --- a/examples/models/llama/source_transformation/pre_quantization.py +++ b/examples/models/llama/source_transformation/pre_quantization.py @@ -146,7 +146,7 @@ def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: scales_key = f"{cur_fqn}.scales" if isinstance(child, nn.Embedding) and scales_key in checkpoint: assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8 - assert checkpoint[scales_key].dtype == torch.float32 + assert checkpoint[scales_key].dtype == dtype return True return False