diff --git a/parlai/agents/rag/conversion_utils.py b/parlai/agents/rag/conversion_utils.py index 5b0bbae27ec..3d86e4c4252 100644 --- a/parlai/agents/rag/conversion_utils.py +++ b/parlai/agents/rag/conversion_utils.py @@ -8,10 +8,12 @@ Conversion Scripts for RAG/DPR. """ from collections import OrderedDict +import os import torch from transformers import BertModel from typing import Dict +from parlai.utils.io import PathManager from parlai.utils.misc import recursive_getattr # Mapping from BERT key to ParlAI Key @@ -35,6 +37,7 @@ class BertConversionUtils: @staticmethod def load_bert_state( + datapath: str, state_dict: Dict[str, torch.Tensor], pretrained_dpr_path: str, encoder_type: str = 'query', @@ -52,7 +55,15 @@ def load_bert_state( :return new_state_dict: return a state_dict with loaded weights. """ - bert_model = BertModel.from_pretrained('bert-base-uncased') + + try: + bert_model = BertModel.from_pretrained('bert-base-uncased') + except OSError: + model_path = PathManager.get_local_path( + os.path.join(datapath, "bert_base_uncased") + ) + bert_model = BertModel.from_pretrained(model_path) + if pretrained_dpr_path: BertConversionUtils.load_dpr_model( bert_model, pretrained_dpr_path, encoder_type diff --git a/parlai/agents/rag/dpr.py b/parlai/agents/rag/dpr.py index 32f9cd0fed8..5726444237e 100644 --- a/parlai/agents/rag/dpr.py +++ b/parlai/agents/rag/dpr.py @@ -17,6 +17,7 @@ from parlai.core.loader import register_agent from parlai.core.opt import Opt from parlai.core.torch_ranker_agent import TorchRankerAgent +from parlai.utils.io import PathManager import parlai.utils.logging as logging @@ -84,6 +85,8 @@ class DprEncoder(TransformerEncoder): models. """ + CONFIG_PATH = 'config.json' + def __init__( self, opt: Opt, @@ -92,7 +95,14 @@ def __init__( encoder_type: str = 'query', ): # Override options - config: BertConfig = BertConfig.from_pretrained('bert-base-uncased') + try: + config: BertConfig = BertConfig.from_pretrained('bert-base-uncased') + except OSError: + config_path = PathManager.get_local_path( + os.path.join(opt['datapath'], "bert_base_uncased", self.CONFIG_PATH) + ) + config: BertConfig = BertConfig.from_pretrained(config_path) + pretrained_path = modelzoo_path( opt['datapath'], pretrained_path ) # type: ignore @@ -131,9 +141,11 @@ def __init__( reduction_type='first', ) - self._load_state(dpr_model, pretrained_path, encoder_type) + self._load_state(opt['datapath'], dpr_model, pretrained_path, encoder_type) - def _load_state(self, dpr_model: str, pretrained_path: str, encoder_type: str): + def _load_state( + self, datapath: str, dpr_model: str, pretrained_path: str, encoder_type: str + ): """ Load pre-trained model states. @@ -146,6 +158,7 @@ def _load_state(self, dpr_model: str, pretrained_path: str, encoder_type: str): """ if dpr_model == 'bert': state_dict = BertConversionUtils.load_bert_state( + datapath, self.state_dict(), pretrained_dpr_path=pretrained_path, encoder_type=encoder_type, diff --git a/parlai/agents/rag/retrievers.py b/parlai/agents/rag/retrievers.py index ad6eb751da6..4a5b10c04a2 100644 --- a/parlai/agents/rag/retrievers.py +++ b/parlai/agents/rag/retrievers.py @@ -15,6 +15,7 @@ import torch import torch.cuda import torch.nn +import transformers from tqdm import tqdm try: @@ -36,6 +37,7 @@ import parlai.utils.logging as logging from parlai.utils.torch import padded_tensor from parlai.utils.typing import TShared +from parlai.utils.io import PathManager from parlai.agents.rag.dpr import DprQueryEncoder from parlai.agents.rag.polyfaiss import RagDropoutPolyWrapper @@ -225,8 +227,11 @@ class RagRetrieverTokenizer: Wrapper for various tokenizers used by RAG Query Model. """ + VOCAB_PATH = 'vocab.txt' + def __init__( self, + datapath: str, query_model: str, dictionary: DictionaryAgent, max_length: int = 256, @@ -242,6 +247,7 @@ def __init__( :param max_length: maximum length of encoding. """ + self.datapath = datapath self.query_model = query_model self.tokenizer = self._init_tokenizer(dictionary) self.max_length = max_length @@ -259,7 +265,13 @@ def _init_tokenizer( ParlAI dictionary agent """ if self.query_model in ['bert', 'bert_from_parlai_rag']: - return BertTokenizer.from_pretrained('bert-base-uncased') + try: + return BertTokenizer.from_pretrained('bert-base-uncased') + except (ImportError, OSError): + vocab_path = PathManager.get_local_path( + os.path.join(self.datapath, "bert_base_uncased", self.VOCAB_PATH) + ) + return transformers.BertTokenizer.from_pretrained(vocab_path) else: return dictionary @@ -371,6 +383,7 @@ def __init__(self, opt: Opt, dictionary: DictionaryAgent, shared: TShared = None self.max_query_len = opt['rag_query_truncate'] or 1024 self.end_idx = dictionary[dictionary.end_token] self._tokenizer = RagRetrieverTokenizer( + datapath=opt['datapath'], query_model=opt['query_model'], dictionary=dictionary, delimiter=opt.get('delimiter', '\n') or '\n', @@ -830,7 +843,10 @@ def _build_reranker( logging.enable() assert isinstance(agent, TorchRankerAgent) - return agent.model, RagRetrieverTokenizer('', agent.dict, max_length=360) + return ( + agent.model, + RagRetrieverTokenizer(opt['datapath'], '', agent.dict, max_length=360), + ) def _retrieve_initial( self, query: torch.LongTensor @@ -967,7 +983,7 @@ def __init__(self, opt: Opt, dictionary: DictionaryAgent, shared: TShared = None self.polyencoder = self.dropout_poly.model self.poly_tokenizer = RagRetrieverTokenizer( - opt['query_model'], self.dropout_poly.dict, max_length=360 + opt['datapath'], opt['query_model'], self.dropout_poly.dict, max_length=360 ) model = ( diff --git a/projects/blenderbot2/agents/modules.py b/projects/blenderbot2/agents/modules.py index 7b813313d9c..de55f711c69 100644 --- a/projects/blenderbot2/agents/modules.py +++ b/projects/blenderbot2/agents/modules.py @@ -669,6 +669,7 @@ def __init__( pretrained_path=opt['memory_writer_model_file'], ).eval() self._tokenizer = RagRetrieverTokenizer( + datapath=opt['datapath'], query_model=opt['query_model'], dictionary=dictionary, delimiter='\n',