|
22 | 22 | preprocess_data, |
23 | 23 | preprocess_scale, |
24 | 24 | ) |
25 | | -from torchao.quantization.granularity import PerRow |
| 25 | +from torchao.quantization.granularity import PerRow, PerTensor |
26 | 26 | from torchao.quantization.observer import get_block_size |
27 | 27 | from torchao.quantization.quant_primitives import ( |
28 | 28 | _choose_scale_float8, |
@@ -177,32 +177,61 @@ def to_float8( |
177 | 177 | block_size = get_block_size(hp_tensor.shape, granularity) |
178 | 178 | block_size = list(block_size) |
179 | 179 |
|
180 | | - # for per row quantization and kernel_preference default setting, we'll use triton kernel for best performance |
| 180 | + kernel_choice = None |
181 | 181 | if ( |
182 | 182 | kernel_preference == KernelPreference.AUTO |
183 | 183 | and _is_fbgemm_genai_gpu_available() |
184 | | - and ( |
185 | | - tuple(block_size) |
186 | | - == (1,) * (hp_tensor.ndim - 1) + (hp_tensor.shape[-1],) |
187 | | - ) |
| 184 | + and is_sm_at_least_90() |
| 185 | + and isinstance(granularity, PerRow) |
| 186 | + and float8_dtype == torch.float8_e4m3fn |
| 187 | + and hp_value_lb is None |
188 | 188 | ): |
189 | | - assert float8_dtype == torch.float8_e4m3fn, ( |
190 | | - f"Only torch.float8_e4m3fn is supported, got: {float8_dtype}" |
| 189 | + # if kernel_preference is AUTO and per row quantization |
| 190 | + # we'll use fbgemm quantize kernel for best performance |
| 191 | + kernel_choice = "fbgemm" |
| 192 | + elif kernel_preference == KernelPreference.FBGEMM: |
| 193 | + # if user explicitly chose FBGEMM kernel preference, we'll also use fbgemm kernel |
| 194 | + assert _is_fbgemm_genai_gpu_available() and is_sm_at_least_90(), ( |
| 195 | + "Specified fbgemm but fbgemm_gpu_genai is not installed or hardware is not >= SM 9.0 (> H100)" |
| 196 | + ) |
| 197 | + assert hp_value_lb is None, ( |
| 198 | + "hp_value_lb should not be specified if with KerenelPreference.FBGEMM" |
191 | 199 | ) |
| 200 | + kernel_choice = "fbgemm" |
| 201 | + else: |
| 202 | + # fallback quantize kernel for everything else will be torch |
| 203 | + kernel_choice = "torch" |
| 204 | + |
| 205 | + if kernel_choice == "fbgemm": |
| 206 | + assert hp_value_lb is None, f"{hp_value_lb=} is not supported" |
192 | 207 | if hp_value_ub is not None: |
193 | 208 | maybe_hp_value_ub_tensor = torch.tensor( |
194 | 209 | hp_value_ub, dtype=torch.float, device=hp_tensor.device |
195 | 210 | ) |
196 | 211 | else: |
197 | 212 | maybe_hp_value_ub_tensor = None |
198 | | - data, scale = torch.ops.triton.quantize_fp8_row( |
199 | | - hp_tensor, scale_ub=maybe_hp_value_ub_tensor |
200 | | - ) |
201 | | - scale_shape = [] |
202 | | - for i in range(hp_tensor.ndim): |
203 | | - scale_shape.append(hp_tensor.shape[i] // block_size[i]) |
204 | | - scale = scale.reshape(*scale_shape) |
| 213 | + if isinstance(granularity, PerRow): |
| 214 | + data, scale = torch.ops.triton.quantize_fp8_row( |
| 215 | + hp_tensor, scale_ub=maybe_hp_value_ub_tensor |
| 216 | + ) |
| 217 | + scale_shape = [] |
| 218 | + for i in range(hp_tensor.ndim): |
| 219 | + scale_shape.append(hp_tensor.shape[i] // block_size[i]) |
| 220 | + scale = scale.reshape(*scale_shape) |
| 221 | + else: |
| 222 | + assert isinstance(granularity, PerTensor), ( |
| 223 | + f"Expected per tensor, got {granularity}" |
| 224 | + ) |
| 225 | + # current error: torch.AcceleratorError: CUDA error: an illegal memory access was encountered |
| 226 | + # TODO: enable after this is working |
| 227 | + # data, scale = torch.ops.fbgemm.quantize_fp8_per_tensor( |
| 228 | + # hp_tensor, num_tokens, scale_ub=maybe_hp_value_ub_tensor |
| 229 | + # ) |
| 230 | + raise NotImplementedError( |
| 231 | + "Currently KernelPreference.FBGEMM does not work for per tensor float8 quant" |
| 232 | + ) |
205 | 233 | else: |
| 234 | + assert kernel_choice == "torch", f"Expected torch, got {kernel_choice}" |
206 | 235 | scale = _choose_scale_float8( |
207 | 236 | hp_tensor, |
208 | 237 | float8_dtype=float8_dtype, |
|
0 commit comments