-
Notifications
You must be signed in to change notification settings - Fork 15
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
Full fine-tuning training fails after the first chunk with a serialization error when attempting to save the model to shared memory. The error occurs due to inability to pickle local function objects from PreTrainedModel.enable_input_require_grads.
To Reproduce
Steps to reproduce the behavior:
- Configure a full fine-tuning setup with peft_config=None as shown below.
- Use DPO training configuration with the provided settings
- Execute the training run
- Training fails after first chunk completion during model serialization
from peft import TaskType
MODEL_NAME_OR_PATH_2 = "mistralai/Mistral-7B-Instruct-v0.3"
fft_dpo_config = RFDPOConfig(
model_adapter_name="default",
ref_adapter_name="reference",
force_use_ref_model=False,
loss_type=["sigmoid", "bco_pair", "sft"],
loss_weights=[0.8, 0.2, 1.0],
beta=0.3,
max_prompt_length=1024,
max_completion_length=1024,
max_length=2048,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
learning_rate=5e-6,
warmup_ratio=0.1,
weight_decay=0,
lr_scheduler_type="linear",
num_train_epochs=1,
optim="adamw_8bit",
bf16=True,
save_strategy="epoch",
logging_strategy="steps",
logging_steps=1
)
RFModelConfig(
model_name=MODEL_NAME_OR_PATH_2,
ref_model_name=None,
peft_config=None, # Full fine-tuning
training_args=fft_dpo_config,
model_kwargs={"device_map": "auto", "torch_dtype": torch.bfloat16, "attn_implementation": "flash_attention_2"},
tokenizer_kwargs={"model_max_length": 2048, "padding_side": "right", "truncation": True}
)Expected Behavior
Full fine-tuning training should complete successfully across all chunks without serialization errors when saving models to shared memory.
Environment
- OS: Ubuntu
- Python version: 3.12
- RapidFire AI version: 0.12.6
- Browser (if applicable): Chrome
Additional Context
- Issue occurs specifically with full fine-tuning (peft_config=None)
- Error happens during model serialization to shared memory between chunks
- Uses Mistral-7B-Instruct model with Flash Attention 2
- Configuration uses DPO training with multiple loss types
Error Logs
Started 1 worker processes successfully
Created workers
Run 1 has failed: Can't get local object 'PreTrainedModel.enable_input_require_grads.<locals>.make_inputs_require_grads'Traceback (most recent call last):
File "/home/palebluedot/miniconda3/envs/bench/lib/python3.12/site-packages/rapidfireai/fit/backend/worker.py", line 273, in serve_forever
self.run_fit(run_id, chunk_id, create_model_fn)
File "/home/palebluedot/miniconda3/envs/bench/lib/python3.12/site-packages/rapidfireai/fit/backend/worker.py", line 206, in run_fit
save_model_to_shared_memory(
File "/home/palebluedot/miniconda3/envs/bench/lib/python3.12/site-packages/rapidfireai/fit/ml/checkpoint_utils.py", line 270, in save_model_to_shared_memory
shm_manager.save_model_object(model_id, model_type, model_data)
File "/home/palebluedot/miniconda3/envs/bench/lib/python3.12/site-packages/rapidfireai/fit/utils/shm_manager.py", line 390, in save_model_object
self._save_full_model(model_id, model_object, model_object_type)
File "/home/palebluedot/miniconda3/envs/bench/lib/python3.12/site-packages/rapidfireai/fit/utils/shm_manager.py", line 276, in _save_full_model
self._registry[model_id] = model_entry
File "/home/palebluedot/miniconda3/envs/bench/lib/python3.12/multiprocessing/managers.py", line 827, in _callmethod
conn.send((self._id, methodname, args, kwds))
File "/home/palebluedot/miniconda3/envs/bench/lib/python3.12/multiprocessing/connection.py", line 206, in send
self._send_bytes(_ForkingPickler.dumps(obj))
File "/home/palebluedot/miniconda3/envs/bench/lib/python3.12/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't get local object 'PreTrainedModel.enable_input_require_grads.<locals>.make_inputs_require_grads'
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working