From 989577c857dd76e38aa7352bdb39481158cbee8d Mon Sep 17 00:00:00 2001 From: ankitade Date: Sun, 7 Feb 2021 18:43:27 -0800 Subject: [PATCH] Add pathmanager support to bert agent (#3439) Co-authored-by: deankita --- .../agents/bert_classifier/bert_classifier.py | 77 ++++++++++--------- parlai/agents/bert_ranker/bert_dictionary.py | 33 ++++---- 2 files changed, 57 insertions(+), 53 deletions(-) diff --git a/parlai/agents/bert_classifier/bert_classifier.py b/parlai/agents/bert_classifier/bert_classifier.py index 8d45beb0b0e..b25e60d7910 100644 --- a/parlai/agents/bert_classifier/bert_classifier.py +++ b/parlai/agents/bert_classifier/bert_classifier.py @@ -8,20 +8,21 @@ BERT classifier agent uses bert embeddings to make an utterance-level classification. """ +import os +from collections import deque from typing import Optional -from parlai.core.params import ParlaiParser -from parlai.core.opt import Opt + +import torch from parlai.agents.bert_ranker.bert_dictionary import BertDictionaryAgent from parlai.agents.bert_ranker.helpers import BertWrapper, MODEL_PATH +from parlai.core.opt import Opt +from parlai.core.params import ParlaiParser from parlai.core.torch_agent import History from parlai.core.torch_classifier_agent import TorchClassifierAgent +from parlai.utils.io import PathManager from parlai.utils.misc import warn_once from parlai.zoo.bert.build import download -from collections import deque -import os -import torch - try: from pytorch_pretrained_bert import BertModel except ImportError: @@ -37,7 +38,7 @@ class BertClassifierHistory(History): """ def __init__(self, opt, **kwargs): - self.sep_last_utt = opt.get('sep_last_utt', False) + self.sep_last_utt = opt.get("sep_last_utt", False) super().__init__(opt, **kwargs) def get_history_vec(self): @@ -64,13 +65,13 @@ class BertClassifierAgent(TorchClassifierAgent): def __init__(self, opt, shared=None): # download pretrained models - download(opt['datapath']) - self.pretrained_path = os.path.join( - opt['datapath'], 'models', 'bert_models', MODEL_PATH + download(opt["datapath"]) + self.pretrained_path = PathManager.get_local_path( + os.path.join(opt["datapath"], "models", "bert_models", MODEL_PATH) ) - opt['pretrained_path'] = self.pretrained_path - self.add_cls_token = opt.get('add_cls_token', True) - self.sep_last_utt = opt.get('sep_last_utt', False) + opt["pretrained_path"] = self.pretrained_path + self.add_cls_token = opt.get("add_cls_token", True) + self.sep_last_utt = opt.get("sep_last_utt", False) super().__init__(opt, shared) @classmethod @@ -88,33 +89,33 @@ def add_cmdline_args( Add CLI args. """ super().add_cmdline_args(parser, partial_opt=partial_opt) - parser = parser.add_argument_group('BERT Classifier Arguments') + parser = parser.add_argument_group("BERT Classifier Arguments") parser.add_argument( - '--type-optimization', + "--type-optimization", type=str, - default='all_encoder_layers', + default="all_encoder_layers", choices=[ - 'additional_layers', - 'top_layer', - 'top4_layers', - 'all_encoder_layers', - 'all', + "additional_layers", + "top_layer", + "top4_layers", + "all_encoder_layers", + "all", ], - help='which part of the encoders do we optimize ' - '(defaults to all layers)', + help="which part of the encoders do we optimize " + "(defaults to all layers)", ) parser.add_argument( - '--add-cls-token', - type='bool', + "--add-cls-token", + type="bool", default=True, - help='add [CLS] token to text vec', + help="add [CLS] token to text vec", ) parser.add_argument( - '--sep-last-utt', - type='bool', + "--sep-last-utt", + type="bool", default=False, - help='separate the last utterance into a different' - 'segment with [SEP] token in between', + help="separate the last utterance into a different" + "segment with [SEP] token in between", ) parser.set_defaults(dict_maxexs=0) # skip building dictionary return parser @@ -135,9 +136,9 @@ def upgrade_opt(cls, opt_on_disk): # 2019-06-25: previous versions of the model did not add a CLS token # to the beginning of text_vec. - if 'add_cls_token' not in opt_on_disk: - warn_once('Old model: overriding `add_cls_token` to False.') - opt_on_disk['add_cls_token'] = False + if "add_cls_token" not in opt_on_disk: + warn_once("Old model: overriding `add_cls_token` to False.") + opt_on_disk["add_cls_token"] = False return opt_on_disk @@ -150,16 +151,16 @@ def build_model(self): def _set_text_vec(self, *args, **kwargs): obs = super()._set_text_vec(*args, **kwargs) - if 'text_vec' in obs and self.add_cls_token: + if "text_vec" in obs and self.add_cls_token: # insert [CLS] token - if 'added_start_end_tokens' not in obs: + if "added_start_end_tokens" not in obs: # Sometimes the obs is cached (meaning its the same object # passed the next time) and if so, we would continually re-add # the start/end tokens. So, we need to test if already done start_tensor = torch.LongTensor([self.dict.start_idx]) - new_text_vec = torch.cat([start_tensor, obs['text_vec']], 0) - obs.force_set('text_vec', new_text_vec) - obs['added_start_end_tokens'] = True + new_text_vec = torch.cat([start_tensor, obs["text_vec"]], 0) + obs.force_set("text_vec", new_text_vec) + obs["added_start_end_tokens"] = True return obs def score(self, batch): diff --git a/parlai/agents/bert_ranker/bert_dictionary.py b/parlai/agents/bert_ranker/bert_dictionary.py index 268a12fd490..e6cab6d684f 100644 --- a/parlai/agents/bert_ranker/bert_dictionary.py +++ b/parlai/agents/bert_ranker/bert_dictionary.py @@ -4,20 +4,21 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from parlai.core.dict import DictionaryAgent -from parlai.zoo.bert.build import download +from parlai.utils.io import PathManager from parlai.utils.misc import warn_once +from parlai.zoo.bert.build import download try: from pytorch_pretrained_bert import BertTokenizer except ImportError: raise ImportError( - 'BERT rankers needs pytorch-pretrained-BERT installed. \n ' - 'pip install pytorch-pretrained-bert' + "BERT rankers needs pytorch-pretrained-BERT installed. \n " + "pip install pytorch-pretrained-bert" ) -from .helpers import VOCAB_PATH - import os +from .helpers import VOCAB_PATH + class BertDictionaryAgent(DictionaryAgent): """ @@ -31,22 +32,24 @@ def __init__(self, opt): super().__init__(opt) # initialize from vocab path warn_once( - 'WARNING: BERT uses a Hugging Face tokenizer; ParlAI dictionary args are ignored' + "WARNING: BERT uses a Hugging Face tokenizer; ParlAI dictionary args are ignored" + ) + download(opt["datapath"]) + vocab_path = PathManager.get_local_path( + os.path.join(opt["datapath"], "models", "bert_models", VOCAB_PATH) ) - download(opt['datapath']) - vocab_path = os.path.join(opt['datapath'], 'models', 'bert_models', VOCAB_PATH) self.tokenizer = BertTokenizer.from_pretrained(vocab_path) - self.start_token = '[CLS]' - self.end_token = '[SEP]' - self.null_token = '[PAD]' - self.start_idx = self.tokenizer.convert_tokens_to_ids(['[CLS]'])[ + self.start_token = "[CLS]" + self.end_token = "[SEP]" + self.null_token = "[PAD]" + self.start_idx = self.tokenizer.convert_tokens_to_ids(["[CLS]"])[ 0 ] # should be 101 - self.end_idx = self.tokenizer.convert_tokens_to_ids(['[SEP]'])[ + self.end_idx = self.tokenizer.convert_tokens_to_ids(["[SEP]"])[ 0 ] # should be 102 - self.pad_idx = self.tokenizer.convert_tokens_to_ids(['[PAD]'])[0] # should be 0 + self.pad_idx = self.tokenizer.convert_tokens_to_ids(["[PAD]"])[0] # should be 0 # set tok2ind for special tokens self.tok2ind[self.start_token] = self.start_idx self.tok2ind[self.end_token] = self.end_idx @@ -68,7 +71,7 @@ def vec2txt(self, vec): else: idxs = vec toks = self.tokenizer.convert_ids_to_tokens(idxs) - return ' '.join(toks) + return " ".join(toks) def act(self): return {}