@@ -73,8 +73,9 @@ def __init__(self, model, config):
7373 if hasattr (self .module , "config" ):
7474 TransformerPolicy .hf_model_config = self .module .config
7575
76- if config .dtype == torch .half and not get_accelerator ().is_fp16_supported ():
77- raise ValueError ("Type fp16 is not supported." )
76+ if config .dtype not in get_accelerator ().supported_dtypes ():
77+ raise ValueError (
78+ f"Data type { config .dtype } is not supported by { get_accelerator ().device_name ()} accelerator" )
7879
7980 # todo: keep this self.injection_dict because we don't use to change config.injection_policy API
8081 # todo: this will get changed when Molly's PR on auto injection dict is merged
@@ -324,7 +325,7 @@ def _validate_args(self, mpu, replace_with_kernel_inject):
324325 if self ._config .checkpoint is not None and not isinstance (self ._config .checkpoint , (str , dict )):
325326 raise ValueError (f"checkpoint must be None, str or dict, got { type (self ._config .checkpoint )} " )
326327
327- supported_dtypes = [None , torch .half , torch .int8 , torch .float ]
328+ supported_dtypes = [None , torch .half , torch .int8 , torch .float , torch . bfloat16 ]
328329 if self ._config .dtype not in supported_dtypes :
329330 raise ValueError (f"{ self ._config .dtype } not supported, valid dtype: { supported_dtypes } " )
330331
0 commit comments