Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sm/better-evaluator #49

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
export CUDA_VISIBLE_DEVICES=$1
python style_transfer/run_rb_gen.py model.name=meta-llama/Llama-3.2-3B-Instruct \
model.peft_config.target_modules='["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]' \
dataset.name=bio-datasets/mimic-iii-gpt4o-tokens \
max_steps=7 \
dataset.num_generated_samples=1000 \
score.model.model_name_or_path=sentence-transformers/all-mpnet-base-v2 \
dataset.sft_ratio=0.06 \
dataset.gen_ratio=0.7 \
sft.training_args.eval_steps=30 \
score.train.train_size=0.6 \
dataset.sft_dataset.size=977 \
dpo.training_args.num_train_epochs=20 \
dpo.percentile=70 \
score.batch_size=8
2 changes: 1 addition & 1 deletion lib/style-transfer/configs/rb_gen/score/default.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
batch_size: 8
is_logged: false

method: vanilla
model:
_target_: sentence_transformers.SentenceTransformer
model_name_or_path: "sentence-transformers/all-mpnet-base-v2"
Expand Down
1 change: 1 addition & 0 deletions lib/style-transfer/style_transfer/rb_gen/steps/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from style_transfer.rb_gen.steps.dpo import dpo_train
from style_transfer.rb_gen.steps.generate import generate
from style_transfer.rb_gen.steps.recalibrate_scoring import recalibrate_scoring
from style_transfer.rb_gen.steps.score import score
from style_transfer.rb_gen.steps.sft import sft_train
2 changes: 1 addition & 1 deletion lib/style-transfer/style_transfer/rb_gen/steps/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def dpo_train(
args = hydra.utils.instantiate(cfg.dpo.training_args)
args.padding_value = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_on_path=f"models/{wandb.run.id}/merged/"
pretrained_model_name_or_path=f"models/{wandb.run.id}/merged/"
)
model.enable_input_require_grads()
peft_config = hydra.utils.instantiate(cfg.model.peft_config)
Expand Down
6 changes: 5 additions & 1 deletion lib/style-transfer/style_transfer/rb_gen/steps/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ def generate(
logging.info("🎉 And it's done!")

shuffled_gen_dataset = gen_dataset.shuffle(seed=cfg.seed + step)
subset_gen_dataset = shuffled_gen_dataset.select(range(cfg.dataset.num_generated_samples))
subset_gen_dataset = (
shuffled_gen_dataset.select(range(cfg.dataset.num_generated_samples))
if step != 0
else shuffled_gen_dataset
)
gen_dataloader = torch.utils.data.DataLoader(
subset_gen_dataset,
batch_size=cfg.gen.batch_size,
Expand Down
185 changes: 185 additions & 0 deletions lib/style-transfer/style_transfer/rb_gen/steps/recalibrate_scoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import logging
import os

import hydra
import pandas as pd
import torch
import wandb
from datasets import Dataset
from omegaconf import DictConfig
from sentence_transformers import InputExample, SentenceTransformer, util
from style_transfer.rb_gen.utils.utils import CustomWandbCallback

os.environ["WANDB_LOG_MODEL"] = "checkpoint"


def train_eval_model(
cfg: DictConfig, eval_model: SentenceTransformer, gen_dataset: Dataset
) -> Dataset:
"""Train the evaluator model.

Args:
cfg: The configuration for the training.
eval_model: The model to train.
gen_dataset: The dataset to use for training.
"""
logging.info("🎲 Training Semantic Model ...")
logging.info("🪚 Splitting Dataset for training...")
gen_dataset = gen_dataset.train_test_split(train_size=cfg.score.train.train_size)
train_gen_dataset = gen_dataset["train"]
gen_dataset = gen_dataset["test"]
train_examples = []

train_ground_dict = process_ground_predictions(train_gen_dataset, eval_model)
train_examples.extend(
create_input_examples(
train_ground_dict["text_1"],
train_ground_dict["text_2"],
train_ground_dict["ground_scores"],
)
)

train_score_dict = process_generation_scores(train_gen_dataset, eval_model)
for seq in range(4):
train_examples.extend(
create_input_examples(
train_gen_dataset[f"generation_{seq}"],
train_gen_dataset["ground_texts"],
train_score_dict[f"evaluator_scores_{seq}"],
offset=-0.5,
)
)

train_gen_dataloader = torch.utils.data.DataLoader(
train_examples,
batch_size=cfg.score.batch_size,
)

train_loss = hydra.utils.instantiate(cfg.score.train.loss, eval_model)()
eval_model.fit(
train_objectives=[(train_gen_dataloader, train_loss)],
epochs=cfg.score.train.epochs,
warmup_steps=cfg.score.train.warmup_steps,
callback=[CustomWandbCallback],
)
logging.info("🎉 Semantic Model Trained !")
return gen_dataset


def create_input_examples(texts1, texts2, scores, offset=0.5):
return [
InputExample(texts=[t1, t2], label=score + offset)
for t1, t2, score in zip(texts1, texts2, scores)
]


def process_ground_predictions(dataset, model):
split_point = len(dataset["ground_texts"]) // 2
split1 = dataset["ground_texts"][:split_point]
split2 = dataset["ground_texts"][split_point:]
scores = encode_and_score_texts(model, split1, split2)
return {"ground_scores": scores, "text_1": split1, "text_2": split2}


def encode_and_score_texts(model, texts1, texts2, batch_size=8):
enc1 = model.encode(texts1, batch_size=batch_size)
enc2 = model.encode(texts2, batch_size=batch_size)
return [util.cos_sim(e1, e2)[0][0].item() for e1, e2 in zip(enc1, enc2)]


def process_generation_scores(dataset, model):
ground_enc = model.encode(dataset["ground_texts"], batch_size=8)
score_dict = {}
for seq in range(4):
prediction_enc = model.encode(dataset[f"generation_{seq}"], batch_size=8)
scores = [
util.cos_sim(g_enc, p_enc)[0][0].item()
for g_enc, p_enc in zip(ground_enc, prediction_enc)
]
score_dict[f"evaluator_scores_{seq}"] = scores
return score_dict


def score_gen_dataset(cfg: DictConfig, dataset: dict, eval_model: SentenceTransformer) -> dict:
"""Score the dataset. Using the evaluator model and cosine similarity.
We score the dataset by calculating the cosine similarity between the ground truth a
nd the generated text.
We iterate over the number of generated sequences and calculate the cosine similarity
for each sequence.

Args:
cfg: The configuration for the scoring.
dataset: The dataset to score.
eval_model: The model to use for scoring.

Returns:
The scored dataset.
"""

logging.info("🔍 Scoring the dataset ...")
score_dict: dict = {}
ground_encoding = eval_model.encode(
dataset["ground_texts"],
batch_size=cfg.score.batch_size,
)
for seq in range(cfg.model.num_generated_sequences):
prediction_enc = eval_model.encode(
dataset[f"generation_{seq}"],
batch_size=cfg.score.batch_size,
)
scores = [
util.cos_sim(ground_enc, pred_enc)[0][0].item()
for ground_enc, pred_enc in zip(ground_encoding, prediction_enc)
]
score_dict.setdefault(f"evaluator_scores_{seq}", []).extend(scores)
logging.info("🎉 Dataset Scored !")
return score_dict


def recalibrate_scoring(
cfg, step: int, is_trainable: bool, dataset: Dataset, checkpoint: str
) -> Dataset:
"""Score the dataset and log the results.

Args:
cfg: The configuration for the scoring.
step: The current step.
is_trainable: Whether the model is trainable.
dataset: The dataset to score.
checkpoint: The checkpoint path to save the model.

Returns:
The scored dataset.
"""
wandb.config.update({"state": f"score/{step}"}, allow_val_change=True)
logging.info("🐈 Loading the Semantic Model ...")
if step == 0:
eval_model = hydra.utils.instantiate(cfg.score.model)
else:
eval_model = hydra.utils.instantiate(
cfg.score.model,
model_name_or_path=checkpoint,
)

gen_dataset = (
train_eval_model(cfg, eval_model, dataset) if is_trainable else dataset
).to_dict()
eval_model.save(checkpoint)
gen_dict_scores = score_gen_dataset(cfg, gen_dataset, eval_model)
gen_dataset.update(gen_dict_scores)
gen_df = pd.DataFrame.from_dict(gen_dataset)
generated_sequences = [
f"evaluator_scores_{seq}" for seq in range(cfg.model.num_generated_sequences)
]

gen_df["max_score"] = gen_df[generated_sequences].max(axis=1)
gen_df["min_score"] = gen_df[generated_sequences].min(axis=1)
gen_df["mean_score"] = gen_df[generated_sequences].mean(axis=1)

wandb.log({f"{wandb.config['state']}/dataset/score": wandb.Table(dataframe=gen_df)})
wandb.log({f"{wandb.config['state']}/max/mean": gen_df["max_score"].mean()})
wandb.log({f"{wandb.config['state']}/min/mean": gen_df["min_score"].mean()})
wandb.log({f"{wandb.config['state']}/mean": gen_df["mean_score"].mean()})

del eval_model
return Dataset.from_pandas(gen_df)
41 changes: 27 additions & 14 deletions lib/style-transfer/style_transfer/run_rb_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import wandb
from datasets import Dataset
from omegaconf import DictConfig, OmegaConf, omegaconf
from style_transfer.rb_gen.steps import dpo_train, generate, score, sft_train
from style_transfer.rb_gen.steps import dpo_train, generate, recalibrate_scoring, score, sft_train
from style_transfer.rb_gen.utils import build_dataset, split_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, PreTrainedTokenizerBase, set_seed
Expand Down Expand Up @@ -74,12 +74,22 @@ def main(cfg: DictConfig):
tokenizer,
gen_dataset,
)
score_dataset = score(
cfg,
step,
True if step == 0 else False,
sth_dataset,
checkpoint=eval_model_path,
score_dataset = (
score(
cfg,
step,
True if step == 0 else False,
sth_dataset,
checkpoint=eval_model_path,
)
if cfg.score.method == "classic"
else recalibrate_scoring(
cfg,
step,
True if step == 0 else False,
sth_dataset,
checkpoint=eval_model_path,
)
)
current_model_path = dpo_train(cfg, step, current_model_path, tokenizer, score_dataset)

Expand All @@ -92,13 +102,16 @@ def main(cfg: DictConfig):
tokenizer,
gen_dataset,
)
score(
cfg,
cfg.max_steps,
False,
sth_dataset,
checkpoint=eval_model_path,
)
if cfg.score.method == "classic":
score(
cfg,
cfg.max_steps,
False,
sth_dataset,
checkpoint=eval_model_path,
)
else:
recalibrate_scoring(cfg, cfg.max_steps, False, sth_dataset, checkpoint=eval_model_path)
shutil.rmtree(f"models/{wandb.run.id}/merged/")
wandb.finish()

Expand Down
Loading