Skip to content

[BUG] Serialization error in saving full-finetuned model - DPO #125

@anay-rfai

Description

@anay-rfai

Bug Description

Full fine-tuning training fails after the first chunk with a serialization error when attempting to save the model to shared memory. The error occurs due to inability to pickle local function objects from PreTrainedModel.enable_input_require_grads.

To Reproduce

Steps to reproduce the behavior:

  1. Configure a full fine-tuning setup with peft_config=None as shown below.
  2. Use DPO training configuration with the provided settings
  3. Execute the training run
  4. Training fails after first chunk completion during model serialization
from peft import TaskType

MODEL_NAME_OR_PATH_2 = "mistralai/Mistral-7B-Instruct-v0.3"

fft_dpo_config = RFDPOConfig(
    model_adapter_name="default",
    ref_adapter_name="reference",
    force_use_ref_model=False, 
    loss_type=["sigmoid", "bco_pair", "sft"], 
    loss_weights=[0.8, 0.2, 1.0],
    beta=0.3, 
    max_prompt_length=1024,
    max_completion_length=1024,
    max_length=2048,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    learning_rate=5e-6,
    warmup_ratio=0.1,
    weight_decay=0,
    lr_scheduler_type="linear",
    num_train_epochs=1, 
    optim="adamw_8bit",
    bf16=True,
    save_strategy="epoch",
    logging_strategy="steps",
    logging_steps=1
)

RFModelConfig(
    model_name=MODEL_NAME_OR_PATH_2,
    ref_model_name=None,
    peft_config=None,  # Full fine-tuning
    training_args=fft_dpo_config,
    model_kwargs={"device_map": "auto", "torch_dtype": torch.bfloat16, "attn_implementation": "flash_attention_2"},
    tokenizer_kwargs={"model_max_length": 2048, "padding_side": "right", "truncation": True}
)

Expected Behavior

Full fine-tuning training should complete successfully across all chunks without serialization errors when saving models to shared memory.

Environment

  • OS: Ubuntu
  • Python version: 3.12
  • RapidFire AI version: 0.12.6
  • Browser (if applicable): Chrome

Additional Context

  • Issue occurs specifically with full fine-tuning (peft_config=None)
  • Error happens during model serialization to shared memory between chunks
  • Uses Mistral-7B-Instruct model with Flash Attention 2
  • Configuration uses DPO training with multiple loss types

Error Logs

Started 1 worker processes successfully
Created workers
Run 1 has failed: Can't get local object 'PreTrainedModel.enable_input_require_grads.<locals>.make_inputs_require_grads'Traceback (most recent call last):
  File "/home/palebluedot/miniconda3/envs/bench/lib/python3.12/site-packages/rapidfireai/fit/backend/worker.py", line 273, in serve_forever
    self.run_fit(run_id, chunk_id, create_model_fn)
  File "/home/palebluedot/miniconda3/envs/bench/lib/python3.12/site-packages/rapidfireai/fit/backend/worker.py", line 206, in run_fit
    save_model_to_shared_memory(
  File "/home/palebluedot/miniconda3/envs/bench/lib/python3.12/site-packages/rapidfireai/fit/ml/checkpoint_utils.py", line 270, in save_model_to_shared_memory
    shm_manager.save_model_object(model_id, model_type, model_data)
  File "/home/palebluedot/miniconda3/envs/bench/lib/python3.12/site-packages/rapidfireai/fit/utils/shm_manager.py", line 390, in save_model_object
    self._save_full_model(model_id, model_object, model_object_type)
  File "/home/palebluedot/miniconda3/envs/bench/lib/python3.12/site-packages/rapidfireai/fit/utils/shm_manager.py", line 276, in _save_full_model
    self._registry[model_id] = model_entry
  File "/home/palebluedot/miniconda3/envs/bench/lib/python3.12/multiprocessing/managers.py", line 827, in _callmethod
    conn.send((self._id, methodname, args, kwds))
  File "/home/palebluedot/miniconda3/envs/bench/lib/python3.12/multiprocessing/connection.py", line 206, in send
    self._send_bytes(_ForkingPickler.dumps(obj))
  File "/home/palebluedot/miniconda3/envs/bench/lib/python3.12/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
AttributeError: Can't get local object 'PreTrainedModel.enable_input_require_grads.<locals>.make_inputs_require_grads'

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions