From b25fc745ee5f7d9fd74c0f480d9e714d00106e61 Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Thu, 6 Jun 2024 19:39:19 +0100 Subject: [PATCH 1/5] Implement nq_open as the retaining dataset. --- llm_unlearn_ucl/parse_args.py | 7 +++ llm_unlearn_ucl/unlearn_harm.py | 86 ++++++++++++++++++++++----------- llm_unlearn_ucl/utils.py | 82 +++++++++++++++++++++++++++++++ 3 files changed, 148 insertions(+), 27 deletions(-) diff --git a/llm_unlearn_ucl/parse_args.py b/llm_unlearn_ucl/parse_args.py index f9de604..54ae58e 100644 --- a/llm_unlearn_ucl/parse_args.py +++ b/llm_unlearn_ucl/parse_args.py @@ -36,6 +36,13 @@ def parse_args() -> argparse.Namespace: help="Set to True if using quantised models.", ) + parser.add_argument( + "--retaining_dataset", + type=str, + default="truthful_qa", + help="Name of the dataset to retain (NOTE: it needs to have" + " a custom dataloader creation script in utils.py)", + ) parser.add_argument( "--unlearning_dataset", type=str, diff --git a/llm_unlearn_ucl/unlearn_harm.py b/llm_unlearn_ucl/unlearn_harm.py index 0afd25a..1344b6e 100644 --- a/llm_unlearn_ucl/unlearn_harm.py +++ b/llm_unlearn_ucl/unlearn_harm.py @@ -39,9 +39,11 @@ from utils import ( compute_kl, create_mathqa_dataloader_from_dataset, + create_nq_open_dataloader_from_dataset, create_pku_dataloader_from_dataset, create_truthfulqa_dataloader, get_answer_loss, + get_nq_open_answers, get_rand_ans_loss, get_truthfulQA_answers_plaintext, ) @@ -328,35 +330,65 @@ def main(args) -> None: return # Get normal data. - ( - train_normal_loaders, - val_normal_loader, - test_normal_loader, - train_normal_dataset, - ) = create_truthfulqa_dataloader( - tokenizer, - batch_size=args.batch_size, - seed=args.shuffle_seed if args.shuffle_seed is not None else args.seed, - num_samples=args.samples_count if args.sequential > 0 else None, - splits=max(args.sequential, 1), - ) - normal_sample_path = f"{args.samples_save_dir}/normal_{args.samples_count if args.sequential > 0 else 'full'}_samples.json" - with open(normal_sample_path, "w") as fin: - print(f"Writing normal samples to {normal_sample_path}") - json.dump( - [ - train_normal_dataset[i] - for i in range( - args.samples_count - if args.sequential > 0 - else len(train_normal_dataset) - ) - ], - fin, + if args.retaining_dataset == "truthful_qa": + ( + train_normal_loaders, + val_normal_loader, + test_normal_loader, + train_normal_dataset, + ) = create_truthfulqa_dataloader( + tokenizer, + batch_size=args.batch_size, + seed=args.shuffle_seed if args.shuffle_seed is not None else args.seed, + num_samples=args.samples_count if args.sequential > 0 else None, + splits=max(args.sequential, 1), + ) + normal_sample_path = f"{args.samples_save_dir}/normal_{args.samples_count if args.sequential > 0 else 'full'}_samples.json" + with open(normal_sample_path, "w") as fin: + print(f"Writing normal samples to {normal_sample_path}") + json.dump( + [ + train_normal_dataset[i] + for i in range( + args.samples_count + if args.sequential > 0 + else len(train_normal_dataset) + ) + ], + fin, + ) + + # Load normal answer used for random mismatch. + normal_ans = get_truthfulQA_answers_plaintext() + if args.retaining_dataset == "google-research-datasets/nq_open": + train_normal_dataset = load_dataset( + "google-research-datasets/nq_open", split="train" + ) + normal_dataset_copy = train_normal_dataset.copy() + train_normal_loaders = create_nq_open_dataloader_from_dataset( + tokenizer, + train_normal_dataset, + batch_size=args.batch_size, + splits=max(args.sequential, 1), ) + normal_sample_path = f"{args.samples_save_dir}/nq_open_{args.samples_count if args.sequential > 0 else 'full'}_samples.json" + with open(normal_sample_path, "w") as fin: + print(f"Writing normal samples to {normal_sample_path}") + json.dump( + [ + train_normal_dataset[i] + for i in range( + args.samples_count + if args.sequential > 0 + else len(train_normal_dataset) + ) + ], + fin, + ) + + # Load normal answer used for random mismatch. + normal_ans = get_nq_open_answers(normal_dataset_copy) - # Load normal answer used for random mismatch. - normal_ans = get_truthfulQA_answers_plaintext() data_sample_artifacts = wandb.Artifact( name="training_batch_raw_data", type="batch_data" ) diff --git a/llm_unlearn_ucl/utils.py b/llm_unlearn_ucl/utils.py index c5f0fc7..8e4ccca 100644 --- a/llm_unlearn_ucl/utils.py +++ b/llm_unlearn_ucl/utils.py @@ -88,6 +88,81 @@ def preprocess(examples): return dataloader +def create_nq_open_dataloader_from_dataset( + tokenizer, dataset, fraction=1.0, batch_size=4, splits: int = 1 +): + """ + Given the PKU dataset, create the dataloader on the unlearned harmful Q&A pairs. + + Args: + tokenizer: Tokenizer. + dataset: Loaded PKU dataset. + fraction: <1 will do downsampling. + batch_size: Batch size used for each step. + splits: The number of splits that the dataset will be sliced into. + Returns: + A List of DataLoader of PKU harmful Q&A pairs. + """ + + # Preproccess function. + def preproccess(examples): + """ + Input: Dict[List] + Output: Dict[List] + """ + results = { + "input_ids": [], + "attention_mask": [], + "start_locs": [], + } + + for i in range(len(examples["answer"])): + prompt = examples["question"] + answers = examples["answer"] + + # Add all responses to results or skip if none. + for answer in answers: + text = f"### Question: {prompt}\n ### Answer: {answer}" + tokenized = tokenizer(text, truncation=True, padding="max_length") + + results["input_ids"].append(tokenized["input_ids"]) + results["attention_mask"].append(tokenized["attention_mask"]) + # Calculate start idx for answer + test_text = f"### Question: {prompt}\n ### Answer: " + test_tokenized = tokenizer( + test_text, truncation=True, padding="max_length" + ) + results["start_locs"].append(len(test_tokenized["input_ids"]) - 1) + + return results + + # Need to drop all original columns to emit more than one row for each original row https://huggingface.co/docs/datasets/about_map_batch#input-size-output-size. + dataset = dataset.map( + preproccess, + batched=True, + remove_columns=["question", "answer"], + ) + dataset.set_format( + type="torch", + columns=["input_ids", "attention_mask", "start_locs"], + ) + + # Add labels and make it data loader. + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + # TODO: data_collator introduces extra/less processed samples. + dataloaders = [ + torch.utils.data.DataLoader( + train_split_dataset, batch_size=batch_size, collate_fn=data_collator + ) + for train_split_dataset in torch.utils.data.random_split( + dataset, tuple(len(dataset) // splits for i in range(splits)) + ) + ] + + return dataloaders + + def create_pku_dataloader_from_dataset( tokenizer, dataset, fraction=1.0, batch_size=4, splits: int = 1 ): @@ -181,6 +256,13 @@ def preproccess(examples): return dataloaders +def get_nq_open_answers(dataset): + answers = [] + for i in dataset: + answers.extend(i["answer"]) + return answers + + def create_truthfulqa_dataloader( tokenizer, batch_size=4, From e4a4f9eb22c5a39e82fa4f3332b0686f719a6d1d Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Fri, 7 Jun 2024 13:05:55 +0100 Subject: [PATCH 2/5] Implement squad as the retaining dataset. --- llm_unlearn_ucl/unlearn_harm.py | 21 +++++++++++---------- llm_unlearn_ucl/utils.py | 20 ++++++++++---------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/llm_unlearn_ucl/unlearn_harm.py b/llm_unlearn_ucl/unlearn_harm.py index 1344b6e..295300b 100644 --- a/llm_unlearn_ucl/unlearn_harm.py +++ b/llm_unlearn_ucl/unlearn_harm.py @@ -39,12 +39,12 @@ from utils import ( compute_kl, create_mathqa_dataloader_from_dataset, - create_nq_open_dataloader_from_dataset, create_pku_dataloader_from_dataset, + create_squad_dataloader_from_dataset, create_truthfulqa_dataloader, get_answer_loss, - get_nq_open_answers, get_rand_ans_loss, + get_squad_answers, get_truthfulQA_answers_plaintext, ) @@ -360,18 +360,16 @@ def main(args) -> None: # Load normal answer used for random mismatch. normal_ans = get_truthfulQA_answers_plaintext() - if args.retaining_dataset == "google-research-datasets/nq_open": - train_normal_dataset = load_dataset( - "google-research-datasets/nq_open", split="train" - ) - normal_dataset_copy = train_normal_dataset.copy() - train_normal_loaders = create_nq_open_dataloader_from_dataset( + elif args.retaining_dataset == "rajpurkar/squad": + train_normal_dataset = load_dataset("rajpurkar/squad", split="train") + normal_dataset_copy = train_normal_dataset + train_normal_loaders = create_squad_dataloader_from_dataset( tokenizer, train_normal_dataset, batch_size=args.batch_size, splits=max(args.sequential, 1), ) - normal_sample_path = f"{args.samples_save_dir}/nq_open_{args.samples_count if args.sequential > 0 else 'full'}_samples.json" + normal_sample_path = f"{args.samples_save_dir}/squad_{args.samples_count if args.sequential > 0 else 'full'}_samples.json" with open(normal_sample_path, "w") as fin: print(f"Writing normal samples to {normal_sample_path}") json.dump( @@ -387,7 +385,10 @@ def main(args) -> None: ) # Load normal answer used for random mismatch. - normal_ans = get_nq_open_answers(normal_dataset_copy) + normal_ans = get_squad_answers(normal_dataset_copy) + else: + print(f"Retaining dataset not known! dataset: {args.retaining_dataset}") + return data_sample_artifacts = wandb.Artifact( name="training_batch_raw_data", type="batch_data" diff --git a/llm_unlearn_ucl/utils.py b/llm_unlearn_ucl/utils.py index 8e4ccca..7bfbd4f 100644 --- a/llm_unlearn_ucl/utils.py +++ b/llm_unlearn_ucl/utils.py @@ -88,20 +88,20 @@ def preprocess(examples): return dataloader -def create_nq_open_dataloader_from_dataset( +def create_squad_dataloader_from_dataset( tokenizer, dataset, fraction=1.0, batch_size=4, splits: int = 1 ): """ - Given the PKU dataset, create the dataloader on the unlearned harmful Q&A pairs. + Given the squad dataset, create the dataloader on the squad Q&A pairs. Args: tokenizer: Tokenizer. - dataset: Loaded PKU dataset. + dataset: Loaded squad dataset. fraction: <1 will do downsampling. batch_size: Batch size used for each step. splits: The number of splits that the dataset will be sliced into. Returns: - A List of DataLoader of PKU harmful Q&A pairs. + A List of DataLoader of squad Q&A pairs. """ # Preproccess function. @@ -116,9 +116,9 @@ def preproccess(examples): "start_locs": [], } - for i in range(len(examples["answer"])): - prompt = examples["question"] - answers = examples["answer"] + for i in range(len(examples["context"])): + prompt = examples["context"][i] + " " + examples["question"][i] + answers = examples["answers"][i]["text"][0] # Add all responses to results or skip if none. for answer in answers: @@ -140,7 +140,7 @@ def preproccess(examples): dataset = dataset.map( preproccess, batched=True, - remove_columns=["question", "answer"], + remove_columns=["question", "answers", "context", "id", "title"], ) dataset.set_format( type="torch", @@ -256,10 +256,10 @@ def preproccess(examples): return dataloaders -def get_nq_open_answers(dataset): +def get_squad_answers(dataset): answers = [] for i in dataset: - answers.extend(i["answer"]) + answers.extend(i["answers"]["text"]) return answers From 8550b3b403d836cb9654f5d8498b1fca4c54547b Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Sun, 9 Jun 2024 12:45:05 +0100 Subject: [PATCH 3/5] Fix merge error. --- llm_unlearn_ucl/unlearn_harm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llm_unlearn_ucl/unlearn_harm.py b/llm_unlearn_ucl/unlearn_harm.py index 472d430..1dcc044 100644 --- a/llm_unlearn_ucl/unlearn_harm.py +++ b/llm_unlearn_ucl/unlearn_harm.py @@ -28,6 +28,7 @@ # Added import numpy as np import torch +import wandb from accelerate import Accelerator from datasets import load_dataset from parse_args import parse_args @@ -39,7 +40,8 @@ compute_kl, create_mathqa_dataloader_from_dataset, create_pku_dataloader_from_dataset, - create_squad_dataloader_from_dataset,` + create_squad_dataloader_from_dataset, + create_symbolic_dataloader_from_dataset, create_truthfulqa_dataloader, get_answer_loss, get_rand_ans_loss, @@ -47,8 +49,6 @@ get_truthfulQA_answers_plaintext, ) -import wandb - def set_seed(seed_num: int) -> None: torch.manual_seed(seed_num) From c9a44b2b3f884687bf1ea3cdef3653f4f3d688cd Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Sun, 9 Jun 2024 14:41:47 +0100 Subject: [PATCH 4/5] Fix dataset slicing and minor cleanup. --- llm_unlearn_ucl/unlearn_harm.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/llm_unlearn_ucl/unlearn_harm.py b/llm_unlearn_ucl/unlearn_harm.py index 1dcc044..d94707b 100644 --- a/llm_unlearn_ucl/unlearn_harm.py +++ b/llm_unlearn_ucl/unlearn_harm.py @@ -406,8 +406,10 @@ def main(args) -> None: # Load normal answer used for random mismatch. normal_ans = get_truthfulQA_answers_plaintext() elif args.retaining_dataset == "rajpurkar/squad": - train_normal_dataset = load_dataset("rajpurkar/squad", split="train") - normal_dataset_copy = train_normal_dataset + train_split = "train" + if args.samples_count > 0: + train_split = f"{train_split}[:{args.samples_count}]" + train_normal_dataset = load_dataset("rajpurkar/squad", split=train_split) train_normal_loaders = create_squad_dataloader_from_dataset( tokenizer, train_normal_dataset, @@ -430,7 +432,7 @@ def main(args) -> None: ) # Load normal answer used for random mismatch. - normal_ans = get_squad_answers(normal_dataset_copy) + normal_ans = get_squad_answers(train_normal_dataset) else: print(f"Retaining dataset not known! dataset: {args.retaining_dataset}") return From a6a8c7785d622c87600628246cf3ea6c9869f91a Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Sun, 9 Jun 2024 15:00:57 +0100 Subject: [PATCH 5/5] Keep only one answer per question --- llm_unlearn_ucl/utils.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/llm_unlearn_ucl/utils.py b/llm_unlearn_ucl/utils.py index 9147f56..31b3d35 100644 --- a/llm_unlearn_ucl/utils.py +++ b/llm_unlearn_ucl/utils.py @@ -173,21 +173,18 @@ def preproccess(examples): for i in range(len(examples["context"])): prompt = examples["context"][i] + " " + examples["question"][i] - answers = examples["answers"][i]["text"][0] + answer = examples["answers"][i]["text"][0] # Add all responses to results or skip if none. - for answer in answers: - text = f"### Question: {prompt}\n ### Answer: {answer}" - tokenized = tokenizer(text, truncation=True, padding="max_length") + text = f"### Question: {prompt}\n ### Answer: {answer}" + tokenized = tokenizer(text, truncation=True, padding="max_length") - results["input_ids"].append(tokenized["input_ids"]) - results["attention_mask"].append(tokenized["attention_mask"]) - # Calculate start idx for answer - test_text = f"### Question: {prompt}\n ### Answer: " - test_tokenized = tokenizer( - test_text, truncation=True, padding="max_length" - ) - results["start_locs"].append(len(test_tokenized["input_ids"]) - 1) + results["input_ids"].append(tokenized["input_ids"]) + results["attention_mask"].append(tokenized["attention_mask"]) + # Calculate start idx for answer + test_text = f"### Question: {prompt}\n ### Answer: " + test_tokenized = tokenizer(test_text, truncation=True, padding="max_length") + results["start_locs"].append(len(test_tokenized["input_ids"]) - 1) return results