-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c31fae4
commit 7ef5140
Showing
21 changed files
with
27,663 additions
and
0 deletions.
There are no files selected for viewing
454 changes: 454 additions & 0 deletions
454
lib/style-transfer/hf_datasets/mimic_iii_wt_replaced_tokens/post-processed/dev.jsonl
Large diffs are not rendered by default.
Oops, something went wrong.
4,542 changes: 4,542 additions & 0 deletions
4,542
...nsfer/hf_datasets/mimic_iii_wt_replaced_tokens/post-processed/filtered_entries-utf8.jsonl
Large diffs are not rendered by default.
Oops, something went wrong.
4,542 changes: 4,542 additions & 0 deletions
4,542
...e-transfer/hf_datasets/mimic_iii_wt_replaced_tokens/post-processed/filtered_entries.jsonl
Large diffs are not rendered by default.
Oops, something went wrong.
455 changes: 455 additions & 0 deletions
455
lib/style-transfer/hf_datasets/mimic_iii_wt_replaced_tokens/post-processed/test.jsonl
Large diffs are not rendered by default.
Oops, something went wrong.
3,633 changes: 3,633 additions & 0 deletions
3,633
lib/style-transfer/hf_datasets/mimic_iii_wt_replaced_tokens/post-processed/train.jsonl
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import hydra | ||
import numpy as np | ||
import pandas as pd | ||
import peft | ||
import wandb | ||
from datasets import Dataset | ||
from omegaconf import DictConfig, ListConfig | ||
from style_transfer.rb_gen.utils.utils import CustomWandbCallback | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase | ||
from trl import DPOTrainer | ||
|
||
|
||
def add_preferences(data_point: dict) -> dict: | ||
"""Add preferences to the data point. | ||
The preferences are the bes and worst generations and their scores. | ||
Previously added during the evaluation step. | ||
We also add the deviation score which is the difference between the best and worst scores. | ||
Args: | ||
data_point: The data point to add preferences to. | ||
Returns: | ||
The data point with preferences added. | ||
""" | ||
df_point = pd.DataFrame({k: [v] for k, v in dict(data_point).items()}) | ||
filtered_columns = pd.DataFrame(df_point).filter(regex="^evaluator_scores") | ||
max_labels = filtered_columns.max().idxmax()[-1] | ||
best_generation = df_point[f"generation_{max_labels}"].values[0] | ||
best_score = filtered_columns.max().max() | ||
min_labels = filtered_columns.min().idxmin()[-1] | ||
worst_generation = df_point[f"generation_{min_labels}"].values[0] | ||
worst_score = filtered_columns.min().min() | ||
data_point["chosen"] = best_generation | ||
data_point["rejected"] = worst_generation | ||
data_point["chosen_score"] = best_score | ||
data_point["rejected_score"] = worst_score | ||
data_point["deviation_score"] = best_score - worst_score | ||
return data_point | ||
|
||
|
||
@hydra.main(version_base="1.3", config_path="./", config_name="dpo.yaml") | ||
def dpo_train(cfg) -> str: | ||
"""Train the model using the reinforcement learning algorithm DPO. | ||
We fix the percentile of the best candidate to keep for training. | ||
Args: | ||
cfg: The configuration for the training. | ||
step: The current step. | ||
model_path: The path to the model. | ||
tokenizer: The tokenizer. | ||
dataset: The dataset to train on. | ||
""" | ||
dataset = Dataset.from_pandas(pd.read_csv("./score.csv")) | ||
dataset = dataset.map( | ||
add_preferences, | ||
batched=False, | ||
) | ||
|
||
percentile = np.percentile(dataset["chosen_score"], cfg.percentile) | ||
dataset = dataset.filter(lambda x: x["chosen_score"] > percentile) | ||
dataset = dataset.select_columns(["prompts", "chosen", "rejected"]) | ||
dataset = dataset.rename_column("prompts", "prompt") | ||
|
||
args = hydra.utils.instantiate(cfg.training_args) | ||
peft_config = hydra.utils.instantiate(cfg.model.peft_config) | ||
peft_config.target_modules = ( | ||
list(peft_config.target_modules) | ||
if isinstance(peft_config.target_modules, ListConfig) | ||
else peft_config.target_modules | ||
) | ||
model = AutoModelForCausalLM.from_pretrained( | ||
pretrained_model_name_or_path="meta-llama/Llama-3.2-3B-Instruct" | ||
) | ||
model = peft.get_peft_model( | ||
model, | ||
peft_config, | ||
) | ||
|
||
model.add_adapter(peft_config=peft_config, adapter_name="reference") | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
pretrained_model_name_or_path="meta-llama/Llama-3.2-3B-Instruct" | ||
) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
model.enable_input_require_grads() | ||
dpo_trainer = DPOTrainer( | ||
args=args, | ||
model=model, | ||
tokenizer=tokenizer, | ||
train_dataset=dataset, | ||
model_adapter_name="default", | ||
ref_adapter_name="reference", | ||
) | ||
dpo_trainer.train() | ||
|
||
dpo_path = args.output_dir | ||
dpo_trainer.save_model(dpo_path) | ||
del model | ||
return dpo_path | ||
|
||
|
||
if __name__ == "__main__": | ||
dpo_train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
training_args: | ||
_target_: trl.DPOConfig | ||
per_device_train_batch_size: 2 | ||
logging_steps: 2 | ||
save_steps: 50 | ||
gradient_accumulation_steps: 8 | ||
gradient_checkpointing: false | ||
learning_rate: 5e-5 | ||
weight_decay: 1e-7 | ||
eval_strategy: "no" | ||
num_train_epochs: 5 | ||
output_dir: "models/dpo/" | ||
optim: "adafactor" | ||
save_only_model: true | ||
remove_unused_columns: false | ||
save_safetensors: false | ||
bf16: true | ||
seed: 0 | ||
max_length: 1024 | ||
max_prompt_length: 512 | ||
report_to: "none" | ||
|
||
model: | ||
peft_config: | ||
_target_: peft.LoraConfig | ||
task_type: CAUSAL_LM | ||
r: 16 | ||
lora_alpha: 16 | ||
lora_dropout: 0 | ||
bias: none | ||
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] | ||
quantization_config: null | ||
prompt: > | ||
As a doctor, you must write an original History of Present Illness (HPI) section for a discharge | ||
summary. Your response should capture the essence of a patient's health journey and recent | ||
medical experiences, while strictly using all the provided keywords conserving the order. You | ||
must adopt a medical telegraphic style, abbreviated, characterized by concise and direct | ||
language. Keywords: {} | ||
beta: 0.1 | ||
checkpoint: null | ||
dataset: null | ||
percentile: 50 |
13,215 changes: 13,215 additions & 0 deletions
13,215
lib/style-transfer/style_transfer/rb_gen/test/score.csv
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import hydra | ||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
from datasets import Dataset | ||
from sentence_transformers import InputExample, SentenceTransformer, losses, util | ||
|
||
|
||
def load_data(csv_path): | ||
return Dataset.from_pandas(pd.read_csv(csv_path)) | ||
|
||
|
||
def split_dataset(dataset, train_size=0.6): | ||
split_data = dataset.train_test_split(train_size=train_size) | ||
return split_data["train"], split_data["test"] | ||
|
||
|
||
def encode_and_score_texts(model, texts1, texts2, batch_size=8): | ||
enc1 = model.encode(texts1, batch_size=batch_size) | ||
enc2 = model.encode(texts2, batch_size=batch_size) | ||
return [util.cos_sim(e1, e2)[0][0].item() for e1, e2 in zip(enc1, enc2)] | ||
|
||
|
||
def create_input_examples(texts1, texts2, scores, offset=1): | ||
return [ | ||
InputExample(texts=[t1, t2], label=score + offset) | ||
for t1, t2, score in zip(texts1, texts2, scores) | ||
] | ||
|
||
|
||
def process_ground_predictions(dataset, model): | ||
split_point = len(dataset["ground_texts"]) // 2 | ||
split1 = dataset["ground_texts"][:split_point] | ||
split2 = dataset["ground_texts"][split_point:] | ||
scores = encode_and_score_texts(model, split1, split2) | ||
return {"ground_scores": scores, "text_1": split1, "text_2": split2} | ||
|
||
|
||
def process_generation_scores(dataset, model): | ||
ground_enc = model.encode(dataset["ground_texts"], batch_size=8) | ||
score_dict = {} | ||
for seq in range(4): | ||
prediction_enc = model.encode(dataset[f"generation_{seq}"], batch_size=8) | ||
scores = [ | ||
util.cos_sim(g_enc, p_enc)[0][0].item() | ||
for g_enc, p_enc in zip(ground_enc, prediction_enc) | ||
] | ||
score_dict[f"evaluator_scores_{seq}"] = scores | ||
return score_dict | ||
|
||
|
||
def train_model(model, train_examples, batch_size=8, epochs=2, warmup_steps=50): | ||
train_dataloader = torch.utils.data.DataLoader( | ||
train_examples, | ||
batch_size=batch_size, | ||
) | ||
train_loss = losses.ContrastiveLoss(model) | ||
model.fit( | ||
train_objectives=[(train_dataloader, train_loss)], | ||
epochs=epochs, | ||
warmup_steps=warmup_steps, | ||
) | ||
|
||
|
||
def main(): | ||
# Load and split dataset | ||
csv = load_data("./test/score.csv") | ||
train_dataset, test_dataset = split_dataset(csv) | ||
|
||
# Initialize model and examples list | ||
eval_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") | ||
train_examples = [] | ||
|
||
# Process training data | ||
train_ground_dict = process_ground_predictions(train_dataset, eval_model) | ||
train_examples.extend( | ||
create_input_examples( | ||
train_ground_dict["text_1"], | ||
train_ground_dict["text_2"], | ||
train_ground_dict["ground_scores"], | ||
) | ||
) | ||
|
||
# Process test data | ||
gen_ground_dict = process_ground_predictions(test_dataset, eval_model) | ||
train_examples.extend( | ||
create_input_examples( | ||
gen_ground_dict["text_1"], gen_ground_dict["text_2"], gen_ground_dict["ground_scores"] | ||
) | ||
) | ||
|
||
# Process generation scores for training | ||
train_score_dict = process_generation_scores(train_dataset, eval_model) | ||
scores = [] | ||
for seq in range(4): | ||
scores.extend(train_score_dict[f"evaluator_scores_{seq}"]) | ||
print("before learning:") | ||
print(f"Global mean: {np.mean(scores):.4f}, Global std: {np.std(scores):.4f}") | ||
for seq in range(4): | ||
train_examples.extend( | ||
create_input_examples( | ||
train_dataset[f"generation_{seq}"], | ||
train_dataset["ground_texts"], | ||
train_score_dict[f"evaluator_scores_{seq}"], | ||
offset=-1, | ||
) | ||
) | ||
|
||
# Train the model | ||
train_model(eval_model, train_examples) | ||
|
||
# Generate final scores | ||
score_dict = process_generation_scores(test_dataset, eval_model) | ||
scores = [] | ||
for seq in range(4): | ||
scores.extend(score_dict[f"evaluator_scores_{seq}"]) | ||
print("after learning:") | ||
print(f"Global mean: {np.mean(scores):.4f}, Global std: {np.std(scores):.4f}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import random | ||
|
||
import pandas as pd | ||
from datasets import load_dataset | ||
|
||
|
||
def main(): | ||
# Set random seed for reproducibility | ||
random.seed(42) | ||
|
||
# Load the dataset | ||
dataset = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k") | ||
|
||
sampled_data = random.sample(list(dataset["train"]), 1500) | ||
|
||
# Reformat data sample | ||
for sample in sampled_data: | ||
if "input" in sample: | ||
if sample["input"]: | ||
sample["instruction"] = f"{sample['instruction']} {sample['input']}" | ||
del sample["input"] | ||
|
||
if "output" in sample: | ||
sample["response"] = sample["output"] | ||
del sample["output"] | ||
|
||
df = pd.DataFrame(sampled_data) | ||
df.to_parquet("datasets/train/1_raw/raw_data.parquet", index=False) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import os | ||
from typing import List | ||
|
||
import pandas as pd | ||
from dotenv import load_dotenv | ||
from quickumls import QuickUMLS | ||
|
||
load_dotenv() | ||
|
||
|
||
class KeywordExtractor: | ||
def __init__(self): | ||
load_dotenv() | ||
self.matcher = QuickUMLS(quickumls_fp=os.getenv("QUICKUMLS_PATH")) | ||
|
||
def extract_keywords(self, text: str) -> List[str]: | ||
if pd.isna(text): | ||
return [] | ||
matches = self.matcher.match( | ||
text.removeprefix( | ||
"If you are a doctor, please answer the medical questions based on " | ||
"the patient's description." | ||
).strip(), | ||
best_match=True, | ||
ignore_syntax=False, | ||
) | ||
return [match[0]["term"] for match in matches] | ||
|
||
|
||
def run(input_path: str, output_path: str) -> None: | ||
# Read the parquet file | ||
extractor = KeywordExtractor() | ||
df = pd.read_parquet(input_path) | ||
|
||
# Apply keyword extraction to create new column | ||
df["instruction_keywords"] = df["instruction"].apply(extractor.extract_keywords) | ||
df["response_keywords"] = df["response"].apply(extractor.extract_keywords) | ||
|
||
# Select final columns | ||
print(df.head(10)) | ||
|
||
# Save the updated dataframe | ||
df.to_parquet(output_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
run( | ||
input_path="datasets/train/1_raw/raw_data.parquet", | ||
output_path="datasets/train/2_keyword/keyword_data.parquet", | ||
) |
Binary file not shown.
Binary file not shown.
Oops, something went wrong.