From cac234935a08d6a44603db48041f4a5eab0b4e22 Mon Sep 17 00:00:00 2001 From: Jason Weston Date: Thu, 10 Nov 2022 14:06:05 -0500 Subject: [PATCH 1/4] cringe --- parlai/mutators/notokonly.py | 29 ++ projects/cringe/README.md | 68 ++++ projects/cringe/cringe_loss.py | 331 ++++++++++++++++++++ projects/cringe/safety_filter_world_logs.py | 74 +++++ projects/cringe/teachers.py | 67 ++++ 5 files changed, 569 insertions(+) create mode 100644 parlai/mutators/notokonly.py create mode 100644 projects/cringe/README.md create mode 100644 projects/cringe/cringe_loss.py create mode 100644 projects/cringe/safety_filter_world_logs.py create mode 100644 projects/cringe/teachers.py diff --git a/parlai/mutators/notokonly.py b/parlai/mutators/notokonly.py new file mode 100644 index 00000000000..91745817434 --- /dev/null +++ b/parlai/mutators/notokonly.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import random +from typing import List +from parlai.core.message import Message +from parlai.core.mutators import ManyEpisodeMutator, register_mutator + + +@register_mutator("notokonly") +class NotOKMutator(ManyEpisodeMutator): + """ + Flattens the entire conversation history. + + Simply concatenates all turns in the conversation with a newline. Frequently useful + when composed with other mutators. + """ + + def many_episode_mutation(self, episode: List[Message]) -> List[List[Message]]: + history = [] + for message in episode: + history.append(message.pop('text')) + message['text'] = '\n'.join(history) + if message['labels'][0] == '__notok__': + yield [message] + history.append(random.choice(message['labels'])) diff --git a/projects/cringe/README.md b/projects/cringe/README.md new file mode 100644 index 00000000000..686a889732f --- /dev/null +++ b/projects/cringe/README.md @@ -0,0 +1,68 @@ +# The CRINGE Loss: Learning what language *not* to model + +Leonard Adolphs, Tianyu Gao, Jing Xu, Kurt Shuster, Sainbayar Sukhbaatar, Jason Weston + + +## Abstract +Standard language model training employs gold human documents or human-human interaction data, and +treats all training data as positive examples. +Growing evidence shows that even with very large amounts of positive training data, issues remain +that can be alleviated with relatively small amounts of negative data -- examples of what the model should not do. +In this work, we propose a novel procedure to train with such data called the Cringe loss +(ContRastive Iterative Negative GEneration). + We show the effectiveness of this approach across three different experiments on the tasks of safe generation, + contradiction avoidance, and open-domain dialogue. Our models outperform multiple strong baselines and are + conceptually simple, easy to train and implement. + +## Paper Link + +Coming soon + + +## Train a CRINGE (single iter.) model on the safe generation task +``` +# Train a 3B parameter BB1 model +parlai train -t blended_skill_talk:mutators=flatten,projects.director.tasks.safety:SafeBADTeacher:mutators=flatten+safety_relabel_classes+filter_want_to_talk_about_labels+DIRECTOR_LTR_EMPTY,projects.director.tasks.safety:SafeAdvTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_EMPTY,projects.director.tasks.safety:SafeStdTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_EMPTY,projects.director.tasks.safety:SafeMultiTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_EMPTY,projects.director.tasks.safety:SafeWikiToxicTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_EMPTY --multitask-weights 5,1,1,1,1,1 --model projects.cringe.cringe_loss:ContrastiveTransformerGeneratorAgent --learn-positional-embeddings True --embedding-size 2560 --ffn-size 10240 --n-decoder-layers 24 --n-encoder-layers 2 --n-heads 32 --n-positions 128 --variant prelayernorm --text-truncate 128 --truncate 128 --dict-tokenizer bytelevelbpe --optimizer adam --update-freq 2 --history-add-global-end-token end --lr-scheduler-patience 3 --warmup-updates 100 --batchsize 8 --gradient-clip 10.0 --fp16 True -lr 5e-05 --load-from-checkpoint True --save-after-valid True --aggregate-micro True --attention-dropout 0.1 --dropout 0.1 --label-truncate 512 --relu-dropout 0.0 --fp16-impl mem_efficient --init-model zoo:blender/blender_3B/model --dict-file zoo:blender/blender_3B/model.dict --model-file .models/cringe/safe_bb1/model --model-parallel true + +``` + + +## Evaluate the CRINGE (single iter.) model on the safe generation task + +### Train the evaluation classifier +To evaluate if the model only generates safe utterances, we use an independently trained classifier. Here, we use the training +script from the [DIRECTOR](https://parl.ai/projects/director/): +``` + parlai train --task projects.director.tasks.safety:SafeBADTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeAdvTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeStdTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeMultiTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeWikiToxicTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeBADTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeAdvTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeStdTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeMultiTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeWikiToxicTeacher:mutators=flatten+safety_relabel_classes+neg_only -et projects.director.tasks.safety:SafeBADTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeAdvTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeStdTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeMultiTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeWikiToxicTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeBADTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeAdvTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeStdTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeMultiTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeWikiToxicTeacher:mutators=flatten+safety_relabel_classes+neg_only -vtim 120 --model transformer/classifier --load-from-pretrained-ranker True --init-model zoo:pretrained_transformers/bi_model_huge_reddit/model --dict-file zoo:pretrained_transformers/bi_model_huge_reddit/model.dict --history-size 20 --label-truncate 72 --text-truncate 360 --dict-tokenizer bpe --dict-lower True --optimizer adamax --output-scaling 0.06 --variant xlm --reduction-type mean --share-encoders False --learn-positional-embeddings True --n-layers 12 --n-heads 12 --ffn-size 3072 --attention-dropout 0.1 --relu-dropout 0.0 --dropout 0.1 --n-positions 1024 --embedding-size 768 --activation gelu --embeddings-scale False --n-segments 2 --learn-embeddings True --share-word-embeddings False --dict-endtoken __start__ -vp 30 -stim 60 --lr-scheduler fixed --lr-scheduler-patience 3 --lr-scheduler-decay 0.9 --warmup_updates 1000 --fp16 true -lr 5e-05 --classes pos neg -bs 20 --validation-metric f1 --validation-metric-mode max --validation-max-exs 3000 --validation-patience 200 --log-every-n-secs 10 -ttim 34200 --load-from-checkpoint true --save-after-valid true --tensorboard-log true --aggregate-micro True --model-file ./models/safety/eval_model +``` + +### Evaluate the model checkpoint +``` +parlai em --batchsize 8 --log-every-n-secs 30 --fp16 True --metrics all --inference beam --beam-size 10 --beam-min-length 20 --beam-block-ngram 3 --beam-context-block-ngram 3 --beam-block-full-context True --skip-generation False --task projects.director.tasks.safety:SafeWikiToxicEvalTeacher:mutators=flatten+safety_relabel_classes+neg_only:eval_classifier_model_file=models/safety/eval_model:include_label_cand_only=true -dt valid --num-examples 1000 --model-file ./models/cringe/safe_bb1/model +``` + +## Iterative Training + +### Generate unsafe generations on the training examples +We use the model that we trained previously to generate episodes on the WikiToxic training data. We log all the results as WikiToxic_world_logs.jsonl. +``` +parlai em --batchsize 16 --log-every-n-secs 30 --fp16 True --metrics all --inference beam --beam-size 10 --beam-min-length 20 --beam-block-ngram 3 --beam-context-block-ngram 3 --beam-block-full-context True --skip-generation False --task projects.director.tasks.safety:SafeWikiToxicEvalTeacher:mutators=flatten+safety_relabel_classes+neg_only:eval_classifier_model_file=models/safety/eval_model:include_label_cand_only=true --num-examples 10 --datatype train:evalmode --model-file ./models/cringe/safe_bb1/model --world-logs ./models/cringe/safe_bb1/WikiToxic_world_logs.jsonl +``` + +### Filter the world logs +We filter the world logs to contain 50/50 negative and positive examples. The previously trained classifier determines the label. +``` +python projects/cringe/safety_filter_world_logs.py --world-logs-file ./models/cringe/safe_bb1/WikiToxic_world_logs.jsonl --filtered-world-logs-file ./models/cringe/safe_bb1/WikiToxic_world_logs_filtered.jsonl +``` + +### Display the filtered iterative training data +We display the new training data generated from the model. We prepend each generation with its label predicted by the classifier for easier inspection. +``` +parlai dd -t projects.cringe.teachers:IterativeTeacher -jfdp ./models/cringe/safe_bb1/WikiToxic_world_logs_filtered.jsonl --prepend-classifier-label true +``` + +### Iterative model finetuning +We finetune the model on the multitask dataset augmented with the generated utterances from the bot. It's the same finetuning command as before with the difference that we added the filtered generations as part of the dataset and we initialize the weights from the previous model. +``` +parlai train -t blended_skill_talk:mutators=flatten,projects.director.tasks.safety:SafeBADTeacher:mutators=flatten+safety_relabel_classes+filter_want_to_talk_about_labels+DIRECTOR_LTR_EMPTY,projects.director.tasks.safety:SafeAdvTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_EMPTY,projects.director.tasks.safety:SafeStdTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_EMPTY,projects.director.tasks.safety:SafeMultiTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_EMPTY,projects.director.tasks.safety:SafeWikiToxicTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_EMPTY,parlai_internal.projects.scones_director.teachers:IterativeTeacher:mutators=flatten:jsonfile_datapath=models/cringe/safe_bb1/WikiToxic_world_logs_filtered.jsonl --multitask-weights 5,1,1,1,1,1,1 --model projects.cringe.cringe_loss:ContrastiveTransformerGeneratorAgent --learn-positional-embeddings True --embedding-size 2560 --ffn-size 10240 --n-decoder-layers 24 --n-encoder-layers 2 --n-heads 32 --n-positions 128 --variant prelayernorm --text-truncate 128 --truncate 128 --dict-tokenizer bytelevelbpe --optimizer adam --update-freq 2 --history-add-global-end-token end --lr-scheduler-patience 3 --warmup-updates 100 -bs 8 --gradient-clip 10.0 --fp16 True -lr 5e-05 --load-from-checkpoint True --save-after-valid True --aggregate-micro True --attention-dropout 0.1 --dropout 0.1 --label-truncate 512 --relu-dropout 0.0 --fp16-impl mem_efficient --init-model ./models/cringe/safe_bb1/model --dict-file ./models/cringe/safe_bb1/model.dict --model-file .models/cringe/safe_bb1_iterative/model --model-parallel true +``` \ No newline at end of file diff --git a/projects/cringe/cringe_loss.py b/projects/cringe/cringe_loss.py new file mode 100644 index 00000000000..51f9cb3df93 --- /dev/null +++ b/projects/cringe/cringe_loss.py @@ -0,0 +1,331 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Transformer Agent with a contrastive loss. +""" +import torch +from torch.nn import CrossEntropyLoss +from torch.distributions.categorical import Categorical +from typing import Optional, Dict, Union +from parlai.core.message import Message + +from parlai.core.opt import Opt +from parlai.core.params import ParlaiParser +from parlai.agents.transformer.transformer import TransformerGeneratorAgent +from parlai.core.torch_generator_agent import PPLMetric +from parlai.core.metrics import AverageMetric + +from parlai.agents.fid.fid import ( + WizIntGoldDocRetrieverFiDAgent, +) +from projects.blenderbot2.agents.blenderbot2 import ( + BlenderBot2FidAgent, + BlenderBot2FidModel, + T5BlenderBot2FidModel, +) + + +class ContrastiveCrossEntropyLoss(CrossEntropyLoss): + def __init__( + self, + ct_loss_weight=1.0, + num_pos_predictions=1, + detach_positives_during_ct=False, + train_ct_on_positive_examples=False, + **kwargs, + ): + super().__init__(**kwargs) + self.ct_loss_weight = ct_loss_weight + self.num_pos_predictions = num_pos_predictions + self.detach_positives_during_ct = detach_positives_during_ct + self.train_ct_on_positive_examples = train_ct_on_positive_examples + + def __call__(self, x, y, classifier_labels=None, **kwargs): + if classifier_labels is None: + classifier_labels = -torch.ones_like(y).to(y.device) + + # turn no-class provided label (-1) into positive label (1) + classifier_labels_ce = torch.abs(classifier_labels) + + if self.train_ct_on_positive_examples: + # no-class (-1 to 0), positive (1 to 1), negative (0 to 1) + classifier_labels_ct = torch.clamp(classifier_labels + 1, max=1) + else: + # no-class (-1 to 0), positive (1 to 0), negative (0 to 1) + classifier_labels_ct = torch.abs(torch.abs(classifier_labels) - 1) + + ce_loss = super().__call__(x, y, **kwargs) + # multiply with classifier labels to not train with negative feedback (0) + ce_loss *= classifier_labels_ce + + # compute the contrastive loss part for the negative labels + # first, get the positives as the top predictions != target + preds = torch.topk(x, k=self.num_pos_predictions + 1, axis=-1) + y_rep = y.unsqueeze(1).repeat(1, self.num_pos_predictions + 1) + logits = preds.values - (preds.indices == y_rep) * 1e10 + + # if the positive is not in the first k predictions, mask out + # the final (k+1)'s logit + prediction_mask = torch.cat( + ( + torch.zeros_like(logits)[:, :-1], + torch.abs((preds.indices == y_rep).sum(-1).unsqueeze(1) - 1), + ), + 1, + ) + logits -= prediction_mask * 1e10 + + # Sample from the categorical distribution of the top-k predictions + # (with the label masked out). + preds_dist = Categorical(logits=logits) + idx_sample = preds_dist.sample() + sample_preds_values = preds.values[torch.arange(x.shape[0]), idx_sample] + + if self.detach_positives_during_ct: + sample_preds_values = sample_preds_values.detach() + + # concatenate the logits of the preds with the actual label's logits + x_target = x[torch.arange(x.shape[0]), y] + x_ct = torch.concat( + [x_target.unsqueeze(1), sample_preds_values.unsqueeze(1)], -1 + ) + # get the y's for the x_ct (the correct label is index 0 if + # the target is positive and index 1 if the target is negative) + y_ct = torch.abs(torch.abs(classifier_labels) - 1).type(y.dtype).to(x_ct.device) + # y_ct = (torch.ones(y.shape) * ).type(y.dtype).to(x_ct.device) + # compute the contrastive loss as cross entropy loss between x_ct, y_ct + ct_loss = super().__call__(x_ct, y_ct, **kwargs) + ct_loss *= classifier_labels_ct + + # remove loss from ignore index + notnull = y.ne(self.ignore_index) + ce_loss *= notnull + ct_loss *= notnull + + loss = ce_loss + self.ct_loss_weight * ct_loss + + return loss, ce_loss, ct_loss + + +class ContrastiveTransformerGeneratorAgent(TransformerGeneratorAgent): + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + """ + Add command line arguments. + """ + agent = parser.add_argument_group( + 'ContrastiveTransformerGeneratorAgent arguments' + ) + parser.add_argument( + '--ct-loss-weight', + type=float, + help='Coefficient for the contrastive loss (negative examples).', + default=1.0, + ) + parser.add_argument( + '--ct-num-pos-predictions', + type=int, + help='How many top predictions do we consider as positives for the contrastive loss?', + default=1, + ) + parser.add_argument( + '--ct-detach-positives', + type=bool, + help='If true, we block the gradient for the positives during the contrastive loss.', + default=False, + ) + parser.add_argument( + '--train-ct-on-positive_examples', + type=bool, + help='If true, we train with the positive examples in the contrastive loss' + ' (with the negatives being the top-k sampled from the model).', + default=False, + ) + super().add_cmdline_args(parser, partial_opt=partial_opt) + return agent + + def build_criterion(self): + return ContrastiveCrossEntropyLoss( + ct_loss_weight=self.opt['ct_loss_weight'], + num_pos_predictions=self.opt['ct_num_pos_predictions'], + detach_positives_during_ct=self.opt['ct_detach_positives'], + ignore_index=self.NULL_IDX, + train_ct_on_positive_examples=self.opt['train_ct_on_positive_examples'], + reduction='none', + ) + + def _v2t(self, vec): + """ + This method wraps the vec2txt call in a try catch to ensure that sequences with + generation errors are ignored. + + We return a empty string instead in that scenario. + """ + try: + return super()._v2t(vec) + except AssertionError: + return '' + + def observe(self, observation: Union[Dict, Message]) -> Message: + observation = super().observe(observation) + if 'is_ltr' not in observation: + observation['is_ltr'] = False + observation['classifier_label'] = 'none' + observation['classifier_label_idx'] = -1 + return observation + + classifier_label = observation['classifier_label'] + if classifier_label == 'pos': + observation['classifier_label_idx'] = 1 + elif classifier_label == 'neg': + observation['classifier_label_idx'] = 0 + return observation + + def batchify(self, obs_batch, sort=False): + """ + This method calls the parent class's batchify method and then add + classifier_label and is_ltr property to the the batch. + """ + batch = super().batchify(obs_batch, sort=sort) + + if batch.valid_indices is None: + return batch + + batch.classifier_label = torch.tensor( + [ + [obs_batch[i].get('classifier_label_idx', -1)] + for i in batch.valid_indices + ] + ) + batch.is_ltr = torch.tensor( + [[obs_batch[i].get('is_ltr', False)] for i in batch.valid_indices] + ) + return batch + + def compute_loss(self, batch, return_output=False): + if batch.label_vec is None: + raise ValueError('Cannot compute loss without a label.') + model_output = self.model(*self._model_input(batch), ys=batch.label_vec) + scores, preds, *_ = model_output + score_view = scores.reshape(-1, scores.size(-1)) + (loss, ce_loss, ct_loss,) = self.criterion( + score_view, + batch.label_vec.view(-1), + batch.classifier_label.repeat(1, scores.shape[1]) + .view(-1) + .to(batch.label_vec.device), + ) + + def loss_reshape(loss): + return loss.view(scores.shape[:-1]).sum(dim=1) + + loss = loss_reshape(loss) + ce_loss = loss_reshape(ce_loss) + ct_loss = loss_reshape(ct_loss) + notnull = batch.label_vec.ne(self.NULL_IDX) + target_tokens = notnull.long().sum(dim=-1) + correct = ((batch.label_vec == preds) * notnull).sum(dim=-1) + + pos_labels = (torch.abs(batch.classifier_label) == 1).view(-1) + neg_labels = (torch.abs(batch.classifier_label) == 0).view(-1) + correct_pos = torch.where(pos_labels, correct, -1) + correct_neg = torch.where(neg_labels, correct, -1) + + # record loss + self.record_local_metric('loss', AverageMetric.many(loss, target_tokens)) + self.record_local_metric( + 'ce_loss', + [ + metric if metric > 0.0 else None + for metric in AverageMetric.many(ce_loss, target_tokens) + ], + ) + self.record_local_metric( + 'ct_loss', + [ + metric if metric > 0.0 else None + for metric in AverageMetric.many(ct_loss, target_tokens) + ], + ) + # token-wise accuracy + self.record_local_metric( + 'token_acc', AverageMetric.many(correct, target_tokens) + ) + self.record_local_metric( + 'token_acc_pos', + [ + metric if metric >= 0 else None + for metric in AverageMetric.many(correct_pos, target_tokens) + ], + ) + self.record_local_metric( + 'token_acc_neg', + [ + metric if metric >= 0 else None + for metric in AverageMetric.many(correct_neg, target_tokens) + ], + ) + # perplexity + self.record_local_metric( + 'ppl_debug', PPLMetric.many(ce_loss + ct_loss, target_tokens) + ) + self.record_local_metric( + 'ppl_pos', + [ + metric if pos_label else None + for pos_label, metric in zip( + pos_labels, PPLMetric.many(ce_loss, target_tokens) + ) + ], + ) + self.record_local_metric( + 'ppl_ct', + [ + metric if neg_label else None + for neg_label, metric in zip( + neg_labels, PPLMetric.many(ct_loss, target_tokens) + ) + ], + ) + + # actually do backwards loss + loss = loss.sum() + loss /= target_tokens.sum() # average loss per token + if return_output: + return (loss, model_output) + else: + return loss + + +class ContrastiveBB2Agent(ContrastiveTransformerGeneratorAgent, BlenderBot2FidAgent): + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + """ + Add command line arguments. + """ + super().add_cmdline_args(parser, partial_opt=partial_opt) + BlenderBot2FidAgent.add_cmdline_args(parser, partial_opt=partial_opt) + return parser + + def build_model(self) -> BlenderBot2FidModel: + if self.generation_model == 't5': + model = T5BlenderBot2FidModel(self.opt, self.dict) + else: + model = BlenderBot2FidModel(self.opt, self.dict) + if self.opt['embedding_type'] != 'random': + self._copy_embeddings( + model.encoder.embeddings.weight, self.opt['embedding_type'] + ) + return model + + +class ContrastiveBB2WizIntGoldDocRetrieverFiDAgent( + WizIntGoldDocRetrieverFiDAgent, ContrastiveBB2Agent +): + pass diff --git a/projects/cringe/safety_filter_world_logs.py b/projects/cringe/safety_filter_world_logs.py new file mode 100644 index 00000000000..7f3cf13d43e --- /dev/null +++ b/projects/cringe/safety_filter_world_logs.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List +import argparse +import json +import os +from tqdm import tqdm +import random +import collections + + +def filter_world_logs_for_classifier_accuracy( + world_logs_files: List[str], filtered_world_logs_file: str +): + filtered_data_pos = [] + filtered_data_neg = [] + + all_world_logs_files = [] + for file in world_logs_files: + if not file.endswith('jsonl'): + # Is it a directory that we can expand? + all_world_logs_files.extend( + [ + os.path.join(file, f) + for f in os.listdir(file) + if 'world_logs' in f and f.endswith('jsonl') + ] + ) + else: + all_world_logs_files.append(file) + world_logs_files = all_world_logs_files + + count = 0 + pos_labels_count = collections.defaultdict(int) + for world_logs_file in tqdm(world_logs_files): + with open(world_logs_file, 'r') as f: + for line in f.readlines(): + count += 1 + line_dict = json.loads(line) + if line_dict['dialog'][0][1]['metrics']['classifier_accuracy'] == 0.0: + filtered_data_neg.append(line) + elif line_dict['dialog'][0][1]['metrics']['classifier_accuracy'] == 1.0: + label = line_dict['dialog'][0][1]['text'] + # Allow at most twice the same positive generation. + if label in pos_labels_count and pos_labels_count[label] >= 2: + continue + pos_labels_count[label] += 1 + filtered_data_pos.append(line) + + num_filtered_data = min(len(filtered_data_neg), len(filtered_data_pos)) + filtered_data = ( + filtered_data_neg[:num_filtered_data] + filtered_data_pos[:num_filtered_data] + ) + random.shuffle(filtered_data) + + with open(filtered_world_logs_file, 'w') as f: + for line in filtered_data: + f.write(line) + + print( + f'Wrote {len(filtered_data)}/{count} examples to filtered log file: {filtered_world_logs_file}' + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world-logs-file', type=str, help='') + parser.add_argument('--filtered-world-logs-file', type=str, help='') + args = parser.parse_args() + world_logs = args.world_logs_file.split(',') + filter_world_logs_for_classifier_accuracy(world_logs, args.filtered_world_logs_file) diff --git a/projects/cringe/teachers.py b/projects/cringe/teachers.py new file mode 100644 index 00000000000..de86de66033 --- /dev/null +++ b/projects/cringe/teachers.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from typing import Optional + +from parlai.core.params import ParlaiParser +from parlai.core.opt import Opt +from parlai.tasks.jsonfile.agents import JsonTeacher + + +class IterativeTeacher(JsonTeacher): + delete_tokens = ['_POTENTIALLY_UNSAFE__'] + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + super().add_cmdline_args(parser, partial_opt) + agent = parser.add_argument_group('IterativeTeacher options') + agent.add_argument( + '--prepend-classifier-label', + type=bool, + default=False, + help='If true, prepend the classifier label to the generation label. This should only be used to inspect the data', + ) + return parser + + def __init__(self, opt, shared=None): + assert ( + 'jsonfile_datapath' in opt + ), 'You need to provide the --jsonfile-datapath flag for the IterativeTeacher.' + super().__init__(opt, shared) + + def setup_data(self, path): + for example, episode_end in super().setup_data(path): + example['is_ltr'] = True + labels = example.pop('labels') + for word in self.delete_tokens: + labels = [l.replace(word, '').strip() for l in labels] + example['labels'] = [l.strip() for l in labels] + + if self.opt.get('prepend_classifier_label', False): + example['labels'][0] = ( + example['classifier_label'] + ': ' + example['labels'][0] + ) + + yield example, episode_end + + def _get_ep_from_turns(self, xturns, yturns): + eps = [] + for xturn, yturn in zip(xturns, yturns): + turn = {} + turn['text'] = xturn.get('text').strip() + turn['labels'] = [yturn.get('text').strip()] + if 'pos_classifier_prediction' in yturn['metrics']: + class_label_int = int(yturn['metrics']['pos_classifier_prediction']) + turn['classifier_label'] = 'neg' if class_label_int == 0 else 'pos' + elif 'classifier_accuracy' in yturn['metrics']: + class_label_int = int(yturn['metrics']['classifier_accuracy']) + turn['classifier_label'] = 'neg' if class_label_int == 0 else 'pos' + elif 'f1' in yturn['metrics']: + turn['classifier_label'] = ( + 'neg' if float(yturn['metrics']['f1']) == 0.0 else 'pos' + ) + eps.append(turn) + return eps From 2853f8a83924cfa63ce944d6385fbced618d657b Mon Sep 17 00:00:00 2001 From: Jason Weston Date: Thu, 10 Nov 2022 14:07:32 -0500 Subject: [PATCH 2/4] oops --- parlai/mutators/notokonly.py | 29 ----------------------------- 1 file changed, 29 deletions(-) delete mode 100644 parlai/mutators/notokonly.py diff --git a/parlai/mutators/notokonly.py b/parlai/mutators/notokonly.py deleted file mode 100644 index 91745817434..00000000000 --- a/parlai/mutators/notokonly.py +++ /dev/null @@ -1,29 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import random -from typing import List -from parlai.core.message import Message -from parlai.core.mutators import ManyEpisodeMutator, register_mutator - - -@register_mutator("notokonly") -class NotOKMutator(ManyEpisodeMutator): - """ - Flattens the entire conversation history. - - Simply concatenates all turns in the conversation with a newline. Frequently useful - when composed with other mutators. - """ - - def many_episode_mutation(self, episode: List[Message]) -> List[List[Message]]: - history = [] - for message in episode: - history.append(message.pop('text')) - message['text'] = '\n'.join(history) - if message['labels'][0] == '__notok__': - yield [message] - history.append(random.choice(message['labels'])) From 76d1ec218306ccafc3822cab71c6ae7606146b32 Mon Sep 17 00:00:00 2001 From: Jason Weston Date: Thu, 10 Nov 2022 14:31:42 -0500 Subject: [PATCH 3/4] bold --- projects/cringe/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/cringe/README.md b/projects/cringe/README.md index 686a889732f..12807f0b348 100644 --- a/projects/cringe/README.md +++ b/projects/cringe/README.md @@ -8,7 +8,7 @@ Standard language model training employs gold human documents or human-human int treats all training data as positive examples. Growing evidence shows that even with very large amounts of positive training data, issues remain that can be alleviated with relatively small amounts of negative data -- examples of what the model should not do. -In this work, we propose a novel procedure to train with such data called the Cringe loss +In this work, we propose a novel procedure to train with such data called the CRINGE loss (ContRastive Iterative Negative GEneration). We show the effectiveness of this approach across three different experiments on the tasks of safe generation, contradiction avoidance, and open-domain dialogue. Our models outperform multiple strong baselines and are From f4acd3586c14ff30524e60af76076739be6d9290 Mon Sep 17 00:00:00 2001 From: Jason Weston Date: Thu, 10 Nov 2022 14:36:29 -0500 Subject: [PATCH 4/4] scones fix --- projects/cringe/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/cringe/README.md b/projects/cringe/README.md index 12807f0b348..c93c9285050 100644 --- a/projects/cringe/README.md +++ b/projects/cringe/README.md @@ -64,5 +64,5 @@ parlai dd -t projects.cringe.teachers:IterativeTeacher -jfdp ./models/cringe/saf ### Iterative model finetuning We finetune the model on the multitask dataset augmented with the generated utterances from the bot. It's the same finetuning command as before with the difference that we added the filtered generations as part of the dataset and we initialize the weights from the previous model. ``` -parlai train -t blended_skill_talk:mutators=flatten,projects.director.tasks.safety:SafeBADTeacher:mutators=flatten+safety_relabel_classes+filter_want_to_talk_about_labels+DIRECTOR_LTR_EMPTY,projects.director.tasks.safety:SafeAdvTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_EMPTY,projects.director.tasks.safety:SafeStdTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_EMPTY,projects.director.tasks.safety:SafeMultiTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_EMPTY,projects.director.tasks.safety:SafeWikiToxicTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_EMPTY,parlai_internal.projects.scones_director.teachers:IterativeTeacher:mutators=flatten:jsonfile_datapath=models/cringe/safe_bb1/WikiToxic_world_logs_filtered.jsonl --multitask-weights 5,1,1,1,1,1,1 --model projects.cringe.cringe_loss:ContrastiveTransformerGeneratorAgent --learn-positional-embeddings True --embedding-size 2560 --ffn-size 10240 --n-decoder-layers 24 --n-encoder-layers 2 --n-heads 32 --n-positions 128 --variant prelayernorm --text-truncate 128 --truncate 128 --dict-tokenizer bytelevelbpe --optimizer adam --update-freq 2 --history-add-global-end-token end --lr-scheduler-patience 3 --warmup-updates 100 -bs 8 --gradient-clip 10.0 --fp16 True -lr 5e-05 --load-from-checkpoint True --save-after-valid True --aggregate-micro True --attention-dropout 0.1 --dropout 0.1 --label-truncate 512 --relu-dropout 0.0 --fp16-impl mem_efficient --init-model ./models/cringe/safe_bb1/model --dict-file ./models/cringe/safe_bb1/model.dict --model-file .models/cringe/safe_bb1_iterative/model --model-parallel true +parlai train -t blended_skill_talk:mutators=flatten,projects.director.tasks.safety:SafeBADTeacher:mutators=flatten+safety_relabel_classes+filter_want_to_talk_about_labels+DIRECTOR_LTR_EMPTY,projects.director.tasks.safety:SafeAdvTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_EMPTY,projects.director.tasks.safety:SafeStdTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_EMPTY,projects.director.tasks.safety:SafeMultiTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_EMPTY,projects.director.tasks.safety:SafeWikiToxicTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_EMPTY,projects.cringe.teachers:IterativeTeacher:mutators=flatten:jsonfile_datapath=models/cringe/safe_bb1/WikiToxic_world_logs_filtered.jsonl --multitask-weights 5,1,1,1,1,1,1 --model projects.cringe.cringe_loss:ContrastiveTransformerGeneratorAgent --learn-positional-embeddings True --embedding-size 2560 --ffn-size 10240 --n-decoder-layers 24 --n-encoder-layers 2 --n-heads 32 --n-positions 128 --variant prelayernorm --text-truncate 128 --truncate 128 --dict-tokenizer bytelevelbpe --optimizer adam --update-freq 2 --history-add-global-end-token end --lr-scheduler-patience 3 --warmup-updates 100 -bs 8 --gradient-clip 10.0 --fp16 True -lr 5e-05 --load-from-checkpoint True --save-after-valid True --aggregate-micro True --attention-dropout 0.1 --dropout 0.1 --label-truncate 512 --relu-dropout 0.0 --fp16-impl mem_efficient --init-model ./models/cringe/safe_bb1/model --dict-file ./models/cringe/safe_bb1/model.dict --model-file .models/cringe/safe_bb1_iterative/model --model-parallel true ``` \ No newline at end of file