|
13 | 13 | from torchao.core.config import AOBaseConfig
|
14 | 14 | from torchao.prototype.mx_formats.constants import (
|
15 | 15 | DTYPE_FP4,
|
| 16 | + DTYPE_FP6_E2M3, |
| 17 | + DTYPE_FP6_E3M2, |
16 | 18 | DTYPE_TO_SHORT_STR,
|
17 | 19 | SUPPORTED_ELEM_DTYPES,
|
18 | 20 | )
|
@@ -41,6 +43,31 @@ class MXLinearRecipeName(Enum):
|
41 | 43 | MXFP4_CUTLASS = "mxfp4_cutlass"
|
42 | 44 |
|
43 | 45 |
|
| 46 | +def _validate_elem_dtype(elem_dtype): |
| 47 | + assert ( |
| 48 | + elem_dtype in SUPPORTED_ELEM_DTYPES |
| 49 | + ), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {elem_dtype}" |
| 50 | + |
| 51 | + |
| 52 | +def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype): |
| 53 | + if gemm_kernel_choice == MXGemmKernelChoice.CUTLASS: |
| 54 | + assert ( |
| 55 | + block_size == 32 |
| 56 | + ), f"block_size must be 32 to use the CUTLASS MX gemm kernels, got {block_size}" |
| 57 | + valid_dtypes = [torch.float8_e4m3fn, DTYPE_FP4] |
| 58 | + assert ( |
| 59 | + elem_dtype in valid_dtypes |
| 60 | + ), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}" |
| 61 | + elif gemm_kernel_choice == MXGemmKernelChoice.CUBLAS: |
| 62 | + assert ( |
| 63 | + block_size == 32 |
| 64 | + ), f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {block_size}" |
| 65 | + valid_dtypes = [torch.float8_e4m3fn] |
| 66 | + assert ( |
| 67 | + elem_dtype in valid_dtypes |
| 68 | + ), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}" |
| 69 | + |
| 70 | + |
44 | 71 | @dataclass
|
45 | 72 | class MXLinearConfig(AOBaseConfig):
|
46 | 73 | # block size for scaling, default is 32 to match
|
@@ -68,53 +95,17 @@ class MXLinearConfig(AOBaseConfig):
|
68 | 95 | # If True, uses a custom triton kernel for fp4 dequantize
|
69 | 96 | use_fp4_custom_triton_dequant_kernel: bool = False
|
70 | 97 |
|
71 |
| - # If True, packs 4xFP6 into 3xuint8 containers for inference, using custom triton |
72 |
| - # kernels (fused unpack/dequantize). Training not currently supported. |
73 |
| - pack_fp6 = True if hasattr(torch.library, "custom_op") else False |
74 |
| - |
75 | 98 | def __post_init__(self):
|
76 |
| - # validate elem_dtype and its overrides |
77 |
| - assert ( |
78 |
| - self.elem_dtype in SUPPORTED_ELEM_DTYPES |
79 |
| - ), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" |
| 99 | + _validate_elem_dtype(self.elem_dtype) |
| 100 | + _validate_gemm_kernel_choice( |
| 101 | + self.gemm_kernel_choice, self.block_size, self.elem_dtype |
| 102 | + ) |
80 | 103 | if self.elem_dtype_weight_override is not None:
|
81 |
| - assert ( |
82 |
| - self.elem_dtype_weight_override in SUPPORTED_ELEM_DTYPES |
83 |
| - ), f"elem_dtype_weight_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" |
| 104 | + _validate_elem_dtype(self.elem_dtype_weight_override) |
| 105 | + assert self.gemm_kernel_choice == MXGemmKernelChoice.EMULATED, "unsupported" |
84 | 106 | if self.elem_dtype_grad_output_override is not None:
|
85 |
| - assert ( |
86 |
| - self.elem_dtype_grad_output_override in SUPPORTED_ELEM_DTYPES |
87 |
| - ), f"elem_dtype_grad_output_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" |
88 |
| - |
89 |
| - # validate that block size and elem_dtype matches kernel choice |
90 |
| - if self.gemm_kernel_choice == MXGemmKernelChoice.CUTLASS: |
91 |
| - assert ( |
92 |
| - self.block_size == 32 |
93 |
| - ), f"block_size must be 32 to use the CUTLASS MX gemm kernels, got {self.block_size}" |
94 |
| - valid_dtypes = [torch.float8_e4m3fn, DTYPE_FP4] |
95 |
| - assert ( |
96 |
| - self.elem_dtype in valid_dtypes |
97 |
| - ), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {self.elem_dtype}" |
98 |
| - assert ( |
99 |
| - self.elem_dtype_weight_override is None |
100 |
| - ), "elem_dtype_weight_override not supported for CUTLASS MX gemm kernels" |
101 |
| - assert ( |
102 |
| - self.elem_dtype_grad_output_override is None |
103 |
| - ), "elem_dtype_grad_output_override not supported for CUTLASS MX gemm kernels" |
104 |
| - elif self.gemm_kernel_choice == MXGemmKernelChoice.CUBLAS: |
105 |
| - assert ( |
106 |
| - self.block_size == 32 |
107 |
| - ), f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {self.block_size}" |
108 |
| - valid_dtypes = [torch.float8_e4m3fn] |
109 |
| - assert ( |
110 |
| - self.elem_dtype in valid_dtypes |
111 |
| - ), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {self.elem_dtype}" |
112 |
| - assert ( |
113 |
| - self.elem_dtype_weight_override is None |
114 |
| - ), "elem_dtype_weight_override not supported for CUTLASS MX gemm kernels" |
115 |
| - assert ( |
116 |
| - self.elem_dtype_grad_output_override is None |
117 |
| - ), "elem_dtype_grad_output_override not supported for CUTLASS MX gemm kernels" |
| 107 | + _validate_elem_dtype(self.elem_dtype_grad_output_override) |
| 108 | + assert self.gemm_kernel_choice == MXGemmKernelChoice.EMULATED, "unsupported" |
118 | 109 |
|
119 | 110 | @staticmethod
|
120 | 111 | def from_recipe_name(
|
@@ -162,5 +153,47 @@ def short_str(self) -> str:
|
162 | 153 | s += ", use_fp8_dim1_cast_triton_kernel=True"
|
163 | 154 | if self.use_fp4_custom_triton_dequant_kernel:
|
164 | 155 | s += ", use_fp4_custom_triton_dequant_kernel=True"
|
165 |
| - # TODO(future PR): split training from inference and add fp6 here |
166 | 156 | return s
|
| 157 | + |
| 158 | + |
| 159 | +@dataclass |
| 160 | +class MXInferenceLinearConfig(AOBaseConfig): |
| 161 | + # block size for scaling, default is 32 to match |
| 162 | + # https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, |
| 163 | + # section 5.2 |
| 164 | + block_size: int = 32 |
| 165 | + |
| 166 | + # element dtype, used for activations, weights and gradients |
| 167 | + elem_dtype: Any = torch.float8_e4m3fn |
| 168 | + # TODO(future PR): support different elem_dtype for activations vs weights |
| 169 | + |
| 170 | + # defines the gemm kernel choice, if the chosen kernel is not supported |
| 171 | + # on the given hardware an exception will be thrown |
| 172 | + gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED |
| 173 | + |
| 174 | + # If True, uses a custom triton kernel for fp4 dequantize |
| 175 | + use_fp4_custom_triton_dequant_kernel: bool = False |
| 176 | + |
| 177 | + # If True, packs 4xFP6 into 3xuint8 containers for inference, using custom triton |
| 178 | + # kernels (fused unpack/dequantize). |
| 179 | + pack_fp6: bool = True |
| 180 | + |
| 181 | + def __post_init__(self): |
| 182 | + _validate_elem_dtype(self.elem_dtype) |
| 183 | + _validate_gemm_kernel_choice( |
| 184 | + self.gemm_kernel_choice, self.block_size, self.elem_dtype |
| 185 | + ) |
| 186 | + |
| 187 | + def short_str(self) -> str: |
| 188 | + """ |
| 189 | + Returns a concise representation of the current config. |
| 190 | + """ |
| 191 | + s = f"bl_sz={self.block_size}, lp_dtype={DTYPE_TO_SHORT_STR[self.elem_dtype]}" |
| 192 | + s += f", kernel={self.gemm_kernel_choice.value}" |
| 193 | + if self.use_fp4_custom_triton_dequant_kernel: |
| 194 | + s += ", use_fp4_custom_triton_dequant_kernel=True" |
| 195 | + if self.elem_dtype in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2) and self.pack_fp6: |
| 196 | + s += ", pack_fp6=True" |
| 197 | + return s |
| 198 | + |
| 199 | + # TODO(future PR): add a recipe to config API for inference |
0 commit comments