From 9f100c73e0012150e7ee1ceec1feb47897c9ac89 Mon Sep 17 00:00:00 2001 From: Pravesh Agrawal Date: Thu, 19 Aug 2021 16:19:05 -0700 Subject: [PATCH 1/5] Make the code compatible for fbcode --- parlai/agents/rag/conversion_utils.py | 9 ++++++++- parlai/agents/rag/dpr.py | 13 ++++++++++--- parlai/agents/rag/retrievers.py | 20 +++++++++++++------- parlai/core/build_data.py | 1 - projects/blenderbot2/agents/modules.py | 1 + 5 files changed, 32 insertions(+), 12 deletions(-) diff --git a/parlai/agents/rag/conversion_utils.py b/parlai/agents/rag/conversion_utils.py index 5b0bbae27ec..9e295dba0ca 100644 --- a/parlai/agents/rag/conversion_utils.py +++ b/parlai/agents/rag/conversion_utils.py @@ -8,10 +8,12 @@ Conversion Scripts for RAG/DPR. """ from collections import OrderedDict +import os import torch from transformers import BertModel from typing import Dict +from parlai.utils.io import PathManager from parlai.utils.misc import recursive_getattr # Mapping from BERT key to ParlAI Key @@ -35,6 +37,7 @@ class BertConversionUtils: @staticmethod def load_bert_state( + datapath: str, state_dict: Dict[str, torch.Tensor], pretrained_dpr_path: str, encoder_type: str = 'query', @@ -52,7 +55,11 @@ def load_bert_state( :return new_state_dict: return a state_dict with loaded weights. """ - bert_model = BertModel.from_pretrained('bert-base-uncased') + + model_path = PathManager.get_local_path( + os.path.join(datapath, "bert_base_uncased") + ) + bert_model = BertModel.from_pretrained(model_path) if pretrained_dpr_path: BertConversionUtils.load_dpr_model( bert_model, pretrained_dpr_path, encoder_type diff --git a/parlai/agents/rag/dpr.py b/parlai/agents/rag/dpr.py index 32f9cd0fed8..63ac3d10ed7 100644 --- a/parlai/agents/rag/dpr.py +++ b/parlai/agents/rag/dpr.py @@ -17,6 +17,7 @@ from parlai.core.loader import register_agent from parlai.core.opt import Opt from parlai.core.torch_ranker_agent import TorchRankerAgent +from parlai.utils.io import PathManager import parlai.utils.logging as logging @@ -84,6 +85,8 @@ class DprEncoder(TransformerEncoder): models. """ + CONFIG_PATH = 'config.json' + def __init__( self, opt: Opt, @@ -92,7 +95,10 @@ def __init__( encoder_type: str = 'query', ): # Override options - config: BertConfig = BertConfig.from_pretrained('bert-base-uncased') + config_path = PathManager.get_local_path( + os.path.join(opt['datapath'], "bert_base_uncased", self.CONFIG_PATH) + ) + config: BertConfig = BertConfig.from_pretrained(config_path) pretrained_path = modelzoo_path( opt['datapath'], pretrained_path ) # type: ignore @@ -131,9 +137,9 @@ def __init__( reduction_type='first', ) - self._load_state(dpr_model, pretrained_path, encoder_type) + self._load_state(opt['datapath'], dpr_model, pretrained_path, encoder_type) - def _load_state(self, dpr_model: str, pretrained_path: str, encoder_type: str): + def _load_state(self, datapath: str, dpr_model: str, pretrained_path: str, encoder_type: str): """ Load pre-trained model states. @@ -146,6 +152,7 @@ def _load_state(self, dpr_model: str, pretrained_path: str, encoder_type: str): """ if dpr_model == 'bert': state_dict = BertConversionUtils.load_bert_state( + datapath, self.state_dict(), pretrained_dpr_path=pretrained_path, encoder_type=encoder_type, diff --git a/parlai/agents/rag/retrievers.py b/parlai/agents/rag/retrievers.py index ad6eb751da6..635ecf89c2e 100644 --- a/parlai/agents/rag/retrievers.py +++ b/parlai/agents/rag/retrievers.py @@ -17,10 +17,7 @@ import torch.nn from tqdm import tqdm -try: - from transformers import BertTokenizerFast as BertTokenizer -except ImportError: - from transformers import BertTokenizer +from transformers import BertTokenizer from typing import Tuple, List, Dict, Union, Optional, Any from typing_extensions import final from sklearn.feature_extraction.text import TfidfVectorizer @@ -36,6 +33,7 @@ import parlai.utils.logging as logging from parlai.utils.torch import padded_tensor from parlai.utils.typing import TShared +from parlai.utils.io import PathManager from parlai.agents.rag.dpr import DprQueryEncoder from parlai.agents.rag.polyfaiss import RagDropoutPolyWrapper @@ -225,8 +223,11 @@ class RagRetrieverTokenizer: Wrapper for various tokenizers used by RAG Query Model. """ + VOCAB_PATH = 'vocab.txt' + def __init__( self, + datapath: str, query_model: str, dictionary: DictionaryAgent, max_length: int = 256, @@ -242,6 +243,7 @@ def __init__( :param max_length: maximum length of encoding. """ + self.datapath = datapath self.query_model = query_model self.tokenizer = self._init_tokenizer(dictionary) self.max_length = max_length @@ -259,7 +261,10 @@ def _init_tokenizer( ParlAI dictionary agent """ if self.query_model in ['bert', 'bert_from_parlai_rag']: - return BertTokenizer.from_pretrained('bert-base-uncased') + vocab_path = PathManager.get_local_path( + os.path.join(self.datapath, "bert_base_uncased", self.VOCAB_PATH) + ) + return BertTokenizer.from_pretrained(vocab_path) else: return dictionary @@ -371,6 +376,7 @@ def __init__(self, opt: Opt, dictionary: DictionaryAgent, shared: TShared = None self.max_query_len = opt['rag_query_truncate'] or 1024 self.end_idx = dictionary[dictionary.end_token] self._tokenizer = RagRetrieverTokenizer( + datapath=opt['datapath'], query_model=opt['query_model'], dictionary=dictionary, delimiter=opt.get('delimiter', '\n') or '\n', @@ -830,7 +836,7 @@ def _build_reranker( logging.enable() assert isinstance(agent, TorchRankerAgent) - return agent.model, RagRetrieverTokenizer('', agent.dict, max_length=360) + return agent.model, RagRetrieverTokenizer(opt['datapath'], '', agent.dict, max_length=360) def _retrieve_initial( self, query: torch.LongTensor @@ -967,7 +973,7 @@ def __init__(self, opt: Opt, dictionary: DictionaryAgent, shared: TShared = None self.polyencoder = self.dropout_poly.model self.poly_tokenizer = RagRetrieverTokenizer( - opt['query_model'], self.dropout_poly.dict, max_length=360 + opt['datapath'], opt['query_model'], self.dropout_poly.dict, max_length=360 ) model = ( diff --git a/parlai/core/build_data.py b/parlai/core/build_data.py index 9a707905154..c598fd70bde 100644 --- a/parlai/core/build_data.py +++ b/parlai/core/build_data.py @@ -495,7 +495,6 @@ def modelzoo_path(datapath, path): animal_ = '.'.join(animal.split(".")[:-1]) + '.build' module_name_ = 'parlai.zoo.{}'.format(animal_) my_module = importlib.import_module(module_name_) - my_module.download(datapath) except (ImportError, AttributeError) as exc: # truly give up raise ImportError( diff --git a/projects/blenderbot2/agents/modules.py b/projects/blenderbot2/agents/modules.py index 7b813313d9c..de55f711c69 100644 --- a/projects/blenderbot2/agents/modules.py +++ b/projects/blenderbot2/agents/modules.py @@ -669,6 +669,7 @@ def __init__( pretrained_path=opt['memory_writer_model_file'], ).eval() self._tokenizer = RagRetrieverTokenizer( + datapath=opt['datapath'], query_model=opt['query_model'], dictionary=dictionary, delimiter='\n', From 19dc3f7d1941a12bdc7dae962ab8aa6a3e148b68 Mon Sep 17 00:00:00 2001 From: Pravesh Agrawal Date: Mon, 23 Aug 2021 18:47:50 -0700 Subject: [PATCH 2/5] Fix based on comments --- parlai/agents/rag/conversion_utils.py | 12 ++++++++---- parlai/agents/rag/dpr.py | 12 ++++++++---- parlai/agents/rag/retrievers.py | 17 ++++++++++++----- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/parlai/agents/rag/conversion_utils.py b/parlai/agents/rag/conversion_utils.py index 9e295dba0ca..3d86e4c4252 100644 --- a/parlai/agents/rag/conversion_utils.py +++ b/parlai/agents/rag/conversion_utils.py @@ -56,10 +56,14 @@ def load_bert_state( return a state_dict with loaded weights. """ - model_path = PathManager.get_local_path( - os.path.join(datapath, "bert_base_uncased") - ) - bert_model = BertModel.from_pretrained(model_path) + try: + bert_model = BertModel.from_pretrained('bert-base-uncased') + except OSError: + model_path = PathManager.get_local_path( + os.path.join(datapath, "bert_base_uncased") + ) + bert_model = BertModel.from_pretrained(model_path) + if pretrained_dpr_path: BertConversionUtils.load_dpr_model( bert_model, pretrained_dpr_path, encoder_type diff --git a/parlai/agents/rag/dpr.py b/parlai/agents/rag/dpr.py index 63ac3d10ed7..0153674cf9b 100644 --- a/parlai/agents/rag/dpr.py +++ b/parlai/agents/rag/dpr.py @@ -95,10 +95,14 @@ def __init__( encoder_type: str = 'query', ): # Override options - config_path = PathManager.get_local_path( - os.path.join(opt['datapath'], "bert_base_uncased", self.CONFIG_PATH) - ) - config: BertConfig = BertConfig.from_pretrained(config_path) + try: + config: BertConfig = BertConfig.from_pretrained('bert-base-uncased') + except OSError: + config_path = PathManager.get_local_path( + os.path.join(opt['datapath'], "bert_base_uncased", self.CONFIG_PATH) + ) + config: BertConfig = BertConfig.from_pretrained(config_path) + pretrained_path = modelzoo_path( opt['datapath'], pretrained_path ) # type: ignore diff --git a/parlai/agents/rag/retrievers.py b/parlai/agents/rag/retrievers.py index 635ecf89c2e..fedec7fa00c 100644 --- a/parlai/agents/rag/retrievers.py +++ b/parlai/agents/rag/retrievers.py @@ -15,9 +15,13 @@ import torch import torch.cuda import torch.nn +import transformers from tqdm import tqdm -from transformers import BertTokenizer +try: + from transformers import BertTokenizerFast as BertTokenizer +except ImportError: + from transformers import BertTokenizer from typing import Tuple, List, Dict, Union, Optional, Any from typing_extensions import final from sklearn.feature_extraction.text import TfidfVectorizer @@ -261,10 +265,13 @@ def _init_tokenizer( ParlAI dictionary agent """ if self.query_model in ['bert', 'bert_from_parlai_rag']: - vocab_path = PathManager.get_local_path( - os.path.join(self.datapath, "bert_base_uncased", self.VOCAB_PATH) - ) - return BertTokenizer.from_pretrained(vocab_path) + try: + return BertTokenizer.from_pretrained('bert-base-uncased') + except ImportError or OSError: + vocab_path = PathManager.get_local_path( + os.path.join(self.datapath, "bert_base_uncased", self.VOCAB_PATH) + ) + return transformers.BertTokenizer.from_pretrained(vocab_path) else: return dictionary From fc5824b0d2b3665ebd563b9765873575b289a01f Mon Sep 17 00:00:00 2001 From: Pravesh Agrawal Date: Mon, 23 Aug 2021 18:49:45 -0700 Subject: [PATCH 3/5] bring back download --- parlai/core/build_data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/parlai/core/build_data.py b/parlai/core/build_data.py index c598fd70bde..9a707905154 100644 --- a/parlai/core/build_data.py +++ b/parlai/core/build_data.py @@ -495,6 +495,7 @@ def modelzoo_path(datapath, path): animal_ = '.'.join(animal.split(".")[:-1]) + '.build' module_name_ = 'parlai.zoo.{}'.format(animal_) my_module = importlib.import_module(module_name_) + my_module.download(datapath) except (ImportError, AttributeError) as exc: # truly give up raise ImportError( From 8833e61694b4077c85689c1b7de9c8c1913313a6 Mon Sep 17 00:00:00 2001 From: Pravesh Agrawal Date: Mon, 23 Aug 2021 19:55:27 -0700 Subject: [PATCH 4/5] Lint issues fix --- parlai/agents/rag/dpr.py | 4 +++- parlai/agents/rag/retrievers.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/parlai/agents/rag/dpr.py b/parlai/agents/rag/dpr.py index 0153674cf9b..5726444237e 100644 --- a/parlai/agents/rag/dpr.py +++ b/parlai/agents/rag/dpr.py @@ -143,7 +143,9 @@ def __init__( self._load_state(opt['datapath'], dpr_model, pretrained_path, encoder_type) - def _load_state(self, datapath: str, dpr_model: str, pretrained_path: str, encoder_type: str): + def _load_state( + self, datapath: str, dpr_model: str, pretrained_path: str, encoder_type: str + ): """ Load pre-trained model states. diff --git a/parlai/agents/rag/retrievers.py b/parlai/agents/rag/retrievers.py index fedec7fa00c..88474e47f7e 100644 --- a/parlai/agents/rag/retrievers.py +++ b/parlai/agents/rag/retrievers.py @@ -843,7 +843,10 @@ def _build_reranker( logging.enable() assert isinstance(agent, TorchRankerAgent) - return agent.model, RagRetrieverTokenizer(opt['datapath'], '', agent.dict, max_length=360) + return ( + agent.model, + RagRetrieverTokenizer(opt['datapath'], '', agent.dict, max_length=360), + ) def _retrieve_initial( self, query: torch.LongTensor From 7895ce89789e6a4b13f0ceb445b5a009d26b0d4e Mon Sep 17 00:00:00 2001 From: apravesh <83974938+apravesh@users.noreply.github.com> Date: Tue, 24 Aug 2021 13:29:33 -0700 Subject: [PATCH 5/5] Update parlai/agents/rag/retrievers.py Co-authored-by: Stephen Roller --- parlai/agents/rag/retrievers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parlai/agents/rag/retrievers.py b/parlai/agents/rag/retrievers.py index 88474e47f7e..4a5b10c04a2 100644 --- a/parlai/agents/rag/retrievers.py +++ b/parlai/agents/rag/retrievers.py @@ -267,7 +267,7 @@ def _init_tokenizer( if self.query_model in ['bert', 'bert_from_parlai_rag']: try: return BertTokenizer.from_pretrained('bert-base-uncased') - except ImportError or OSError: + except (ImportError, OSError): vocab_path = PathManager.get_local_path( os.path.join(self.datapath, "bert_base_uncased", self.VOCAB_PATH) )