Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Add pathmanager support to bert agent (#3439)
Browse files Browse the repository at this point in the history
Co-authored-by: deankita <deankita@devgpu179.prn2.facebook.com>
  • Loading branch information
ankitade and deankita authored Feb 8, 2021
1 parent 7f03a61 commit 989577c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 53 deletions.
77 changes: 39 additions & 38 deletions parlai/agents/bert_classifier/bert_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,21 @@
BERT classifier agent uses bert embeddings to make an utterance-level classification.
"""

import os
from collections import deque
from typing import Optional
from parlai.core.params import ParlaiParser
from parlai.core.opt import Opt

import torch
from parlai.agents.bert_ranker.bert_dictionary import BertDictionaryAgent
from parlai.agents.bert_ranker.helpers import BertWrapper, MODEL_PATH
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser
from parlai.core.torch_agent import History
from parlai.core.torch_classifier_agent import TorchClassifierAgent
from parlai.utils.io import PathManager
from parlai.utils.misc import warn_once
from parlai.zoo.bert.build import download

from collections import deque
import os
import torch

try:
from pytorch_pretrained_bert import BertModel
except ImportError:
Expand All @@ -37,7 +38,7 @@ class BertClassifierHistory(History):
"""

def __init__(self, opt, **kwargs):
self.sep_last_utt = opt.get('sep_last_utt', False)
self.sep_last_utt = opt.get("sep_last_utt", False)
super().__init__(opt, **kwargs)

def get_history_vec(self):
Expand All @@ -64,13 +65,13 @@ class BertClassifierAgent(TorchClassifierAgent):

def __init__(self, opt, shared=None):
# download pretrained models
download(opt['datapath'])
self.pretrained_path = os.path.join(
opt['datapath'], 'models', 'bert_models', MODEL_PATH
download(opt["datapath"])
self.pretrained_path = PathManager.get_local_path(
os.path.join(opt["datapath"], "models", "bert_models", MODEL_PATH)
)
opt['pretrained_path'] = self.pretrained_path
self.add_cls_token = opt.get('add_cls_token', True)
self.sep_last_utt = opt.get('sep_last_utt', False)
opt["pretrained_path"] = self.pretrained_path
self.add_cls_token = opt.get("add_cls_token", True)
self.sep_last_utt = opt.get("sep_last_utt", False)
super().__init__(opt, shared)

@classmethod
Expand All @@ -88,33 +89,33 @@ def add_cmdline_args(
Add CLI args.
"""
super().add_cmdline_args(parser, partial_opt=partial_opt)
parser = parser.add_argument_group('BERT Classifier Arguments')
parser = parser.add_argument_group("BERT Classifier Arguments")
parser.add_argument(
'--type-optimization',
"--type-optimization",
type=str,
default='all_encoder_layers',
default="all_encoder_layers",
choices=[
'additional_layers',
'top_layer',
'top4_layers',
'all_encoder_layers',
'all',
"additional_layers",
"top_layer",
"top4_layers",
"all_encoder_layers",
"all",
],
help='which part of the encoders do we optimize '
'(defaults to all layers)',
help="which part of the encoders do we optimize "
"(defaults to all layers)",
)
parser.add_argument(
'--add-cls-token',
type='bool',
"--add-cls-token",
type="bool",
default=True,
help='add [CLS] token to text vec',
help="add [CLS] token to text vec",
)
parser.add_argument(
'--sep-last-utt',
type='bool',
"--sep-last-utt",
type="bool",
default=False,
help='separate the last utterance into a different'
'segment with [SEP] token in between',
help="separate the last utterance into a different"
"segment with [SEP] token in between",
)
parser.set_defaults(dict_maxexs=0) # skip building dictionary
return parser
Expand All @@ -135,9 +136,9 @@ def upgrade_opt(cls, opt_on_disk):

# 2019-06-25: previous versions of the model did not add a CLS token
# to the beginning of text_vec.
if 'add_cls_token' not in opt_on_disk:
warn_once('Old model: overriding `add_cls_token` to False.')
opt_on_disk['add_cls_token'] = False
if "add_cls_token" not in opt_on_disk:
warn_once("Old model: overriding `add_cls_token` to False.")
opt_on_disk["add_cls_token"] = False

return opt_on_disk

Expand All @@ -150,16 +151,16 @@ def build_model(self):

def _set_text_vec(self, *args, **kwargs):
obs = super()._set_text_vec(*args, **kwargs)
if 'text_vec' in obs and self.add_cls_token:
if "text_vec" in obs and self.add_cls_token:
# insert [CLS] token
if 'added_start_end_tokens' not in obs:
if "added_start_end_tokens" not in obs:
# Sometimes the obs is cached (meaning its the same object
# passed the next time) and if so, we would continually re-add
# the start/end tokens. So, we need to test if already done
start_tensor = torch.LongTensor([self.dict.start_idx])
new_text_vec = torch.cat([start_tensor, obs['text_vec']], 0)
obs.force_set('text_vec', new_text_vec)
obs['added_start_end_tokens'] = True
new_text_vec = torch.cat([start_tensor, obs["text_vec"]], 0)
obs.force_set("text_vec", new_text_vec)
obs["added_start_end_tokens"] = True
return obs

def score(self, batch):
Expand Down
33 changes: 18 additions & 15 deletions parlai/agents/bert_ranker/bert_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,21 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from parlai.core.dict import DictionaryAgent
from parlai.zoo.bert.build import download
from parlai.utils.io import PathManager
from parlai.utils.misc import warn_once
from parlai.zoo.bert.build import download

try:
from pytorch_pretrained_bert import BertTokenizer
except ImportError:
raise ImportError(
'BERT rankers needs pytorch-pretrained-BERT installed. \n '
'pip install pytorch-pretrained-bert'
"BERT rankers needs pytorch-pretrained-BERT installed. \n "
"pip install pytorch-pretrained-bert"
)
from .helpers import VOCAB_PATH

import os

from .helpers import VOCAB_PATH


class BertDictionaryAgent(DictionaryAgent):
"""
Expand All @@ -31,22 +32,24 @@ def __init__(self, opt):
super().__init__(opt)
# initialize from vocab path
warn_once(
'WARNING: BERT uses a Hugging Face tokenizer; ParlAI dictionary args are ignored'
"WARNING: BERT uses a Hugging Face tokenizer; ParlAI dictionary args are ignored"
)
download(opt["datapath"])
vocab_path = PathManager.get_local_path(
os.path.join(opt["datapath"], "models", "bert_models", VOCAB_PATH)
)
download(opt['datapath'])
vocab_path = os.path.join(opt['datapath'], 'models', 'bert_models', VOCAB_PATH)
self.tokenizer = BertTokenizer.from_pretrained(vocab_path)

self.start_token = '[CLS]'
self.end_token = '[SEP]'
self.null_token = '[PAD]'
self.start_idx = self.tokenizer.convert_tokens_to_ids(['[CLS]'])[
self.start_token = "[CLS]"
self.end_token = "[SEP]"
self.null_token = "[PAD]"
self.start_idx = self.tokenizer.convert_tokens_to_ids(["[CLS]"])[
0
] # should be 101
self.end_idx = self.tokenizer.convert_tokens_to_ids(['[SEP]'])[
self.end_idx = self.tokenizer.convert_tokens_to_ids(["[SEP]"])[
0
] # should be 102
self.pad_idx = self.tokenizer.convert_tokens_to_ids(['[PAD]'])[0] # should be 0
self.pad_idx = self.tokenizer.convert_tokens_to_ids(["[PAD]"])[0] # should be 0
# set tok2ind for special tokens
self.tok2ind[self.start_token] = self.start_idx
self.tok2ind[self.end_token] = self.end_idx
Expand All @@ -68,7 +71,7 @@ def vec2txt(self, vec):
else:
idxs = vec
toks = self.tokenizer.convert_ids_to_tokens(idxs)
return ' '.join(toks)
return " ".join(toks)

def act(self):
return {}

0 comments on commit 989577c

Please sign in to comment.