@@ -393,7 +393,7 @@ def __init__(
393393 raise ValueError ("You passed model_kwargs to the BCOTrainer. But your model is already instantiated." )
394394 else :
395395 model_init_kwargs = args .model_init_kwargs
396- dtype = model_init_kwargs .get ("dtype" )
396+ dtype = model_init_kwargs .get ("dtype" , "auto" )
397397 if dtype is not None :
398398 # Convert to `torch.dtype` if an str is passed
399399 if isinstance (dtype , str ) and dtype != "auto" :
@@ -403,6 +403,7 @@ def __init__(
403403 f"Invalid `dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got { dtype } ."
404404 )
405405 model_init_kwargs ["dtype" ] = dtype
406+ model_init_kwargs ["device_map" ] = model_init_kwargs .get ("device_map" , "auto" )
406407
407408 if args .ref_model_init_kwargs is None :
408409 ref_model_init_kwargs = {}
@@ -412,7 +413,7 @@ def __init__(
412413 )
413414 else :
414415 ref_model_init_kwargs = args .ref_model_init_kwargs
415- dtype = ref_model_init_kwargs .get ("dtype" )
416+ dtype = ref_model_init_kwargs .get ("dtype" , "auto" )
416417 if dtype is not None :
417418 # Convert to `torch.dtype` if an str is passed
418419 if isinstance (dtype , str ) and dtype != "auto" :
@@ -422,6 +423,7 @@ def __init__(
422423 f"Invalid `dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got { dtype } ."
423424 )
424425 ref_model_init_kwargs ["dtype" ] = dtype
426+ ref_model_init_kwargs ["device_map" ] = ref_model_init_kwargs .get ("device_map" , "auto" )
425427
426428 if isinstance (model , str ):
427429 model = AutoModelForCausalLM .from_pretrained (model , ** model_init_kwargs )
0 commit comments