Skip to content

Commit

Permalink
🐛 fix DPO merging issues see huggingface/trl#742
Browse files Browse the repository at this point in the history
  • Loading branch information
simonmeoni committed Nov 4, 2024
1 parent 104d095 commit 129ad98
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
23 changes: 19 additions & 4 deletions lib/style-transfer/style_transfer/rb_gen/steps/dpo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import hydra
import numpy as np
import pandas as pd
import peft
import wandb
from datasets import Dataset
from peft import AutoPeftModelForCausalLM
from omegaconf import ListConfig
from style_transfer.rb_gen.utils.utils import CustomWandbCallback
from transformers import PreTrainedTokenizerBase
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
from trl import DPOTrainer


Expand Down Expand Up @@ -62,15 +63,29 @@ def dpo_train(
cfg.dpo.training_args.output_dir = f"models/{wandb.run.id}/dpo/{step}"
args = hydra.utils.instantiate(cfg.dpo.training_args)
args.padding_value = tokenizer.eos_token_id
model = AutoPeftModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_path)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_on_path=f"models/{wandb.run.id}/merged/"
)
model.enable_input_require_grads()
peft_config = hydra.utils.instantiate(cfg.model.peft_config)
peft_config.target_modules = (
list(peft_config.target_modules)
if isinstance(peft_config.target_modules, ListConfig)
else peft_config.target_modules
)
model = peft.get_peft_model(
model,
peft_config,
)
model.add_adapter(peft_config=peft_config, adapter_name="reference")
dpo_trainer = DPOTrainer(
args=args,
ref_model=None,
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
callbacks=[CustomWandbCallback],
model_adapter_name="default",
ref_adapter_name="reference",
)
dpo_trainer.train()

Expand Down
2 changes: 0 additions & 2 deletions lib/style-transfer/style_transfer/rb_gen/steps/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
import logging
import os
import shutil
import sqlite3
from typing import Callable

Expand Down Expand Up @@ -78,7 +77,6 @@ def generate(
del llm
gc.collect()
torch.cuda.empty_cache()
shutil.rmtree(f"models/{wandb.run.id}/merged/")
return gen_pred_dataset


Expand Down
5 changes: 4 additions & 1 deletion lib/style-transfer/style_transfer/run_rb_gen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
import shutil

import hydra
import wandb
Expand All @@ -14,7 +15,7 @@
logger = logging.getLogger(__name__)

os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
os.environ["WANDB_LOG_MODEL"] = "none"
os.environ["WANDB_START_METHOD"] = "thread"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
tqdm.pandas()
Expand Down Expand Up @@ -65,6 +66,7 @@ def main(cfg: DictConfig):
sft_train(cfg, sft_dataset, test_dataset, current_model_path)
logger.info("Bootstrapping done, Iterative Reward-based Generation Training begins...")
for step in range(cfg.max_steps):
logger.info(f"🔄 Step {step} ...")
sth_dataset = generate(
cfg,
step,
Expand Down Expand Up @@ -97,6 +99,7 @@ def main(cfg: DictConfig):
sth_dataset,
checkpoint=eval_model_path,
)
shutil.rmtree(f"models/{wandb.run.id}/merged/")
wandb.finish()


Expand Down

0 comments on commit 129ad98

Please sign in to comment.