Skip to content

Commit 799b39b

Browse files
device_map and dtype to "auto" by default (#4509)
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
1 parent a6a2beb commit 799b39b

File tree

13 files changed

+34
-18
lines changed

13 files changed

+34
-18
lines changed

tests/experimental/test_trainers_args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def test_bco(self):
6464
# self.assertEqual(trainer.args.generate_during_eval, True)
6565
assert trainer.args.is_encoder_decoder
6666
assert trainer.args.precompute_ref_log_probs
67-
assert trainer.args.model_init_kwargs == {"trust_remote_code": True}
68-
assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True}
67+
assert trainer.args.model_init_kwargs == {"trust_remote_code": True, "device_map": "auto", "dtype": "auto"}
68+
assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True, "device_map": "auto", "dtype": "auto"}
6969
assert trainer.args.dataset_num_proc == 4
7070
assert trainer.args.prompt_sample_size == 512
7171
assert trainer.args.min_density_ratio == 0.2

tests/test_sft_trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545

4646
if is_peft_available():
47+
import peft
4748
from peft import (
4849
LoraConfig,
4950
PeftModel,
@@ -537,6 +538,11 @@ def test_train_with_peft_config_prompt_tuning(self, peft_type):
537538
tokenizer_name_or_path="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
538539
)
539540
elif peft_type == "prefix_tuning":
541+
if parse_version(peft.__version__) <= Version("0.17.1"):
542+
pytest.xfail(
543+
"Prefix tuning with device_map='auto' is broken in peft 0.17.1 and below. See "
544+
"https://github.com/huggingface/peft/issues/2821"
545+
)
540546
peft_config = PrefixTuningConfig(
541547
task_type=TaskType.CAUSAL_LM,
542548
num_virtual_tokens=4,

tests/test_trainers_args.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_cpo(self):
7878
assert trainer.args.truncation_mode == "keep_start"
7979
# self.assertEqual(trainer.args.generate_during_eval, True)
8080
assert trainer.args.is_encoder_decoder
81-
assert trainer.args.model_init_kwargs == {"trust_remote_code": True}
81+
assert trainer.args.model_init_kwargs == {"trust_remote_code": True, "device_map": "auto", "dtype": "auto"}
8282
assert trainer.args.dataset_num_proc == 4
8383

8484
def test_dpo(self):
@@ -189,8 +189,8 @@ def test_kto(self):
189189
# self.assertEqual(trainer.args.generate_during_eval, True)
190190
assert trainer.args.is_encoder_decoder
191191
assert trainer.args.precompute_ref_log_probs
192-
assert trainer.args.model_init_kwargs == {"trust_remote_code": True}
193-
assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True}
192+
assert trainer.args.model_init_kwargs == {"trust_remote_code": True, "device_map": "auto", "dtype": "auto"}
193+
assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True, "device_map": "auto", "dtype": "auto"}
194194
assert trainer.args.dataset_num_proc == 4
195195

196196
@pytest.mark.parametrize("mixtures_coef_list", [False, True])

trl/experimental/bco/bco_trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def __init__(
393393
raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.")
394394
else:
395395
model_init_kwargs = args.model_init_kwargs
396-
dtype = model_init_kwargs.get("dtype")
396+
dtype = model_init_kwargs.get("dtype", "auto")
397397
if dtype is not None:
398398
# Convert to `torch.dtype` if an str is passed
399399
if isinstance(dtype, str) and dtype != "auto":
@@ -403,6 +403,7 @@ def __init__(
403403
f"Invalid `dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
404404
)
405405
model_init_kwargs["dtype"] = dtype
406+
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")
406407

407408
if args.ref_model_init_kwargs is None:
408409
ref_model_init_kwargs = {}
@@ -412,7 +413,7 @@ def __init__(
412413
)
413414
else:
414415
ref_model_init_kwargs = args.ref_model_init_kwargs
415-
dtype = ref_model_init_kwargs.get("dtype")
416+
dtype = ref_model_init_kwargs.get("dtype", "auto")
416417
if dtype is not None:
417418
# Convert to `torch.dtype` if an str is passed
418419
if isinstance(dtype, str) and dtype != "auto":
@@ -422,6 +423,7 @@ def __init__(
422423
f"Invalid `dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
423424
)
424425
ref_model_init_kwargs["dtype"] = dtype
426+
ref_model_init_kwargs["device_map"] = ref_model_init_kwargs.get("device_map", "auto")
425427

426428
if isinstance(model, str):
427429
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)

