Skip to content

Commit

Permalink
Implement squad as retaining dataset. (#97)
Browse files Browse the repository at this point in the history
* Implement nq_open as the retaining dataset.

* Implement squad as the retaining dataset.

* Fix merge error.

* Fix dataset slicing and minor cleanup.

* Keep only one answer per question
  • Loading branch information
Davidyz authored Jun 9, 2024
1 parent cab4141 commit a64e345
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 29 deletions.
7 changes: 7 additions & 0 deletions llm_unlearn_ucl/parse_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ def parse_args() -> argparse.Namespace:
help="Set to True if using quantised models.",
)

parser.add_argument(
"--retaining_dataset",
type=str,
default="truthful_qa",
help="Name of the dataset to retain (NOTE: it needs to have"
" a custom dataloader creation script in utils.py)",
)
parser.add_argument(
"--unlearning_dataset",
type=str,
Expand Down
92 changes: 63 additions & 29 deletions llm_unlearn_ucl/unlearn_harm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
# Added
import numpy as np
import torch
import wandb
from accelerate import Accelerator
from datasets import load_dataset
from parse_args import parse_args
Expand All @@ -39,15 +40,15 @@
compute_kl,
create_mathqa_dataloader_from_dataset,
create_pku_dataloader_from_dataset,
create_squad_dataloader_from_dataset,
create_symbolic_dataloader_from_dataset,
create_truthfulqa_dataloader,
get_answer_loss,
get_rand_ans_loss,
get_squad_answers,
get_truthfulQA_answers_plaintext,
)

import wandb


def set_seed(seed_num: int) -> None:
torch.manual_seed(seed_num)
Expand Down Expand Up @@ -374,35 +375,68 @@ def main(args) -> None:
return

# Get normal data.
(
train_normal_loaders,
val_normal_loader,
test_normal_loader,
train_normal_dataset,
) = create_truthfulqa_dataloader(
tokenizer,
batch_size=args.batch_size,
seed=args.shuffle_seed if args.shuffle_seed is not None else args.seed,
num_samples=args.samples_count if args.sequential > 0 else None,
splits=max(args.sequential, 1),
)
normal_sample_path = f"{args.samples_save_dir}/normal_{args.samples_count if args.sequential > 0 else 'full'}_samples.json"
with open(normal_sample_path, "w") as fin:
print(f"Writing normal samples to {normal_sample_path}")
json.dump(
[
train_normal_dataset[i]
for i in range(
args.samples_count
if args.sequential > 0
else len(train_normal_dataset)
)
],
fin,
if args.retaining_dataset == "truthful_qa":
(
train_normal_loaders,
val_normal_loader,
test_normal_loader,
train_normal_dataset,
) = create_truthfulqa_dataloader(
tokenizer,
batch_size=args.batch_size,
seed=args.shuffle_seed if args.shuffle_seed is not None else args.seed,
num_samples=args.samples_count if args.sequential > 0 else None,
splits=max(args.sequential, 1),
)
normal_sample_path = f"{args.samples_save_dir}/normal_{args.samples_count if args.sequential > 0 else 'full'}_samples.json"
with open(normal_sample_path, "w") as fin:
print(f"Writing normal samples to {normal_sample_path}")
json.dump(
[
train_normal_dataset[i]
for i in range(
args.samples_count
if args.sequential > 0
else len(train_normal_dataset)
)
],
fin,
)

# Load normal answer used for random mismatch.
normal_ans = get_truthfulQA_answers_plaintext()
elif args.retaining_dataset == "rajpurkar/squad":
train_split = "train"
if args.samples_count > 0:
train_split = f"{train_split}[:{args.samples_count}]"
train_normal_dataset = load_dataset("rajpurkar/squad", split=train_split)
train_normal_loaders = create_squad_dataloader_from_dataset(
tokenizer,
train_normal_dataset,
batch_size=args.batch_size,
splits=max(args.sequential, 1),
)
normal_sample_path = f"{args.samples_save_dir}/squad_{args.samples_count if args.sequential > 0 else 'full'}_samples.json"
with open(normal_sample_path, "w") as fin:
print(f"Writing normal samples to {normal_sample_path}")
json.dump(
[
train_normal_dataset[i]
for i in range(
args.samples_count
if args.sequential > 0
else len(train_normal_dataset)
)
],
fin,
)

# Load normal answer used for random mismatch.
normal_ans = get_squad_answers(train_normal_dataset)
else:
print(f"Retaining dataset not known! dataset: {args.retaining_dataset}")
return

# Load normal answer used for random mismatch.
normal_ans = get_truthfulQA_answers_plaintext()
data_sample_artifacts = wandb.Artifact(
name="training_batch_raw_data", type="batch_data"
)
Expand Down
79 changes: 79 additions & 0 deletions llm_unlearn_ucl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,78 @@ def preprocess(examples):
return dataloader


def create_squad_dataloader_from_dataset(
tokenizer, dataset, fraction=1.0, batch_size=4, splits: int = 1
):
"""
Given the squad dataset, create the dataloader on the squad Q&A pairs.
Args:
tokenizer: Tokenizer.
dataset: Loaded squad dataset.
fraction: <1 will do downsampling.
batch_size: Batch size used for each step.
splits: The number of splits that the dataset will be sliced into.
Returns:
A List of DataLoader of squad Q&A pairs.
"""

# Preproccess function.
def preproccess(examples):
"""
Input: Dict[List]
Output: Dict[List]
"""
results = {
"input_ids": [],
"attention_mask": [],
"start_locs": [],
}

for i in range(len(examples["context"])):
prompt = examples["context"][i] + " " + examples["question"][i]
answer = examples["answers"][i]["text"][0]

# Add all responses to results or skip if none.
text = f"### Question: {prompt}\n ### Answer: {answer}"
tokenized = tokenizer(text, truncation=True, padding="max_length")

results["input_ids"].append(tokenized["input_ids"])
results["attention_mask"].append(tokenized["attention_mask"])
# Calculate start idx for answer
test_text = f"### Question: {prompt}\n ### Answer: "
test_tokenized = tokenizer(test_text, truncation=True, padding="max_length")
results["start_locs"].append(len(test_tokenized["input_ids"]) - 1)

return results

# Need to drop all original columns to emit more than one row for each original row https://huggingface.co/docs/datasets/about_map_batch#input-size-output-size.
dataset = dataset.map(
preproccess,
batched=True,
remove_columns=["question", "answers", "context", "id", "title"],
)
dataset.set_format(
type="torch",
columns=["input_ids", "attention_mask", "start_locs"],
)

# Add labels and make it data loader.
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# TODO: data_collator introduces extra/less processed samples.
dataloaders = [
torch.utils.data.DataLoader(
train_split_dataset, batch_size=batch_size, collate_fn=data_collator
)
for train_split_dataset in torch.utils.data.random_split(
dataset, tuple(len(dataset) // splits for i in range(splits))
)
]

return dataloaders


def create_pku_dataloader_from_dataset(
tokenizer, dataset, fraction=1.0, batch_size=4, splits: int = 1
):
Expand Down Expand Up @@ -236,6 +308,13 @@ def preproccess(examples):
return dataloaders


def get_squad_answers(dataset):
answers = []
for i in dataset:
answers.extend(i["answers"]["text"])
return answers


def create_truthfulqa_dataloader(
tokenizer,
batch_size=4,
Expand Down

0 comments on commit a64e345

Please sign in to comment.