diff --git a/.gitignore b/.gitignore index b6e4761..b8da8a4 100644 --- a/.gitignore +++ b/.gitignore @@ -103,6 +103,7 @@ celerybeat.pid # Environments .env +.cenv .venv env/ venv/ diff --git a/agents/__init__.py b/agents/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/agents/history.py b/agents/history.py new file mode 100755 index 0000000..3bf19f8 --- /dev/null +++ b/agents/history.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This file is derived from parlai/core/seq2seq/seq2seq.py. +In particular, it's derived from an older version that inherits from TorchAgent rather +than TorchGeneratorAgent. +It should be possible to refactor this file to be comparable to the current +parlai/core/seq2seq/seq2seq.py, i.e. inherit from TorchGeneratorAgent - this would +probably reduce the amount of boilerplate in this file. +However, for simplicity and to keep things as similar as possible to the version used +for the paper, we have kept this file mostly the same. +""" + +from parlai.core.torch_agent import Batch, History, TorchAgent +from parlai.core.torch_generator_agent import TorchGeneratorAgent +from parlai.utils.torch import padded_tensor, argsort +# from .base_controllable_seq2seq import BaseControllableSeq2seqAgent +# from .util import ConvAI2History +# from .controls import get_ctrl_vec + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from collections import defaultdict, namedtuple, Counter, deque +from operator import attrgetter + +import os +import math +import json +import tempfile +import copy +from itertools import chain + + +def list_to_matrix(l, n): + return [l[i:i+n] for i in range(0, len(l), n)] + + +class SelfConsciousHistory(History): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + opt = args[0] + if opt['eval_type'] == 'convai2': + self.add_person_tokens = True + elif opt['eval_type'] == 'dnli': + self.add_person_tokens = False + else: + raise ValueError + + self.world_cardinality = opt.get('world_cardinality', 5) + self.history_distractor_strings = [[] for _ in range(self.world_cardinality)] + self.history_distractor_raw_strings = [[] for _ in range(self.world_cardinality)] + self.history_distractor_vecs = [[] for _ in range(self.world_cardinality)] + # Will be used for TransferTransfo + self.history_token_type_ids = [] + self.history_distractor_token_type_ids = [[] for _ in range(self.world_cardinality)] + + def reset(self): + """Clear the history""" + super().reset() + self.history_distractor_strings = [[] for _ in range(self.world_cardinality)] + self.history_distractor_raw_strings = [[] for _ in range(self.world_cardinality)] + self.history_distractor_vecs = [[] for _ in range(self.world_cardinality)] + self.history_token_type_ids = [] + self.history_distractor_token_type_ids = [[] for _ in range(self.world_cardinality)] + + def _update_distractor_strings(self, text, idx): + history_strings = self.history_distractor_strings[idx] + if self.size > 0: + while len(history_strings) >= self.size: + history_strings.pop(0) + history_strings.append(text) + + def _update_distractor_raw_strings(self, text, idx): + history_raw_strings = self.history_distractor_raw_strings[idx] + if self.size > 0: + while len(history_raw_strings) >= self.size: + history_raw_strings.pop(0) + history_raw_strings.append(text) + + def _update_distractor_vecs(self, text, idx): + history_vecs = self.history_distractor_vecs[idx] + if self.size > 0: + while len(history_vecs) >= self.size: + history_vecs.pop(0) + history_vecs.append(self.parse(text)) + + def _update_token_type_ids(self, text, idx): + pass + + def add_reply_to_distractors(self, model_reply): + + # Update model's response to the history + if model_reply is not None: + for idx in range(self.world_cardinality): + self._update_distractor_raw_strings(model_reply, idx) + # this is causing the repetition of p2 token. + # need to do this only once. not every loop + if self.add_person_tokens and idx == 0: + model_reply = self._add_person_tokens(model_reply, self.p2_token) + self._update_distractor_strings(model_reply, idx) + self._update_distractor_vecs(model_reply, idx) + + # def update_history(self, obs, add_next=None): + def update_history(self, obs, temp_history=None): + """ + Update the history with the given observation. + :param add_next: + string to append to history prior to updating it with the + observation + """ + # super().update_history(obs, add_next) + super().update_history(obs, temp_history=temp_history) + + # Update previous turn's my response + # if add_next is not None: + # for idx in range(self.world_cardinality): + # self._update_distractor_raw_strings(add_next, idx) + # # this is causing the repetition of p2 token. + # # need to do this only once. not every loop + # if self.add_person_tokens and idx == 0: + # add_next = self._add_person_tokens(add_next, self.p2_token) + # self._update_distractor_strings(add_next, idx) + # self._update_distractor_vecs(add_next, idx) + + # Update current turn's opponent's response + if 'distractor_text' in obs: + assert len(obs['distractor_text']) == self.world_cardinality, \ + f"Numer of distractor_text must be eqaul to world_cardinality. ({len(obs['distractor_text'])} vs {self.world_cardinality})" + for idx, distractor_text in enumerate(obs['distractor_text']): + if self.split_on_newln: + next_texts = distractor_text.split('\n') + else: + next_texts = [distractor_text] + for text in next_texts: + self._update_distractor_raw_strings(text, idx) + if self.add_person_tokens: + text = self._add_person_tokens( + distractor_text, self.p1_token, self.add_p1_after_newln + ) + self._update_distractor_strings(text, idx) + self._update_distractor_vecs(text, idx) + + def get_history_distractor_str(self): + """Return the list of string version of the distractor histories.""" + if len(self.history_distractor_strings[0]) > 0: + return [ + self.delimiter.join(history_strings) + for history_strings in self.history_distractor_strings + ] + return None + + def get_history_distractor_vec(self): + """Return a vectorized version of the distractor histories.""" + if len(self.history_distractor_vecs[0]) == 0: + return None + + histories = [] + for idx in range(self.world_cardinality): + history_vecs = self.history_distractor_vecs[idx] + + # if self.vec_type == 'deque': + # history = deque(maxlen=self.max_len) + # for vec in history_vecs[:-1]: + # history.extend(vec) + # history.extend(self.delimiter_tok) + # history.extend(history_vecs[-1]) + # else: + # vec type is a list + history = [] + for vec in history_vecs[:-1]: + history += vec + history += self.delimiter_tok + history += history_vecs[-1] + + histories.append(history) + return histories + + def get_token_type_ids(self): + """ + Return a vectorized version of the token_type_ids and + distractor_token_type_ids + """ + pass + + +class ContextConsciousHistory(History): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + opt = args[0] + if opt['eval_type'] == 'convai2': + self.add_person_tokens = True + elif opt['eval_type'] == 'dnli': + self.add_person_tokens = False + else: + raise ValueError + + self.world_cardinality = opt.get('world_cardinality', 5) + self.history_distractor_strings = [[] for _ in range(self.world_cardinality)] + self.history_distractor_raw_strings = [[] for _ in range(self.world_cardinality)] + self.history_distractor_vecs = [[] for _ in range(self.world_cardinality)] + # Will be used for TransferTransfo + self.history_token_type_ids = [] + self.history_distractor_token_type_ids = [[] for _ in range(self.world_cardinality)] + self.eval_type = opt.get('eval_type') + + def reset(self): + """Clear the history""" + super().reset() + self.history_distractor_strings = [[] for _ in range(self.world_cardinality)] + self.history_distractor_raw_strings = [[] for _ in range(self.world_cardinality)] + self.history_distractor_vecs = [[] for _ in range(self.world_cardinality)] + self.history_token_type_ids = [] + self.history_distractor_token_type_ids = [[] for _ in range(self.world_cardinality)] + + def _update_distractor_strings(self, text, idx): + history_strings = self.history_distractor_strings[idx] + if self.size > 0: + while len(history_strings) >= self.size: + history_strings.pop(0) + history_strings.append(text) + + def _update_distractor_raw_strings(self, text, idx): + history_raw_strings = self.history_distractor_raw_strings[idx] + if self.size > 0: + while len(history_raw_strings) >= self.size: + history_raw_strings.pop(0) + history_raw_strings.append(text) + + def _update_distractor_vecs(self, text, idx): + history_vecs = self.history_distractor_vecs[idx] + if self.size > 0: + while len(history_vecs) >= self.size: + history_vecs.pop(0) + history_vecs.append(self.parse(text)) + + def _update_token_type_ids(self, text, idx): + pass + + def add_reply_to_distractors(self, model_reply, obs=None): + + # Update model's response along with distractor responses to the history + if model_reply is not None and 'distractor_text' in obs: + distractor_responses = obs['distractor_text'] + assert len(obs['distractor_text']) == self.world_cardinality + + for idx in range(self.world_cardinality): + self._update_distractor_raw_strings(distractor_responses[idx], idx) + if self.add_person_tokens: + distractor_responses[idx] = self._add_person_tokens(distractor_responses[idx], self.p2_token) + self._update_distractor_strings(distractor_responses[idx], idx) + self._update_distractor_vecs(distractor_responses[idx], idx) + + # def update_history(self, obs, add_next=None): + def update_history(self, obs, temp_history=None): + """ + Update the history with the given observation. + :param add_next: + string to append to history prior to updating it with the + observation + """ + super().update_history(obs, temp_history=temp_history) + + # Update current turn's opponent's response + if self.eval_type == 'convai2': + if 'text' in obs: + for idx in range(self.world_cardinality): + if self.split_on_newln: + next_texts = obs['text'].split('\n') + else: + next_texts = [obs['text']] + for text in next_texts: + self._update_distractor_raw_strings(text, idx) + if self.add_person_tokens: + text = self._add_person_tokens( + obs['text'], self.p1_token, self.add_p1_after_newln + ) + self._update_distractor_strings(text, idx) + self._update_distractor_vecs(text, idx) + else: + if 'distractor_text' in obs: + distractor_texts = obs['distractor_text'] + for idx, distractor in enumerate(distractor_texts): + self._update_distractor_raw_strings(distractor, idx) + self._update_distractor_strings(distractor, idx) + self._update_distractor_vecs(distractor, idx) + + def get_history_distractor_str(self): + """Return the list of string version of the distractor histories.""" + if len(self.history_distractor_strings[0]) > 0: + return [ + self.delimiter.join(history_strings) + for history_strings in self.history_distractor_strings + ] + return None + + def get_history_distractor_vec(self): + """Return a vectorized version of the distractor histories.""" + if len(self.history_distractor_vecs[0]) == 0: + return None + + histories = [] + for idx in range(self.world_cardinality): + history_vecs = self.history_distractor_vecs[idx] + + # if self.vec_type == 'deque': + # history = deque(maxlen=self.max_len) + # for vec in history_vecs[:-1]: + # history.extend(vec) + # history.extend(self.delimiter_tok) + # history.extend(history_vecs[-1]) + # else: + # vec type is a list + history = [] + for vec in history_vecs[:-1]: + history += vec + history += self.delimiter_tok + history += history_vecs[-1] + + histories.append(history) + return histories + + def get_token_type_ids(self): + """ + Return a vectorized version of the token_type_ids and + distractor_token_type_ids + """ + pass diff --git a/agents/modules.py b/agents/modules.py new file mode 100644 index 0000000..60cf07e --- /dev/null +++ b/agents/modules.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Implements NN code for transformers. + +Original paper: https://arxiv.org/abs/1706.03762. (Vaswani, 2017). The +`Annotated Transformer` (Rush, 2018) is an excellent reading guide which explains +much of the mechanics of the Transformer model +(http://nlp.seas.harvard.edu/2018/04/03/attention.html). + +This module also supports special segments (ala BERT; +https://arxiv.org/abs/1810.04805), and a few different variations seen in the +literature (BERT and XLM; https://arxiv.org/abs/1901.07291). +""" + +import math +from typing import Dict, Tuple, Optional + +import numpy as np +import torch +import torch.cuda +import torch.nn as nn +import torch.nn.functional as F + +from parlai.core.torch_generator_agent import TorchGeneratorModel +from parlai.agents.transformer.modules import TransformerGeneratorModel + + +class SelfConsciousTransformerModel(TransformerGeneratorModel): + """ + Implements a full transformer generator model, with pragmatic self-consciousness. + """ + + def __init__(self, opt, dictionary): + super().__init__(opt, dictionary) + + self.alpha = 0.0 if opt['conscious_target'] == 'none' else opt['alpha'] + self.beta = opt['beta'] + self.world_cardinality = opt['world_cardinality'] + self.worldprior = opt['worldprior'] + self.target_persona = 0 + self.fp16 = opt['fp16'] + + def _initialize_worldpriors(self, bsz, seqlen): + """ + initialize the world prior with a uniform distribution + """ + cardinality = self.world_cardinality + torch_dtype=torch.half if self.fp16 else torch.float + ones = torch.ones(1, seqlen, cardinality, dtype=torch_dtype, requires_grad=False).cuda() + uniform_world_prior = torch.log(ones / cardinality) + world_priors = uniform_world_prior.repeat(bsz, 1, 1).detach() + + return world_priors + + def _pragmatic_reasoning(self, s0_t, worldprior): + """ + run pragmatic reasoning with the base speaker and its imaginary listener + """ + + vocab_size = self.embeddings.num_embeddings + + # log-scale + log_score = nn.functional.log_softmax(s0_t, dim=2) + log_score = log_score.squeeze() # (bpsz, vocab) + + # (bsz, world_cardinality, vocab) + log_score = log_score.view(-1, self.world_cardinality, vocab_size) + + # S_0 for L_1 + _literal_speaker = log_score.clone() + _literal_speaker, _literal_s_next_token_idxs = torch.max(_literal_speaker, dim=-1, keepdim=True) + + # S_0 for the actual given persona (bsz, vocab) + speaker_prior = log_score.select(1, self.target_persona) # target persona is always index 0 + + # S_0 for L_0 + # (bsz, vocab, world_cardinality) + log_score = log_score.transpose(dim0=1, dim1=2).contiguous() + log_score = log_score * self.beta + + # L_0 \propto S_0 * p(i) + # worldprior should be broadcasted to all the tokens + # (bsz, vocab, world_cardinality) + listener_posterior = (log_score + worldprior) - torch.logsumexp(log_score + worldprior, 2, keepdim=True) + + # (bsz, vocab) + listener_score = listener_posterior.select(2, self.target_persona) # target persona is always index 0 + listener_score = listener_score * self.alpha + + speaker_posterior = (listener_score + speaker_prior) - torch.logsumexp(listener_score + speaker_prior, 1, keepdim=True) + + # need to unsqueeze in the dimension 1 + speaker_posterior = speaker_posterior.unsqueeze(1) # (bsz, 1, vocab) + + # L_0 for L_1 + _literal_listener = listener_posterior.transpose(dim0=1, dim1=2).contiguous() + _literal_listener = torch.gather(_literal_listener, -1, _literal_s_next_token_idxs) + + pragmatic_listener = (_literal_speaker + _literal_listener) - torch.logsumexp(_literal_speaker + _literal_listener, 1, keepdim=True) + pragmatic_listener = pragmatic_listener.squeeze() + + return speaker_posterior, listener_posterior, pragmatic_listener + + def selfconscious_decode(self, encoder_states, maxlen): + """ + greedy decoding with pragmatic self-consciousness + """ + bpsz = encoder_states[0].size(0) + bsz = bpsz // self.world_cardinality + + inputs_t = self.START.detach().expand(bpsz, 1) + worldpriors = self._initialize_worldpriors(bsz, maxlen).detach() + + s1_scores = [] + incr_state = None + + for t in range(maxlen): + worldprior_t = worldpriors.select(1, t).unsqueeze(1) + + latent, incr_state = self.decoder(inputs_t, encoder_states, incr_state) + _logits = self.output(latent) + # only get the last timestep's logit + s0_t = _logits.select(dim=1, index=-1).unsqueeze(1) # logits shape: (bpsz, 1, vocab) + + # s1_t: (bsz, 1, vocab) + # listener_posterior: (bsz, vocab, world_cardinality) + s1_t, l0_t, l1_t = self._pragmatic_reasoning(s0_t, worldprior_t) + s1_scores.append(s1_t) + + next_token = s1_t.max(2)[1].clone().detach() # next input is current predicted output idx + + idx_for_tile = torch.arange(bsz).repeat(self.world_cardinality, 1).transpose(0, 1).reshape(-1).cuda() + inputs_next_t = torch.index_select(next_token, 0, idx_for_tile) + next_token = next_token.unsqueeze(2) + tiled_next_token = next_token.repeat(1, 1, self.world_cardinality) + + if self.worldprior != 'uniform': + # (bsz, vocab, world_cardinality) -> (bsz, 1, world_cardinality) + updated_world_prior = torch.gather(l0_t, 1, tiled_next_token).clone().detach() + if t + 1 < maxlen: + if self.worldprior == 'L0': + worldpriors[:, t + 1, :] = updated_world_prior.squeeze() + elif self.worldprior == 'L1': + worldpriors[:, t + 1, :] = l1_t + else: + raise NotImplementedError + + # update inputs for next timestep + inputs_t = torch.cat((inputs_t, inputs_next_t), dim=1) + + s1_scores = torch.cat(s1_scores, dim=1) # (bsz, seqlen, vocab) + _, preds = s1_scores.max(dim=2) + + return preds, s1_scores + + def selfconscious_decode_forced(self, encoder_states, ys): + """ + faster teacher-forced decoding with pragmatic self-consciousness + """ + + bsz = ys.size(0) + seqlen = ys.size(1) + self.longest_label = max(self.longest_label, seqlen) + emb_size = self.encoder.embedding_size + enc_outputs = encoder_states[0].view(bsz * self.world_cardinality, -1, emb_size).contiguous() + enc_outputs_mask = encoder_states[1].view(bsz * self.world_cardinality, -1).contiguous() + enc_states = (enc_outputs, enc_outputs_mask) + bpsz = enc_outputs.size(0) + + # tile ys as much as the world_cardinality + idx_for_tile = torch.arange(bsz).repeat(self.world_cardinality, 1).transpose(0, 1).reshape(-1).cuda() + tiled_ys = torch.index_select(ys, 0, idx_for_tile) + + inputs = tiled_ys.narrow(1, 0, seqlen - 1) + inputs = torch.cat([self.START.detach().expand(bpsz, 1), inputs], 1) + worldpriors = self._initialize_worldpriors(bsz, seqlen).detach() + s1_scores = [] + + latent, _ = self.decoder(inputs, enc_states) + base_speaker = self.output(latent) + + for t in range(seqlen): + + s0_t = base_speaker.select(dim=1, index=t).unsqueeze(1) # s0_t: (bpsz, 1, vocab) + worldprior_t = worldpriors.select(dim=1, index=t).unsqueeze(1) + + # s1_t: (bsz, 1, vocab) + # l0_t: (bsz, vocab, world_cardinality) + s1_t, l0_t, l1_t = self._pragmatic_reasoning(s0_t, worldprior_t) + s1_scores.append(s1_t) + + # Update world_prior with listener posterior + if t + 1 < seqlen: + next_tokens = inputs.select(1, t + 1).view(-1, 1) # (bpsz, 1): the next tokens for each bpsz instance + next_tokens = next_tokens.unsqueeze(2) + # [0, 1*world_cardinality, 2*wc, 3*wc, ..., bpsz - 1wc] -> to get the ground-truth personas + target_persona_idxs = torch.arange(bsz).cuda() * (self.world_cardinality) + + # we only need the next token of the ground-truth persona + next_token = torch.index_select(next_tokens, 0, target_persona_idxs) # (bsz, 1, 1) + tiled_next_token = next_token.repeat(1, 1, self.world_cardinality) # (bsz, 1, world_cardinality) + + if self.worldprior != 'uniform': + # (bsz, vocab, world_cardinality) -> (bsz, 1, world_cardinality) + updated_world_prior = torch.gather(l0_t, 1, tiled_next_token).clone().detach() + if self.worldprior == 'L0': + worldpriors[:, t + 1, :] = updated_world_prior.squeeze() + elif self.worldprior == 'L1': + worldpriors[:, t + 1, :] = l1_t + else: + raise NotImplementedError + + s1_scores = torch.cat(s1_scores, 1) # (bsz, seqlen, vocab) + _, preds = s1_scores.max(dim=2) + + return s1_scores, preds diff --git a/agents/selfconscious_blender.py b/agents/selfconscious_blender.py new file mode 100644 index 0000000..8b938d6 --- /dev/null +++ b/agents/selfconscious_blender.py @@ -0,0 +1,543 @@ +import os +import random +import numpy as np +from itertools import chain + +import torch +import torch.nn.functional as F + +from parlai.core.opt import Opt +from parlai.core.message import Message +from parlai.core.torch_agent import Batch, Output +from parlai.core.torch_generator_agent import PPLMetric +from parlai.core.metrics import SumMetric, AverageMetric +from parlai.utils.torch import padded_tensor +from parlai.utils.misc import warn_once +from parlai.agents.transformer.transformer import ( + TransformerGeneratorAgent, + add_common_cmdline_args +) + +from agents.modules import SelfConsciousTransformerModel +from modules.dnli_bert import DnliBert +from agents.history import SelfConsciousHistory, ContextConsciousHistory + + +def list_to_matrix(l, n): + return [l[i:i+n] for i in range(0, len(l), n)] + + +class SelfConsciousBlenderAgent(TransformerGeneratorAgent): + """ + Implementation of the Self-Conscious Blender Agent. + """ + + @classmethod + def add_cmdline_args(cls, argparser): + """ + Add command-line arguments specifically for this agent. + """ + agent = argparser.add_argument_group('Self-conscious Blender Arguments') + agent.add_argument( + '--conscious-target', + type=str, + choices=['none', 'self', 'context'], + default='self', + help='The target which the agent will be concerned about.', + ) + agent.add_argument( + '-a', + '--alpha', + type=float, + default=0, + help='Rationality parameter for S_1(speaker_1)', + ) + agent.add_argument( + '-b', + '--beta', + type=float, + default=1, + help='Rationality parameter for Listener', + ) + agent.add_argument( + '--world_cardinality', + type=int, + default=3, + help='Cardinality of world I:= Number of persona to use RSA model (including GT)', + ) + agent.add_argument( + '--worldprior', + type=str, + choices=['uniform', 'L0', 'L1'], + default='L0', + help='Update world prior with a `uniform` distribution or `L0` or `L1`.', + ) + agent.add_argument( + '--use_dnli', + type=bool, + default=True, + help='Whether to use dnli model to measure consistency-score in Convai2 or rerank candidates in DNLI' + ) + add_common_cmdline_args(agent) + cls.dictionary_class().add_cmdline_args(argparser) + + super(SelfConsciousBlenderAgent, cls).add_cmdline_args(argparser) + return agent + + def __init__(self, opt: Opt, shared=None): + + self.task = str.lower(opt['task'].split(':')[-1]) + + if opt['conscious_target'] != 'none': + assert opt['conscious_target'] in self.task, \ + "conscious_target (`" + opt['conscious_target'] + "`) must match task type (`" + self.task + "`)" + + SEED = 46 + random.seed(SEED) + np.random.seed(SEED) + os.environ['PYTHONHASHSEED'] = str(SEED) + torch.random.manual_seed(SEED) + torch.cuda.manual_seed(SEED) + torch.manual_seed(SEED) + torch.cuda.manual_seed_all(SEED) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + # For public self-consciousness + self.target_persona = opt.get('target_persona', 0) + self.conscious_target = opt.get('conscious_target', 'self') + self.world_cardinality = opt.get('world_cardinality', 3) + self.alpha = 0.0 if self.conscious_target == 'none' else opt.get('alpha', 2.0) + self.beta = opt.get('beta', 1.0) + self.worldprior = opt.get('worldprior', 'L0') + + self.eval_type = opt.get('eval_type') + # self.rank_candidates = opt.get('rank_candidates', True) + self.multigpu = ( + opt.get('multigpu', False) and self.use_cuda and (opt.get('batchsize') > 1) + ) + + init_model, is_finetune = self._get_init_model(opt, shared) + super().__init__(opt, shared) + + # Implementation is based on beam_size 1 + self.beam_size = 1 + warn_once(f'This implementation is assumed to have beam-size 1.') + + # Always rank candidates for the ranking metrics + self.rank_candidates = True + warn_once(f'rank-candidates is always True for ranking metrics.') + + if opt['use_dnli']: + if not shared: + self.dnli_model = DnliBert(opt, use_cuda=self.use_cuda) + else: + self.dnli_model = shared['dnli_model'] + else: + self.dnli_model = None + + self.id = 'SelfConsciousBlender' + + self.reset() + + def build_model(self, states=None): + """ + Build and return model. + """ + model = SelfConsciousTransformerModel(self.opt, self.dict) + if self.opt['embedding_type'] != 'random': + self._copy_embeddings( + model.encoder.embeddings.weight, self.opt['embedding_type'] + ) + return model + + def history_class(self): + return ContextConsciousHistory if 'context' in self.task else SelfConsciousHistory + + def _model_input(self, batch): + """ + Override from TorchGeneratorAgent + passes (batch.text_vec,) to TorchGeneratorAgent._encoder_input() + TGA._encoder_input() directly passes the result of TGA._model_input() + change batch.text_vec to batch.distractor_text_vec for pragmatic decoding + """ + bsz = batch.text_vec.size(0) + distractor_text_vec = batch.distractor_text_vec.view(bsz * self.world_cardinality, -1).contiguous() + return (distractor_text_vec,) + + def selfconscious_greedy_generate(self, batch, maxlen): + """ + Greedy decoding with Public Self-Consciousness + """ + + bsz = batch.text_vec.size(0) + world_cardinality = self.world_cardinality + embedding_size = self.opt.get('embedding_size') + encoder_states = self.model.encoder(*self._encoder_input(batch)) + + preds, scores = self.model.selfconscious_decode(encoder_states, maxlen) + + return preds, scores + + def rank(self, batch): + """ + Rank candidates by PPL score + """ + bsz = batch.text_vec.size(0) + world_cardinality = self.world_cardinality + embedding_size = self.opt.get('embedding_size') + ranked_candidates = [] + cand_ordering = [] + encoder_states = self.model.encoder(*self._encoder_input(batch)) + batch_dim = encoder_states[0].size(0) # two possibilities: batchsize or batchsize * world_cardinality + + if bsz != batch_dim: + enc_output = encoder_states[0].view(bsz, world_cardinality, -1, embedding_size).contiguous() + enc_output_mask = encoder_states[1].view(bsz, world_cardinality, -1).contiguous() + encoder_states = (enc_output, enc_output_mask) + + for i in range(bsz): + num_cands = len(batch.candidate_vecs[i]) + cands, _ = self._pad_tensor(batch.candidate_vecs[i]) + # get [i]th state from encoder_states #num_cands time. + # because we need same encoder_states for each candidate + enc = self.model.reorder_encoder_states(encoder_states, [i] * num_cands) + + # enc: (num_cands, world_cardinality, seqlen, emb_size) + # scores: (num_cands, max_len, vocab_size) + scores, _ = self.model.selfconscious_decode_forced(enc, cands) + + cand_losses = F.cross_entropy( + scores.view(num_cands * cands.size(1), -1), + cands.view(-1), + reduction='none', + ).view(num_cands, cands.size(1)) + # now cand_losses is cands x seqlen size, but we still need to + # check padding and such + mask = (cands != self.NULL_IDX) + mask = mask.half() if self.fp16 else mask.float() + cand_scores = (-cand_losses * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9) + + if self.dnli_model is not None and self.eval_type == 'dnli': + cand_scores = torch.unsqueeze(cand_scores, 0) + cand_scores = self.dnli_model.rerank_candidates([batch.observations[i]], cand_scores) + cand_scores = torch.squeeze(cand_scores) + + _, ordering = cand_scores.sort(descending=True) + ranked_candidates.append([batch.candidates[i][o] for o in ordering]) + cand_ordering.append(ordering) + + return ranked_candidates, cand_ordering + + def compute_loss(self, batch, return_output=False): + """ + Override from TorchGeneratorAgent + Compute and return the loss for the given batch. + + Easily overridable for customized loss functions. + + If return_output is True, the full output from the call to self.model() + is also returned, via a (loss, model_output) pair. + """ + if batch.label_vec is None: + raise ValueError('Cannot compute loss without a label.') + + bsz = batch.text_vec.size(0) + world_cardinality = self.world_cardinality + embedding_size = self.opt.get('embedding_size') + encoder_states = self.model.encoder(*self._encoder_input(batch)) + + enc_output = encoder_states[0].view(bsz, world_cardinality, -1, embedding_size).contiguous() + enc_output_mask = encoder_states[1].view(bsz, world_cardinality, -1).contiguous() + encoder_states = (enc_output, enc_output_mask) + + scores, preds = self.model.selfconscious_decode_forced(encoder_states, batch.label_vec) + model_output = (scores, preds, encoder_states) + + score_view = scores.view(-1, scores.size(-1)) + loss = self.criterion(score_view, batch.label_vec.view(-1)) + loss = loss.view(scores.shape[:-1]).sum(dim=1) + # save loss to metrics + notnull = batch.label_vec.ne(self.NULL_IDX) + target_tokens = notnull.long().sum(dim=-1) + correct = ((batch.label_vec == preds) * notnull).sum(dim=-1) + + self.record_local_metric('loss', AverageMetric.many(loss, target_tokens)) + self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens)) + self.record_local_metric( + 'token_acc', AverageMetric.many(correct, target_tokens) + ) + + # actually do backwards loss + loss = loss.sum() + loss /= target_tokens.sum() # average loss per token + + if return_output: + return (loss, model_output) + else: + return loss + + def _eval_convai2_step(self, batch): + """Evaluate a single batch of examples.""" + + assert self.alpha >= 0 + if batch.distractor_text_vec is None: + return None + + self.model.eval() + + # 1. Generation + assert self.beam_size is 1 + maxlen = self.label_truncate or 256 + if not self.skip_generation: + preds, scores = self.selfconscious_greedy_generate(batch, maxlen) + else: + preds = None + + # 2. Compute PPL with teacher-forced generation + # calculate loss on targets with teacher forcing + loss, model_output = self.compute_loss(batch, return_output=True) + token_losses = self._construct_token_losses( + batch.label_vec, model_output + ) + + # 3. Rank candidates by computing PPL for each candidates + if self.rank_candidates: + ranked_cands, ordering = self.rank(batch) + else: + ranked_cands = None + + # 4. Compute consistency score + additional_metrics = [{'c_score': 0.0} for _ in range(len(batch.observations))] + output_texts = [self._v2t(p) for p in preds] if preds is not None else None + if not self.skip_generation: + if self.opt['use_dnli']: + c_scores = [] + for text, obs in zip(output_texts, batch.observations): + if 'context' in self.task: + c_score = self.dnli_model.compute_consistency_scores(text, obs['my_context']) + else: + persona_strings = obs['my_persona'].split('\n') + c_score = self.dnli_model.compute_consistency_scores(text, persona_strings) + + c_scores.append(c_score) + + for idx, c_score in enumerate(c_scores): + additional_metrics[idx]['c_score'] = c_score + + return Output(output_texts, ranked_cands, token_losses=token_losses, metrics=additional_metrics) + + def _eval_dnli_step(self, batch): + """Evaluate a single batch of examples.""" + + assert self.alpha >= 0 + + self.model.eval() + ranked_cands, ordering = self.rank(batch) + + bsz = len(ranked_cands) + dnli_metrics = [] + for batch_idx in range(bsz): + dnli_score = {'contradict@1': 0, 'entail@1': 0, 'neutral@1': 0} + top1_idx = ordering[batch_idx][0].item() + if top1_idx == 0: + pass + # dnli_metrics['dnli_hit@1'] += 1 + elif top1_idx > 0 and top1_idx < 11: + dnli_score['contradict@1'] += 1 + elif top1_idx >= 11 and top1_idx < 21: + dnli_score['entail@1'] += 1 + else: + dnli_score['neutral@1'] += 1 + dnli_metrics.append(dnli_score) + + return Output(text_candidates=ranked_cands, metrics=dnli_metrics) + + def eval_step(self, batch): + + if self.opt['eval_type'] == 'convai2': + return self._eval_convai2_step(batch) + elif self.opt['eval_type'] == 'dnli': + return self._eval_dnli_step(batch) + else: + raise NotImplementedError + + def self_observe(self, self_message: Message): + """ + Override from TorchAgent + Update the model's reply or label to the history of distractor-fields in History class + """ + episode_done = self.observation['episode_done'] + use_reply = self.opt.get('use_reply', 'label') + + # actually ingest the label + if use_reply == 'none': + # we're not including our own responses anyway. + reply = None + elif use_reply == 'label': + # first look for the true label + label_key = ( + 'labels' + if 'labels' in self.observation + else 'eval_labels' + if 'eval_labels' in self.observation + else None + ) + if label_key is not None: + lbls = self.observation[label_key] + reply = lbls[0] if len(lbls) == 1 else self.random.choice(lbls) + else: + # otherwise, we use the last output the model generated + if self_message is not None: + reply = self_message['text'] + else: + reply = None + + super().self_observe(self_message) + + if episode_done: + return None + + if reply is not None: + if 'context' in self.task: + self.history.add_reply_to_distractors(reply, self.observation) + else: + self.history.add_reply_to_distractors(reply) + + return reply + + def _ordered_cand_scores_to_cand_text(self, ordered_cand_preds, cand_inds, candidates): + cand_replies = [None] * len(candidates) + + for idx, order in enumerate(ordered_cand_preds): # batch_idx, sorted cand_idx + batch_idx = cand_inds[idx] + # get the original sentences from candidates by order + cand_replies[batch_idx] = [candidates[batch_idx][i] for i in order] + + return cand_replies + + def _build_candidates_tensor(self, batch): + if not batch.candidates: + return None, None + + cand_inds = [i for i in range(len(batch.candidates)) if batch.candidates[i]] + cands = [batch.candidate_vecs[i] for i in cand_inds] + + # get the length of the longest candidate in the batch + max_cand_len = max( + [max([cand.size(0) for cand in cands_i]) for cands_i in cands] + ) + + for i, c in enumerate(cands): # make each instance in batch.cands to a padded tensor + cands[i] = padded_tensor(c, use_cuda=self.use_cuda, + max_len=max_cand_len, + fp16friendly=self.fp16)[0].unsqueeze(0) + + # (batchsize, num_cands, max_len + a) +a due to fp16 + cands = torch.cat(cands, 0) + + return cands, cand_inds + + def vectorize(self, obs, history, **kwargs): + """ + Override from TorchAgent + Vectorize the texts in observation + """ + super().vectorize(obs, history, **kwargs) # candidate vecs are vectorized here + if not self.is_training: + self._set_distractor_text_vec(obs, history, kwargs['text_truncate']) + return obs + + def _set_text_vec(self, obs, history, truncate): + """ + Override from TorchAgent for DNLI evaluation + This will be called in super().vectorize() + """ + # WARNING: self.is_training is always False in here + is_training = False if 'eval_labels' in obs else True + + if is_training or self.opt['eval_type'] == 'convai2': + return super()._set_text_vec(obs, history, truncate) + elif self.opt['eval_type'] == 'dnli': + if 'text' not in obs: + return obs + + # Vectorize the text + if 'text_vec' not in obs: + obs['full_text'] = obs['text'] + vec = self.dict.txt2vec(obs['full_text']) + obs['text_vec'] = vec + + # check truncation + if obs.get('text_vec') is not None: + truncated_vec = self._check_truncate(obs['text_vec'], truncate, True) + obs.force_set('text_vec', torch.LongTensor(truncated_vec)) + return obs + else: + raise NotImplementedError + + def _set_distractor_text_vec(self, obs, history, truncate): + """ + Set 'distractor_text' and 'distractor_text_vec' field in the observation + """ + if 'distractor_text' not in obs: + return obs + + if 'distractor_text_vec' not in obs: + # distractor_text is in the SelfConsciousHistory class + distractor_string = history.get_history_distractor_str() + + if distractor_string is None: + return obs + + # Set 'full_distractor_text' + obs['full_distractor_text'] = distractor_string + # distractor_text_vec is also in the SelfConsciousHistory class + # they are already vectorized at SelfConsciousHistory.update_history() + if distractor_string: + obs['distractor_text_vec'] = history.get_history_distractor_vec() + + # Check truncation + if obs.get('distractor_text_vec') is not None: + truncated_vec = [ + torch.LongTensor(self._check_truncate(text_vec, truncate, True)) + for text_vec in obs['distractor_text_vec'] + ] + obs.force_set('distractor_text_vec', truncated_vec) + return obs + + def batchify(self, *args, **kwargs): + """ + Override from TorchAgent + Additionally batchify the distractor_text_vec and add it to batch + """ + kwargs['sort'] = True # need sort for pack_padded() + batch = super().batchify(*args, **kwargs) + sort = False # we must not sort after super().batchify() + + exs = batch.observations + d_text_vec, d_lens = None, None + if any('distractor_text_vec' in ex for ex in exs): + # Pad distractor vectors + _d_text_vec = [ex.get('distractor_text_vec', self.EMPTY) for ex in exs] + _d_text_vec_flattened = list(chain(*_d_text_vec)) + d_text_vec, d_lens = self._pad_tensor(_d_text_vec_flattened) + + # Reshape to (batch_size, world_cardinality, max_length) + bsz = len(exs) + d_text_vec = d_text_vec.view(bsz, self.world_cardinality, -1) + d_lens = list_to_matrix(d_lens, self.world_cardinality) + + batch = Batch( + distractor_text_vec=d_text_vec, + distractor_text_lengths=d_lens, + **dict(batch) + ) + + return batch + + def share(self): + shared = super().share() + if self.opt['use_dnli']: + shared['dnli_model'] = self.dnli_model + return shared diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..a49fe08 --- /dev/null +++ b/environment.yml @@ -0,0 +1,30 @@ +name: pragmatic-consistency +channels: + - defaults +dependencies: + - cudatoolkit=10.1 + - pytorch::pytorch=1.6.0 + - python=3.6.8 + - pip=20.1.1 + - pip: + # ParlAI (use commit of 1st of October, 2020) + - git+https://github.com/facebookresearch/ParlAI.git@9cd6b6c0e70c72a24e959e4a328cb4093eb7f3de + - torchtext==0.7.0 + - spacy==2.3.2 + - pytorch-transformers==1.2.0 + # For logging + - tqdm + - better-exceptions + # For linting + - pylint + - pycodestyle + - mypy + # For markdown preview + - grip + # etc + - ruamel.yaml + - more_itertools + - isort + - pudb + - jupyter + - orderedset diff --git a/eval_dnli.py b/eval_dnli.py new file mode 100644 index 0000000..292f96a --- /dev/null +++ b/eval_dnli.py @@ -0,0 +1,21 @@ +from parlai.scripts.eval_model import eval_model +from parlai.scripts.eval_model import setup_args as parlai_setupargs + + +def setup_args(): + parser = parlai_setupargs() + parser.set_defaults( + model_file='zoo:blender/blender_90M/model', + eval_type='dnli', + metrics='contradict@1,entail@1,neutral@1', + alpha=8, + beta=1, + use_dnli=False + ) + return parser + + +if __name__ == '__main__': + parser = setup_args() + opt = parser.parse_args() + eval_model(opt) diff --git a/eval_personachat.py b/eval_personachat.py new file mode 100644 index 0000000..698a46c --- /dev/null +++ b/eval_personachat.py @@ -0,0 +1,20 @@ +from parlai.scripts.eval_model import eval_model +from parlai.scripts.eval_model import setup_args as parlai_setupargs + + +def setup_args(): + parser = parlai_setupargs() + parser.set_defaults( + model_file='zoo:blender/blender_90M/model', + eval_type='convai2', + metrics='token_acc,ppl,loss,c_scores,f1', + alpha=2, + beta=0.5 + ) + return parser + + +if __name__ == '__main__': + parser = setup_args() + opt = parser.parse_args() + eval_model(opt) diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/dnli_bert.py b/modules/dnli_bert.py new file mode 100644 index 0000000..a98c79e --- /dev/null +++ b/modules/dnli_bert.py @@ -0,0 +1,214 @@ +import os +from copy import deepcopy + +import torch +import torch.nn.functional as F +import numpy as np +from pytorch_transformers import ( + BertForSequenceClassification, + BertTokenizer +) +from parlai.core.build_data import download_from_google_drive + + +def _truncate_seq_pair(tokens_a, tokens_b, max_length): + """Truncates a sequence pair in place to the maximum length.""" + # This is a simple heuristic which will always truncate the longer sequence + # one token at a time. This makes more sense than truncating an equal percent + # of tokens from each, since if one sequence is very short then each token + # that's truncated likely contains more information than a longer sequence. + while True: + total_length = len(tokens_a) + len(tokens_b) + if total_length <= max_length: + break + if len(tokens_a) > len(tokens_b): + tokens_a.pop() + else: + tokens_b.pop() + + +class DnliBert(object): + def __init__(self, + opt, + dnli_lambda=1.0, + dnli_k=10, + max_seq_length=128, + use_cuda=True): + self.opt = opt + self.dnli_lambda = dnli_lambda + self.dnli_k = dnli_k + self.max_seq_length = max_seq_length + self.use_cuda = use_cuda + self.mapper = {0: "contradiction", + 1: "entailment", + 2: "neutral"} + + dnli_model, dnli_tokenizer = self._load_dnli_model() + self.dnli_model = dnli_model + self.dnli_tokenizer = dnli_tokenizer + + def _load_dnli_model(self): + # Download pretrained weight + dnli_model_fname = os.path.join(self.opt['datapath'], 'dnli_model.bin') + if not os.path.exists(dnli_model_fname): + print(f"[ Download pretrained dnli model params to {dnli_model_fname}]") + download_from_google_drive( + '1Qawz1pMcV0aGLVYzOgpHPgG5vLSKPOJ1', + dnli_model_fname + ) + + # Load pretrained weight + print(f"[ Load pretrained dnli model from {dnli_model_fname}]") + model_state_dict = torch.load(dnli_model_fname) + dnli_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', state_dict=model_state_dict, num_labels=3) + if self.use_cuda: + dnli_model.cuda() + dnli_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) + + return dnli_model, dnli_tokenizer + + def rerank_candidates(self, observations, cand_scores): + sorted_cand_values, sorted_cand_indices = cand_scores.sort(1, descending=True) + + for batch_idx, obs in enumerate(observations): + full_text = obs['full_text'] + personas = [] + for text in full_text.split('\n'): + if 'your persona:' in text: + personas.append(text.replace('your persona:', '')) + else: + break + candidates = obs['label_candidates'] + + tok_candidates = [self.dnli_tokenizer.tokenize(sent) for sent in candidates] + tok_personas = [self.dnli_tokenizer.tokenize(sent) for sent in personas] + + dnli_scores = self._compute_dnli_scores(tok_candidates, tok_personas) + s_1 = sorted_cand_values[batch_idx, 0] + s_k = sorted_cand_values[batch_idx, self.dnli_k - 1] + + _lambda = self.dnli_lambda + cand_scores[batch_idx] = cand_scores[batch_idx] - _lambda * (s_1 - s_k) * dnli_scores + + return cand_scores + + def compute_consistency_scores(self, pred, personas): + """ + preds, and personas must be list of string + """ + max_seq_length = self.max_seq_length + + pred_tokenized = self.dnli_tokenizer.tokenize(pred) + personas_tokenized = [self.dnli_tokenizer.tokenize(sent.replace('your persona:', '')) for sent in personas] + + all_input_ids = [] + all_input_mask = [] + all_segment_ids = [] + for idx, persona_tokenized in enumerate(personas_tokenized): + _pred_tokenized = deepcopy(pred_tokenized) + _persona_tokenized = deepcopy(persona_tokenized) + _truncate_seq_pair(_pred_tokenized, _persona_tokenized, max_seq_length - 3) + + tokens = ["[CLS]"] + _pred_tokenized + ["[SEP]"] + segment_ids = [0] * len(tokens) + tokens += _persona_tokenized + ["[SEP]"] + segment_ids += [1] * (len(_persona_tokenized) + 1) + + input_ids = self.dnli_tokenizer.convert_tokens_to_ids(tokens) + input_mask = [1] * len(input_ids) + padding = [0] * (max_seq_length - len(input_ids)) + input_ids += padding + input_mask += padding + segment_ids += padding + + all_input_ids.append(input_ids) + all_input_mask.append(input_mask) + all_segment_ids.append(segment_ids) + + # Convert inputs to tensors + all_input_ids = torch.tensor(all_input_ids, dtype=torch.long) + all_input_mask = torch.tensor(all_input_mask, dtype=torch.long) + all_segment_ids = torch.tensor(all_segment_ids, dtype=torch.long) + if self.use_cuda: + all_input_ids = all_input_ids.cuda() + all_input_mask = all_input_mask.cuda() + all_segment_ids = all_segment_ids.cuda() + + # Inference + self.dnli_model.eval() + with torch.no_grad(): + logits = self.dnli_model(all_input_ids, all_segment_ids, all_input_mask) + probs = F.softmax(logits[0], dim=1) + + probs = probs.detach().cpu().numpy() + idx_max = np.argmax(probs, axis=1) + val_max = np.max(probs, axis=1) + + consistency_score = 0.0 + for pred_idx in idx_max: + if pred_idx == 0: # contradict + consistency_score -= 1.0 + elif pred_idx == 1: # entailment + consistency_score += 1.0 + elif pred_idx == 2: # neutral + consistency_score += 0.0 + + return consistency_score + + def _compute_dnli_scores(self, tok_candidates, tok_personas): + max_seq_length = self.max_seq_length + + dnli_scores = [] + for cand_idx, tok_candidate in enumerate(tok_candidates): + all_input_ids = [] + all_input_mask = [] + all_segment_ids = [] + for tok_persona in tok_personas: + # Prepare inputs + # [CLS] candidates [SEP] persona [SEP] + _tok_candidate = deepcopy(tok_candidate) + _tok_persona = deepcopy(tok_persona) + # Account for [CLS], [SEP], [SEP] with "- 3" + _truncate_seq_pair(_tok_candidate, _tok_persona, max_seq_length - 3) + + # Make inputs + tokens = ["[CLS]"] + _tok_candidate + ["[SEP]"] + segment_ids = [0] * len(tokens) + tokens += _tok_persona + ["[SEP]"] + segment_ids += [1] * (len(_tok_persona) + 1) + + input_ids = self.dnli_tokenizer.convert_tokens_to_ids(tokens) + input_mask = [1] * len(input_ids) + padding = [0] * (max_seq_length - len(input_ids)) + input_ids += padding + input_mask += padding + segment_ids += padding + + all_input_ids.append(input_ids) + all_input_mask.append(input_mask) + all_segment_ids.append(segment_ids) + + # Convert inputs to tensors + all_input_ids = torch.tensor(all_input_ids, dtype=torch.long) + all_input_mask = torch.tensor(all_input_mask, dtype=torch.long) + all_segment_ids = torch.tensor(all_segment_ids, dtype=torch.long) + if self.use_cuda: + all_input_ids = all_input_ids.cuda() + all_input_mask = all_input_mask.cuda() + all_segment_ids = all_segment_ids.cuda() + + # Inference + self.dnli_model.eval() + with torch.no_grad(): + logits = self.dnli_model(all_input_ids, all_segment_ids, all_input_mask) + probs = F.softmax(logits[0], dim=1) + + probs = probs.detach().cpu().numpy() + idx_max = np.argmax(probs, axis=1) + val_max = np.max(probs, axis=1) + dnli_score = np.max((idx_max == 0) * val_max) + dnli_scores.append(dnli_score) + dnli_scores = torch.tensor(dnli_scores, dtype=torch.float) + if self.use_cuda: + dnli_scores = dnli_scores.cuda() + return dnli_scores diff --git a/tasks/__init__.py b/tasks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tasks/build.py b/tasks/build.py new file mode 100755 index 0000000..f40e0c9 --- /dev/null +++ b/tasks/build.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import parlai.core.params as params +import parlai.core.build_data as build_data + + +FOLDER_NAME = 'self_conscious_dialogue' + + +def build(opt): + dpath = os.path.join(opt['datapath'], FOLDER_NAME) + # version 1.0: initial release + version = '1.0' + + # check whether data had been previously built + if not build_data.built(dpath, version_string=version): + print('[building data: ' + dpath + ']') + + # make a clean directory if needed + if build_data.built(dpath): + # if an older version exists, remove those outdated files. + build_data.remove_dir(dpath) + build_data.make_dir(dpath) + + ######################### + # ConvAI2 (PersonaChat) + ######################### + fname = 'data_v1.tar.gz' + url = 'https://parl.ai/downloads/controllable_dialogue/' + fname + build_data.download(url, dpath, fname) + build_data.untar(dpath, fname) + + fname = 'convai2_fix_723.tgz' + url = 'http://parl.ai/downloads/convai2/' + fname + build_data.download(url, dpath, fname) + build_data.untar(dpath, fname) + + ######################### + # Dialogue NLI + ######################### + fname = 'dialogue_nli.zip' + gd_id = '1WtbXCv3vPB5ql6w0FVDmAEMmWadbrCuG' + build_data.download_from_google_drive(gd_id, os.path.join(dpath, fname)) + build_data.untar(dpath, fname) + + fname = 'dialogue_nli_evaluation.zip' + gd_id = '1sllq30KMJzEVQ4C0-a9ShSLSPIZc3iMi' + build_data.download_from_google_drive(gd_id, os.path.join(dpath, fname)) + build_data.untar(dpath, fname) + + ######################### + # Distractor personas + ######################### + fname = 'train_sorted_50_personas.json' + gd_id = '1SGFdJqyNYeepKFqwMLv4Ym717QQTtpi8' + build_data.download_from_google_drive(gd_id, os.path.join(dpath, fname)) + fname = 'valid_sorted_50_personas.json' + gd_id = '1A7oVKmjJ1EZTh6-3Gio4XQo81QgnTGGi' + build_data.download_from_google_drive(gd_id, os.path.join(dpath, fname)) + fname = 'dnli_sorted_50_personas.json' + gd_id = '1wlIkVcBZoGQd3rbI7XWNhuq4rvw9FyoP' + build_data.download_from_google_drive(gd_id, os.path.join(dpath, fname)) + + print("Data has been placed in " + dpath) + + build_data.mark_done(dpath, version) + + +def make_path(opt, fname): + return os.path.join(opt['datapath'], FOLDER_NAME, fname) + + +if __name__ == '__main__': + opt = params.ParlaiParser().parse_args(print_args=False) + build(opt) diff --git a/tasks/teachers.py b/tasks/teachers.py new file mode 100644 index 0000000..350046d --- /dev/null +++ b/tasks/teachers.py @@ -0,0 +1,497 @@ +import copy +import os +import math +import json +import random +from operator import itemgetter +from orderedset import OrderedSet +from collections import defaultdict +import pickle + +from parlai.utils.misc import warn_once, str_to_msg +from parlai.core.message import Message +from parlai.core.torch_agent import TorchAgent +from parlai.core.teachers import ( + ParlAIDialogTeacher, + FixedDialogTeacher, + FbDeprecatedDialogTeacher, + DialogData +) + +from .build import build, make_path + +__PATH__ = os.path.abspath(os.path.dirname(__file__)) + + +def _path(opt): + build(opt) + datatype = opt['datatype'].split(':')[0] + if datatype == 'test': + warn_once("WARNING: Test set not included. Setting datatype to valid.") + datatype = 'valid' + return make_path(opt, datatype + '.txt'), datatype + + +def _split_persona_and_context(text, eval_type='convai2'): + if 'your persona:' not in text: + return None, text + else: + if eval_type == 'convai2': + texts = text.split('\n') + return '\n'.join(texts[:-1]), texts[-1] + elif eval_type =='dnli': + texts = text.split('\n') + last_idx = 0 + for idx, text in enumerate(texts): + if 'your persona:' in text: + last_idx = idx + persona_texts = texts[:last_idx+1] + context_texts = texts[last_idx+1:] + return '\n'.join(persona_texts), '\n'.join(context_texts) + + +def _split_personas_and_context(text): + if 'your persona:' not in text: + return text, text, text + else: + your_personas = [] + partner_personas = [] + context = [] + texts = text.split('\n') + for text in texts: + if text.startswith('your persona:'): + your_personas.append(text) + elif text.startswith("partner's persona:"): + partner_personas.append(text) + else: + context.append(text) + + return '\n'.join(your_personas), '\n'.join(partner_personas), context + + +class SelfConsciousDialogueTeacher(FixedDialogTeacher): + """ + Teacher (i.e. input data supplier) for the Self-conscious Agent. + SelfConsciousDialogueTeacher (SCDT) supplies data input + along with the distractors to the Self-conscious Agent. + """ + def __init__(self, opt, shared=None): + super().__init__(opt, shared) + self.opt = opt + + datapath, datatype = _path(opt) + + if not shared: + self.episodes = [] + self.num_exs = 0 + self._setup_data(datapath, datatype) + else: + self.episodes = shared['episodes'] + self.num_exs = sum(len(e) for e in self.episodes) + self.id = 'self_conscious_dialogue' + self.reset() + + @staticmethod + def add_cmdline_args(argparser): + agent = argparser.add_argument_group('Self Conscious Dialogue Teacher arguments') + agent.add_argument( + '--eval-type', + type=str, + choices=['convai2', 'dnli'], + default='dnli', + help='Which validation data to use', + ) + + def _setup_data(self, path, datatype): + + random.seed(46) + + # Data loading with script of ParlAIDialogTeacher + print(f"[Loading ParlAI text data: {path}]") + + # Read data from ConvAI2 + convai2_datapath = make_path(self.opt, f'{datatype}_both_original.txt') + convai2_episodes = self._load_convai2_data(convai2_datapath) + + # Get persona pool + all_personas, persona_to_idx = self._get_persona_pool(self.opt) + sorted_personas = self._get_sorted_persona_pool(datatype) + + + if self.opt['eval_type'] == 'convai2': + self.episodes = [] + self.num_exs = 0 + eps = [] + with open(path) as read: + for line in read: + msg = str_to_msg(line.rstrip('\n')) + if msg: + self.num_exs += 1 + eps.append(msg) + if msg.get('episode_done', False): + self.episodes.append(eps) + eps = [] + if len(eps) > 0: + # add last episode + eps[-1].force_set('episode_done', True) + self.episodes.append(eps) + # Add label candidates and partner's persona + for episode_idx, episode in enumerate(self.episodes): + for turn_idx, turn in enumerate(episode): + convai2_turn = convai2_episodes[episode_idx][turn_idx] + convai2_text = convai2_turn[0] + label_candidates = convai2_turn[3] + + turn['label_candidates'] = label_candidates + if turn_idx == 0: + my_persona, partner_persona, _ = _split_personas_and_context(convai2_text) + turn['partner_persona'] = partner_persona + turn['my_persona'] = my_persona + else: + turn['partner_persona'] = episode[0]['partner_persona'] + turn['my_persona'] = episode[0]['my_persona'] + elif self.opt['eval_type'] == 'dnli': + self.episodes = [] + self.num_exs = 0 + for eval_set in ['attributes', 'havenot', 'likedislike']: + datapath = make_path(self.opt, f'{datatype}_{eval_set}.jsonl') + with open(datapath, 'r') as fp: + for line in fp: + msg = json.loads(line) + msg['eval_set'] = eval_set + msg['episode_done'] = True + + # Make 'text' + persona_lines = [f'your persona: {x[:-2]}.' for x in msg['persona']] + utts = msg['prefix'] + + p1_token, p2_token = TorchAgent.P1_TOKEN, TorchAgent.P2_TOKEN + lines = persona_lines + # Identify the dialogue lines. It's assumed that p1 goes first. + for i, utt in enumerate(utts): + if i % 2 == 0: + lines.append(f'{p1_token} {utt}') + else: + lines.append(f'{p2_token} {utt}') + text = '\n'.join(lines) + + msg['text'] = text + + # Make 'label_candidates' + cands = msg['candidates'] + msg['label_candidates'] = cands['label'] + cands['neg'][:10] \ + + cands['similar'][:10] + cands['rand'][:10] + + # Remove unused attributes + del msg['persona'] + del msg['prefix'] + del msg['triple'] + del msg['relevant_persona_sentence'] + del msg['candidates'] + + self.episodes.append([msg]) + self.num_exs += 1 + + # Add distractor personas + if self.opt['world_cardinality'] > 0: + num_all_personas = len(all_personas) + persona_indices = list(range(num_all_personas)) + world_cardinality = self.opt['world_cardinality'] + for episode in self.episodes: + gt_persona, first_context = _split_persona_and_context(episode[0]['text'], self.opt['eval_type']) + gt_persona_idx = persona_to_idx.get(gt_persona, -1) + + # Choose random distractor personas + distractor_indices = random.sample(persona_indices, world_cardinality - 1) + while gt_persona_idx in distractor_indices: + # Resample if gt_persona is sampled + distractor_indices = random.sample(persona_indices, world_cardinality - 1) + distractor_personas = itemgetter(*distractor_indices)(all_personas) + distractor_personas = list(distractor_personas) + + # Make it to 'distractor_text' + for turn_idx, turn in enumerate(episode): + if turn_idx == 0: + turn['distractor_text'] = [ + '\n'.join([persona, first_context]) + for persona in [gt_persona] + distractor_personas + ] + else: + turn['distractor_text'] = [turn['text']] * world_cardinality + + def _get_persona_pool(self, opt, remove_duplicate=True): + print("[loading persona pool from convai2 training data]") + # Get episodes from training dataset + datapath = make_path(opt, 'train.txt') + episodes = [] + eps = [] + with open(datapath) as read: + for line in read: + msg = str_to_msg(line.rstrip('\n')) + if msg: + # self.num_exs += 1 + eps.append(msg) + if msg.get('episode_done', False): + episodes.append(eps) + eps = [] + if len(eps) > 0: + # add last episode + eps[-1].force_set('episode_done', True) + episodes.append(eps) + + # Extract personas from episodes + persona_set = OrderedSet() + for episode in episodes: + first_turn = episode[0] + text = first_turn['text'] + persona, _ = _split_persona_and_context(text) + persona_set.add(persona) + + # Remove duplicate + if remove_duplicate: + train_persona_fname = os.path.join(__PATH__, 'train_persona_map.pkl') + with open(train_persona_fname, 'rb') as fp: + _train_personas = pickle.load(fp) + train_personas = [] + for personas in _train_personas.values(): + longest_idx = 0 + longest_length = -1 + for idx, persona in enumerate(personas): + if len(persona) > longest_length: + longest_idx = idx + longest_length = len(persona) + selected_persona = map(lambda x: f"your persona: {x}.",personas[longest_idx]) + selected_persona = '\n'.join(selected_persona) + train_personas.append(selected_persona) + persona_set = OrderedSet() + for train_persona in train_personas: + persona_set.add(train_persona) + + all_personas = [] + persona_to_idx = {} + for i, persona in enumerate(persona_set): + all_personas.append(persona) + persona_to_idx[persona] = i + + print(f"Total {len(all_personas)} personas in dataset") + + return all_personas, persona_to_idx + + def _get_sorted_persona_pool(self, datatype): + print("[loading sorted persona pool from convai2 training data]") + eval_type = self.opt['eval_type'] + if eval_type == 'convai2': + datapath = make_path(self.opt, 'valid_sorted_50_personas.json') + elif eval_type == 'dnli': + datapath = make_path(self.opt, 'dnli_sorted_50_personas.json') + else: + raise ValueError("eval_set must be one of convai2 and dnli") + + with open(datapath, 'r') as fp: + sorted_personas = json.load(fp) + sorted_personas['idx2persona'] = sorted_personas['train_personas'] + sorted_personas['persona2idx'] = {} + for idx, persona in enumerate(sorted_personas['train_personas']): + sorted_personas['persona2idx'][persona] = idx + + return sorted_personas + + def _load_convai2_data(self, datapath): + """ + Read data in the fbdialog format. + Returns ``(x, y, r, c)`` tuples. + ``x`` represents a query, ``y`` represents the labels, ``r`` represents + any reward, and ``c`` represents any label_candidates. + The example above will be translated into the following tuples: + :: + x: 'Sam went to the kitchen\nPat gave Sam the milk\nWhere is the milk?' + y: ['kitchen'] + r: '1' + c: ['hallway', 'kitchen', 'bathroom'] + new_episode = True (this is the first example in the episode) + :: + x: 'Sam went to the hallway\\nPat went to the bathroom\\nWhere is the + milk?' + y: ['hallway'] + r: '1' + c: ['hallway', 'kitchen', 'bathroom'] + new_episode = False (this is the second example in the episode) + """ + self.cloze = False # Set this to use FbDialogTeacher + convai2_dataloader = FbDeprecatedDialogTeacher.setup_data(self, datapath) + convai2_episodes = [] + for episode in DialogData._read_episode(self, convai2_dataloader): + convai2_episodes.append(episode) + del self.cloze + return convai2_episodes + + def share(self): + shared = super().share() + shared['episodes'] = self.episodes + return shared + + def num_examples(self): + return self.num_exs + + def num_episodes(self): + return len(self.episodes) + + def get(self, episode_idx, entry_idx=None): + return self.episodes[episode_idx][entry_idx] + + +class ContextConsciousDialogueTeacher(SelfConsciousDialogueTeacher): + def _setup_data(self, path, datatype): + # random.seed(self.opt['random_seed']) # Set this for pick same distractor persona + random.seed(46) # Set this for pick same distractor persona + # Data loading with script of ParlAIDialogTeacher + print(f"[Loading ParlAI text data: {path}]") + + # Read data from ConvAI2 + convai2_datapath = make_path(self.opt, f'{datatype}_both_original.txt') + convai2_episodes = self._load_convai2_data(convai2_datapath) + + if self.opt['eval_type'] == 'convai2': + self.episodes = [] + self.num_exs = 0 + eps = [] + with open(path) as read: + for line in read: + msg = str_to_msg(line.rstrip('\n')) + if msg: + self.num_exs += 1 + eps.append(msg) + if msg.get('episode_done', False): + self.episodes.append(eps) + eps = [] + if len(eps) > 0: + # add last episode + eps[-1].force_set('episode_done', True) + self.episodes.append(eps) + # Add label candidates and partner's persona + for episode_idx, episode in enumerate(self.episodes): + for turn_idx, turn in enumerate(episode): + convai2_turn = convai2_episodes[episode_idx][turn_idx] + convai2_text = convai2_turn[0] + label_candidates = convai2_turn[3] + + turn['label_candidates'] = label_candidates + if turn_idx == 0: + my_persona, partner_persona, _ = _split_personas_and_context(convai2_text) + turn['partner_persona'] = partner_persona + turn['my_persona'] = my_persona + else: + turn['partner_persona'] = episode[0]['partner_persona'] + turn['my_persona'] = episode[0]['my_persona'] + elif self.opt['eval_type'] == 'dnli': + self.episodes = [] + self.num_exs = 0 + for eval_set in ['attributes', 'havenot', 'likedislike']: + datapath = make_path(self.opt, f'{datatype}_{eval_set}.jsonl') + with open(datapath, 'r') as fp: + for line in fp: + msg = json.loads(line) + msg['eval_set'] = eval_set + msg['episode_done'] = True + + # Make 'text' + persona_lines = [f'your persona: {x[:-2]}.' for x in msg['persona']] + utts = msg['prefix'] + + p1_token, p2_token = TorchAgent.P1_TOKEN, TorchAgent.P2_TOKEN + lines = persona_lines + # Identify the dialogue lines. It's assumed that p1 goes first. + for i, utt in enumerate(utts): + if i % 2 == 0: + lines.append(f'{p1_token} {utt}') + else: + lines.append(f'{p2_token} {utt}') + text = '\n'.join(lines) + + msg['text'] = text + + # Make 'label_candidates' + cands = msg['candidates'] + msg['label_candidates'] = cands['label'] + cands['neg'][:10] \ + + cands['similar'][:10] + cands['rand'][:10] + + # Remove unused attributes + del msg['persona'] + del msg['prefix'] + del msg['triple'] + del msg['relevant_persona_sentence'] + del msg['candidates'] + + self.episodes.append([msg]) + self.num_exs += 1 + + # Get dialogue history pool + context_pool = self._get_context_pool(self.opt) + + # Add distractor history + if self.opt['world_cardinality'] > 0: + for episode in self.episodes: + gt_persona, first_context = _split_persona_and_context(episode[0]['text'], self.opt['eval_type']) + + # Select distractor history + if self.opt['eval_type'] == 'convai2': + num_turn = len(episode) + else: + dialogue = first_context.split('\n') + num_turn = math.ceil(len(dialogue)/2) + if num_turn < min(context_pool.keys()): + # orginal_num_turn = num_turn + num_turn = min(context_pool.keys()) + + context_indices = list(range(len(context_pool[num_turn]))) + + distractor_c_indices = random.sample(context_indices, self.opt['world_cardinality'] - 1) + distractor_contexts = itemgetter(*distractor_c_indices)(context_pool[num_turn]) + + # Make it to 'distractor_text' + if self.opt['eval_type'] == 'convai2': + for turn_idx, turn in enumerate(episode): + turn['distractor_text'] = turn['labels'] + [c[turn_idx] for c in distractor_contexts] + if turn_idx == 0: + turn['my_context'] = turn['labels'] + else: + turn['my_context'] = episode[turn_idx - 1]['my_context'] + turn['labels'] + else: + # DNLI + distractor_text = [episode[0]['text']] + for c in distractor_contexts: + copied_dialogue = copy.deepcopy(dialogue) + for turn_idx, utterance in enumerate(copied_dialogue): + if turn_idx % 2 == 1: + copied_dialogue[turn_idx] = p2_token + c[turn_idx // 2] + distractor_context = '\n'.join([gt_persona] + copied_dialogue) + distractor_text.append(distractor_context) + episode[0]['distractor_text'] = distractor_text + + def _get_context_pool(self, opt): + print("[loading history pool from convai2 training data]") + datapath = make_path(opt, 'train.txt') + episodes = [] + eps = [] + with open(datapath) as read: + for line in read: + msg = str_to_msg(line.rstrip('\n')) + if msg: + eps.append(msg) + if msg.get('episode_done', False): + episodes.append(eps) + eps = [] + if len(eps) > 0: + # add last episode + eps[-1].force_set('episode_done', True) + episodes.append(eps) + + context_pool = defaultdict(list) + for ep in episodes: + context_pool[len(ep)].append([turn['labels'][0] for turn in ep]) + + return dict(context_pool) + + +class DefaultTeacher(SelfConsciousDialogueTeacher): + pass diff --git a/tasks/test_persona_map.pkl b/tasks/test_persona_map.pkl new file mode 100644 index 0000000..042d592 Binary files /dev/null and b/tasks/test_persona_map.pkl differ diff --git a/tasks/train_persona_map.pkl b/tasks/train_persona_map.pkl new file mode 100644 index 0000000..21e69b4 Binary files /dev/null and b/tasks/train_persona_map.pkl differ diff --git a/tasks/valid_persona_map.pkl b/tasks/valid_persona_map.pkl new file mode 100644 index 0000000..30c3462 Binary files /dev/null and b/tasks/valid_persona_map.pkl differ