From afbb8c587b332311d7d9c6853f686aea2e72d2d4 Mon Sep 17 00:00:00 2001 From: David <30951234+Davidyz@users.noreply.github.com> Date: Sun, 9 Jun 2024 18:33:03 +0100 Subject: [PATCH] Implement dataloader constructor for piaf dataset. (#93) * Implement dataloader constructor for piaf dataset. * Fix for batched mapping. * Fix for batched mapping. * Black formatting. * Prepare for rebasing. --- llm_unlearn_ucl/unlearn_harm.py | 46 +++++++++++++++++--- llm_unlearn_ucl/utils.py | 74 +++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 5 deletions(-) diff --git a/llm_unlearn_ucl/unlearn_harm.py b/llm_unlearn_ucl/unlearn_harm.py index d94707b..7da03ad 100644 --- a/llm_unlearn_ucl/unlearn_harm.py +++ b/llm_unlearn_ucl/unlearn_harm.py @@ -39,6 +39,7 @@ from utils import ( compute_kl, create_mathqa_dataloader_from_dataset, + create_piaf_dataloader_from_dataset, create_pku_dataloader_from_dataset, create_squad_dataloader_from_dataset, create_symbolic_dataloader_from_dataset, @@ -315,6 +316,46 @@ def main(args) -> None: # ADDITONALLY: create_truthfulqa_dataloader() is also using this pattern!!! question_prefix_str = "### Question:" answer_prefix_str = "### Answer:" + elif args.unlearning_dataset == "AgentPublic/piaf": + # filter entries with harmful responses and draw random samples from the remaining dataset. + full_bad_dataset = load_dataset("AgentPublic/piaf", split="train").filter( + lambda entry: len(entry["answers"]["text"]) != 0 + ) + if args.shuffle_seed: + # shuffle the dataset with a given seed for reproducibility + full_bad_dataset = full_bad_dataset.shuffle(seed=args.shuffle_seed) + if args.sequential > 0: + # NOTE: sequential/batch unlearning using sliced dataset. + train_bad_dataset = full_bad_dataset.select(range(args.samples_count)) + else: + # NOTE: full dataset like bytedance. + train_bad_dataset = full_bad_dataset + + Path(args.samples_save_dir).mkdir(exist_ok=True) + bad_sample_path = f"{args.samples_save_dir}/piaf_{args.samples_count if args.sequential > 0 else 'full'}_samples.json" + with open(bad_sample_path, "w") as fin: + print(f"Writing bad samples to {bad_sample_path}") + json.dump( + [ + train_bad_dataset[i] + for i in range( + args.samples_count + if args.sequential > 0 + else len(train_bad_dataset) + ) + ], + fin, + ) + + train_bad_loaders = create_piaf_dataloader_from_dataset( + tokenizer, + train_bad_dataset, + batch_size=args.batch_size, + splits=max(args.sequential, 1), + ) + + question_prefix_str = "### Question:" + answer_prefix_str = "### Réponse:" elif args.unlearning_dataset == "sail/symbolic-instruction-tuning": # filter entries with harmful responses and draw random samples from the remaining dataset. full_bad_dataset = load_dataset( @@ -353,11 +394,6 @@ def main(args) -> None: splits=max(args.sequential, 1), ) - # XXX: for now this is the prefix that is added before each q and answer, - # it is used by get_rand_ans_loss() to extract just the question part and - # add a random answer to it. - # !!!! Has additional sideffect of model unlearning this pattern!!!! - # ADDITONALLY: create_truthfulqa_dataloader() is also using this pattern!!! question_prefix_str = "### Question:" answer_prefix_str = "### Answer:" elif args.unlearning_dataset == "math_qa": diff --git a/llm_unlearn_ucl/utils.py b/llm_unlearn_ucl/utils.py index 31b3d35..ffb2a0b 100644 --- a/llm_unlearn_ucl/utils.py +++ b/llm_unlearn_ucl/utils.py @@ -74,6 +74,80 @@ def preprocess(examples): return dataloaders +def create_piaf_dataloader_from_dataset( + tokenizer, dataset, fraction=1.0, batch_size=4, splits: int = 1 +): + """ + Given the piaf dataset, create the dataloader on the unlearned French Q&A pairs. + + Args: + tokenizer: Tokenizer. + dataset: Loaded piaf dataset. + fraction: <1 will do downsampling. + batch_size: Batch size used for each step. + splits: The number of splits that the dataset will be sliced into. + Returns: + A List of DataLoader of piaf French Q&A pairs. + """ + + def preproccess(examples): + """ + Input: Dict[List] + Output: Dict[List] + """ + results = { + "input_ids": [], + "attention_mask": [], + "start_locs": [], + } + for i in range(len(examples["answers"])): + prompt = examples["context"][i] + " " + examples["question"][i] + response = examples["answers"][i]["text"][0] + + text = f"### Question: {prompt}\n ### Réponse: {response}" + tokenized = tokenizer(text, truncation=True, padding="max_length") + + results["input_ids"].append(tokenized["input_ids"]) + results["attention_mask"].append(tokenized["attention_mask"]) + # Calculate start idx for answer + test_text = f"### Question: {prompt}\n ### Réponse: " + test_tokenized = tokenizer(test_text, truncation=True, padding="max_length") + results["start_locs"].append(len(test_tokenized["input_ids"]) - 1) + + return results + + dataset = dataset.map( + preproccess, + batched=True, + remove_columns=[ + "answers", + "context", + "id", + "question", + "title", + ], + ) + dataset.set_format( + type="torch", + columns=["input_ids", "attention_mask", "start_locs"], + ) + + # Add labels and make it data loader. + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + # TODO: data_collator introduces extra/less processed samples. + dataloaders = [ + torch.utils.data.DataLoader( + train_split_dataset, batch_size=batch_size, collate_fn=data_collator + ) + for train_split_dataset in torch.utils.data.random_split( + dataset, tuple(len(dataset) // splits for i in range(splits)) + ) + ] + + return dataloaders + + def create_mathqa_dataloader_from_dataset( tokenizer, dataset, fraction=1.0, batch_size=4 ):