trl/trainer/cpo_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def __init__(
160160
raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
161161
else:
162162
model_init_kwargs = args.model_init_kwargs
163-
dtype = model_init_kwargs.get("dtype")
163+
dtype = model_init_kwargs.get("dtype", "auto")
164164
if dtype is not None:
165165
# Convert to `torch.dtype` if an str is passed
166166
if isinstance(dtype, str) and dtype != "auto":
@@ -170,6 +170,7 @@ def __init__(
170170
f"Invalid `dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
171171
)
172172
model_init_kwargs["dtype"] = dtype
173+
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")
173174

174175
if isinstance(model, str):
175176
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)

trl/trainer/gkd_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def __init__(
187187
if teacher_model_init_kwargs["dtype"] in ["auto", None]
188188
else getattr(torch, teacher_model_init_kwargs["dtype"])
189189
)
190+
teacher_model_init_kwargs["device_map"] = teacher_model_init_kwargs.get("device_map", "auto")
190191

191192
if isinstance(teacher_model, str):
192193
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)

trl/trainer/grpo_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def __init__(
261261
model_init_kwargs = args.model_init_kwargs or {}
262262
if isinstance(model, str):
263263
model_id = model
264-
dtype = model_init_kwargs.get("dtype")
264+
dtype = model_init_kwargs.get("dtype", "auto")
265265
if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
266266
pass # dtype is already a torch.dtype or "auto" or None
267267
elif isinstance(dtype, str): # it's a str, but not "auto"
@@ -272,7 +272,7 @@ def __init__(
272272
"Invalid `dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
273273
f"a `torch.dtype` (e.g., 'float32'), but got {dtype}."
274274
)
275-
# Disable caching if gradient checkpointing is enabled (not supported)
275+
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")
276276
config = AutoConfig.from_pretrained(model_id)
277277
architecture = getattr(transformers, config.architectures[0])
278278
model = architecture.from_pretrained(model_id, **model_init_kwargs)

trl/trainer/kto_trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def __init__(
382382
raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
383383
else:
384384
model_init_kwargs = args.model_init_kwargs
385-
dtype = model_init_kwargs.get("dtype")
385+
dtype = model_init_kwargs.get("dtype", "auto")
386386
if dtype is not None:
387387
# Convert to `torch.dtype` if an str is passed
388388
if isinstance(dtype, str) and dtype != "auto":
@@ -392,6 +392,7 @@ def __init__(
392392
f"Invalid `dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
393393
)
394394
model_init_kwargs["dtype"] = dtype
395+
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")
395396

396397
if args.ref_model_init_kwargs is None:
397398
ref_model_init_kwargs = {}
@@ -401,7 +402,7 @@ def __init__(
401402
)
402403
else:
403404
ref_model_init_kwargs = args.ref_model_init_kwargs
404-
dtype = ref_model_init_kwargs.get("dtype")
405+
dtype = ref_model_init_kwargs.get("dtype", "auto")
405406
if dtype is not None:
406407
# Convert to `torch.dtype` if an str is passed
407408
if isinstance(dtype, str) and dtype != "auto":
@@ -411,6 +412,7 @@ def __init__(
411412
f"Invalid `dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
412413
)
413414
ref_model_init_kwargs["dtype"] = dtype
415+
ref_model_init_kwargs["device_map"] = ref_model_init_kwargs.get("device_map", "auto")
414416

415417
if isinstance(model, str):
416418
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)

trl/trainer/online_dpo_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def __init__(
304304
model_id = model
305305

306306
# Handle dtype in model_init_kwargs
307-
dtype = model_init_kwargs.get("dtype")
307+
dtype = model_init_kwargs.get("dtype", "auto")
308308
if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
309309
pass
310310
elif isinstance(dtype, str):
@@ -315,6 +315,7 @@ def __init__(
315315
"Invalid `dtype` passed to `OnlineDPOConfig`. Expected either 'auto' or a string "
316316
f"representing a `torch.dtype` (e.g., 'float32'), but got {dtype}."
317317
)
318+
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")
318319

319320
model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
320321
else:

trl/trainer/orpo_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def __init__(
162162
raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
163163
else:
164164
model_init_kwargs = args.model_init_kwargs
165-
dtype = model_init_kwargs.get("dtype")
165+
dtype = model_init_kwargs.get("dtype", "auto")
166166
if dtype is not None:
167167
# Convert to `torch.dtype` if an str is passed
168168
if isinstance(dtype, str) and dtype != "auto":
@@ -172,6 +172,7 @@ def __init__(
172172
f"Invalid `dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
173173
)
174174
model_init_kwargs["dtype"] = dtype
175+
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")
175176

176177
if isinstance(model, str):
177178
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)

0 commit comments

Comments
 (0)