diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 25c33417a..8d62a842f 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -596,7 +596,7 @@ def double_quantize_scalers( return ( quantized_scaler_blocks.flatten().to(torch.int8), - quantization_factor.view(n_scaler_blocks), + quantization_factor.view(n_scaler_blocks).contiguous(), scalers_1_mean, )