@@ -40,10 +40,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
4040 "torchao is not installed. Please install it to use MXFP8 linear layers."
4141 )
4242 torchao_version = version ("torchao" )
43- mxfp8_min_version = "0.11.0"
44- if torchao_version < mxfp8_min_version :
43+
44+ # Last torchao release was 0.12.0, so nightly build starts with 0.13.0+git...
45+ is_nightly_build = torchao_version .startswith ("0.13.0" )
46+ if not is_nightly_build :
4547 raise ImportError (
46- f"torchao version { torchao_version } is too old, please install torchao { mxfp8_min_version } or later and try again"
48+ f"torchao version { torchao_version } is too old, please install torchao nightly build and try again"
4749 )
4850
4951 # Can be removed if we enable the emulated versions
@@ -56,12 +58,17 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
5658 self .filter_fqns = mx_job_config .filter_fqns
5759
5860 # Configure MXFP8
59- from torchao .prototype .mx_formats .config import MXLinearConfig
61+ from torchao .prototype .mx_formats .config import (
62+ MXFP8Dim1CastKernelChoice ,
63+ MXLinearConfig ,
64+ )
6065
6166 config = MXLinearConfig .from_recipe_name (NAME_MAP [mx_job_config .recipe_name ])
62- config .use_fp8_dim1_cast_triton_kernel = (
63- mx_job_config .use_fp8_dim1_cast_triton_kernel
64- )
67+
68+ # String to enum
69+ config .mxfp8_dim1_cast_kernel_choice = MXFP8Dim1CastKernelChoice [
70+ mx_job_config .mxfp8_dim1_cast_kernel_choice .upper ()
71+ ]
6572 self .config = config
6673
6774 logger .info (f"Float8 training active with recipe { mx_job_config .recipe_name } " )
0 commit comments