Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement squad as retaining dataset. #97

Merged
merged 6 commits into from
Jun 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions llm_unlearn_ucl/parse_args.py
Original file line number Diff line number Diff line change
@@ -36,6 +36,13 @@ def parse_args() -> argparse.Namespace:
help="Set to True if using quantised models.",
)

parser.add_argument(
"--retaining_dataset",
type=str,
default="truthful_qa",
help="Name of the dataset to retain (NOTE: it needs to have"
" a custom dataloader creation script in utils.py)",
)
parser.add_argument(
"--unlearning_dataset",
type=str,
92 changes: 63 additions & 29 deletions llm_unlearn_ucl/unlearn_harm.py
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@
# Added
import numpy as np
import torch
import wandb
from accelerate import Accelerator
from datasets import load_dataset
from parse_args import parse_args
@@ -39,15 +40,15 @@
compute_kl,
create_mathqa_dataloader_from_dataset,
create_pku_dataloader_from_dataset,
create_squad_dataloader_from_dataset,
create_symbolic_dataloader_from_dataset,
create_truthfulqa_dataloader,
get_answer_loss,
get_rand_ans_loss,
get_squad_answers,
get_truthfulQA_answers_plaintext,
)

import wandb


def set_seed(seed_num: int) -> None:
torch.manual_seed(seed_num)
@@ -374,35 +375,68 @@ def main(args) -> None:
return

# Get normal data.
(
train_normal_loaders,
val_normal_loader,
test_normal_loader,
train_normal_dataset,
) = create_truthfulqa_dataloader(
tokenizer,
batch_size=args.batch_size,
seed=args.shuffle_seed if args.shuffle_seed is not None else args.seed,
num_samples=args.samples_count if args.sequential > 0 else None,
splits=max(args.sequential, 1),
)
normal_sample_path = f"{args.samples_save_dir}/normal_{args.samples_count if args.sequential > 0 else 'full'}_samples.json"
with open(normal_sample_path, "w") as fin:
print(f"Writing normal samples to {normal_sample_path}")
json.dump(
[
train_normal_dataset[i]
for i in range(
args.samples_count
if args.sequential > 0
else len(train_normal_dataset)
)
],
fin,
if args.retaining_dataset == "truthful_qa":
(
train_normal_loaders,
val_normal_loader,
test_normal_loader,
train_normal_dataset,
) = create_truthfulqa_dataloader(
tokenizer,
batch_size=args.batch_size,
seed=args.shuffle_seed if args.shuffle_seed is not None else args.seed,
num_samples=args.samples_count if args.sequential > 0 else None,
splits=max(args.sequential, 1),
)
normal_sample_path = f"{args.samples_save_dir}/normal_{args.samples_count if args.sequential > 0 else 'full'}_samples.json"
with open(normal_sample_path, "w") as fin:
print(f"Writing normal samples to {normal_sample_path}")
json.dump(
[
train_normal_dataset[i]
for i in range(
args.samples_count
if args.sequential > 0
else len(train_normal_dataset)
)
],
fin,
)

# Load normal answer used for random mismatch.
normal_ans = get_truthfulQA_answers_plaintext()
elif args.retaining_dataset == "rajpurkar/squad":
train_split = "train"
if args.samples_count > 0:
train_split = f"{train_split}[:{args.samples_count}]"
train_normal_dataset = load_dataset("rajpurkar/squad", split=train_split)
train_normal_loaders = create_squad_dataloader_from_dataset(
tokenizer,
train_normal_dataset,
batch_size=args.batch_size,
splits=max(args.sequential, 1),
)
normal_sample_path = f"{args.samples_save_dir}/squad_{args.samples_count if args.sequential > 0 else 'full'}_samples.json"
with open(normal_sample_path, "w") as fin:
print(f"Writing normal samples to {normal_sample_path}")
json.dump(
[
train_normal_dataset[i]
for i in range(
args.samples_count
if args.sequential > 0
else len(train_normal_dataset)
)
],
fin,
)

# Load normal answer used for random mismatch.
normal_ans = get_squad_answers(train_normal_dataset)
else:
print(f"Retaining dataset not known! dataset: {args.retaining_dataset}")
return

# Load normal answer used for random mismatch.
normal_ans = get_truthfulQA_answers_plaintext()
data_sample_artifacts = wandb.Artifact(
name="training_batch_raw_data", type="batch_data"
)
79 changes: 79 additions & 0 deletions llm_unlearn_ucl/utils.py
Original file line number Diff line number Diff line change
@@ -143,6 +143,78 @@ def preprocess(examples):
return dataloader


def create_squad_dataloader_from_dataset(
tokenizer, dataset, fraction=1.0, batch_size=4, splits: int = 1
):
"""
Given the squad dataset, create the dataloader on the squad Q&A pairs.
Args:
tokenizer: Tokenizer.
dataset: Loaded squad dataset.
fraction: <1 will do downsampling.
batch_size: Batch size used for each step.
splits: The number of splits that the dataset will be sliced into.
Returns:
A List of DataLoader of squad Q&A pairs.
"""

# Preproccess function.
def preproccess(examples):
"""
Input: Dict[List]
Output: Dict[List]
"""
results = {
"input_ids": [],
"attention_mask": [],
"start_locs": [],
}

for i in range(len(examples["context"])):
prompt = examples["context"][i] + " " + examples["question"][i]
answer = examples["answers"][i]["text"][0]

# Add all responses to results or skip if none.
text = f"### Question: {prompt}\n ### Answer: {answer}"
tokenized = tokenizer(text, truncation=True, padding="max_length")

results["input_ids"].append(tokenized["input_ids"])
results["attention_mask"].append(tokenized["attention_mask"])
# Calculate start idx for answer
test_text = f"### Question: {prompt}\n ### Answer: "
test_tokenized = tokenizer(test_text, truncation=True, padding="max_length")
results["start_locs"].append(len(test_tokenized["input_ids"]) - 1)

return results

# Need to drop all original columns to emit more than one row for each original row https://huggingface.co/docs/datasets/about_map_batch#input-size-output-size.
dataset = dataset.map(
preproccess,
batched=True,
remove_columns=["question", "answers", "context", "id", "title"],
)
dataset.set_format(
type="torch",
columns=["input_ids", "attention_mask", "start_locs"],
)

# Add labels and make it data loader.
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# TODO: data_collator introduces extra/less processed samples.
dataloaders = [
torch.utils.data.DataLoader(
train_split_dataset, batch_size=batch_size, collate_fn=data_collator
)
for train_split_dataset in torch.utils.data.random_split(
dataset, tuple(len(dataset) // splits for i in range(splits))
)
]

return dataloaders


def create_pku_dataloader_from_dataset(
tokenizer, dataset, fraction=1.0, batch_size=4, splits: int = 1
):
@@ -236,6 +308,13 @@ def preproccess(examples):
return dataloaders


def get_squad_answers(dataset):
answers = []
for i in dataset:
answers.extend(i["answers"]["text"])
return answers


def create_truthfulqa_dataloader(
tokenizer,
batch_size=4,