Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Make the code compatible for fbcode #3964

Merged
merged 5 commits into from
Aug 25, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion parlai/agents/rag/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
Conversion Scripts for RAG/DPR.
"""
from collections import OrderedDict
import os
import torch
from transformers import BertModel
from typing import Dict

from parlai.utils.io import PathManager
from parlai.utils.misc import recursive_getattr

# Mapping from BERT key to ParlAI Key
Expand All @@ -35,6 +37,7 @@ class BertConversionUtils:

@staticmethod
def load_bert_state(
datapath: str,
state_dict: Dict[str, torch.Tensor],
pretrained_dpr_path: str,
encoder_type: str = 'query',
Expand All @@ -52,7 +55,15 @@ def load_bert_state(
:return new_state_dict:
return a state_dict with loaded weights.
"""
bert_model = BertModel.from_pretrained('bert-base-uncased')

try:
bert_model = BertModel.from_pretrained('bert-base-uncased')
except OSError:
model_path = PathManager.get_local_path(
os.path.join(datapath, "bert_base_uncased")
)
bert_model = BertModel.from_pretrained(model_path)

if pretrained_dpr_path:
BertConversionUtils.load_dpr_model(
bert_model, pretrained_dpr_path, encoder_type
Expand Down
19 changes: 16 additions & 3 deletions parlai/agents/rag/dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from parlai.core.loader import register_agent
from parlai.core.opt import Opt
from parlai.core.torch_ranker_agent import TorchRankerAgent
from parlai.utils.io import PathManager
import parlai.utils.logging as logging


Expand Down Expand Up @@ -84,6 +85,8 @@ class DprEncoder(TransformerEncoder):
models.
"""

CONFIG_PATH = 'config.json'

def __init__(
self,
opt: Opt,
Expand All @@ -92,7 +95,14 @@ def __init__(
encoder_type: str = 'query',
):
# Override options
config: BertConfig = BertConfig.from_pretrained('bert-base-uncased')
try:
config: BertConfig = BertConfig.from_pretrained('bert-base-uncased')
except OSError:
config_path = PathManager.get_local_path(
os.path.join(opt['datapath'], "bert_base_uncased", self.CONFIG_PATH)
)
config: BertConfig = BertConfig.from_pretrained(config_path)

pretrained_path = modelzoo_path(
opt['datapath'], pretrained_path
) # type: ignore
Expand Down Expand Up @@ -131,9 +141,11 @@ def __init__(
reduction_type='first',
)

self._load_state(dpr_model, pretrained_path, encoder_type)
self._load_state(opt['datapath'], dpr_model, pretrained_path, encoder_type)

def _load_state(self, dpr_model: str, pretrained_path: str, encoder_type: str):
def _load_state(
self, datapath: str, dpr_model: str, pretrained_path: str, encoder_type: str
):
"""
Load pre-trained model states.

Expand All @@ -146,6 +158,7 @@ def _load_state(self, dpr_model: str, pretrained_path: str, encoder_type: str):
"""
if dpr_model == 'bert':
state_dict = BertConversionUtils.load_bert_state(
datapath,
self.state_dict(),
pretrained_dpr_path=pretrained_path,
encoder_type=encoder_type,
Expand Down
22 changes: 19 additions & 3 deletions parlai/agents/rag/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
import torch.cuda
import torch.nn
import transformers
from tqdm import tqdm

try:
Expand All @@ -36,6 +37,7 @@
import parlai.utils.logging as logging
from parlai.utils.torch import padded_tensor
from parlai.utils.typing import TShared
from parlai.utils.io import PathManager

from parlai.agents.rag.dpr import DprQueryEncoder
from parlai.agents.rag.polyfaiss import RagDropoutPolyWrapper
Expand Down Expand Up @@ -225,8 +227,11 @@ class RagRetrieverTokenizer:
Wrapper for various tokenizers used by RAG Query Model.
"""

VOCAB_PATH = 'vocab.txt'

def __init__(
self,
datapath: str,
query_model: str,
dictionary: DictionaryAgent,
max_length: int = 256,
Expand All @@ -242,6 +247,7 @@ def __init__(
:param max_length:
maximum length of encoding.
"""
self.datapath = datapath
self.query_model = query_model
self.tokenizer = self._init_tokenizer(dictionary)
self.max_length = max_length
Expand All @@ -259,7 +265,13 @@ def _init_tokenizer(
ParlAI dictionary agent
"""
if self.query_model in ['bert', 'bert_from_parlai_rag']:
return BertTokenizer.from_pretrained('bert-base-uncased')
try:
return BertTokenizer.from_pretrained('bert-base-uncased')
except ImportError or OSError:
apravesh marked this conversation as resolved.
Show resolved Hide resolved
vocab_path = PathManager.get_local_path(
os.path.join(self.datapath, "bert_base_uncased", self.VOCAB_PATH)
)
return transformers.BertTokenizer.from_pretrained(vocab_path)
else:
return dictionary

Expand Down Expand Up @@ -371,6 +383,7 @@ def __init__(self, opt: Opt, dictionary: DictionaryAgent, shared: TShared = None
self.max_query_len = opt['rag_query_truncate'] or 1024
self.end_idx = dictionary[dictionary.end_token]
self._tokenizer = RagRetrieverTokenizer(
datapath=opt['datapath'],
query_model=opt['query_model'],
dictionary=dictionary,
delimiter=opt.get('delimiter', '\n') or '\n',
Expand Down Expand Up @@ -830,7 +843,10 @@ def _build_reranker(
logging.enable()
assert isinstance(agent, TorchRankerAgent)

return agent.model, RagRetrieverTokenizer('', agent.dict, max_length=360)
return (
agent.model,
RagRetrieverTokenizer(opt['datapath'], '', agent.dict, max_length=360),
)

def _retrieve_initial(
self, query: torch.LongTensor
Expand Down Expand Up @@ -967,7 +983,7 @@ def __init__(self, opt: Opt, dictionary: DictionaryAgent, shared: TShared = None
self.polyencoder = self.dropout_poly.model

self.poly_tokenizer = RagRetrieverTokenizer(
opt['query_model'], self.dropout_poly.dict, max_length=360
opt['datapath'], opt['query_model'], self.dropout_poly.dict, max_length=360
)

model = (
Expand Down
1 change: 1 addition & 0 deletions projects/blenderbot2/agents/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ def __init__(
pretrained_path=opt['memory_writer_model_file'],
).eval()
self._tokenizer = RagRetrieverTokenizer(
datapath=opt['datapath'],
query_model=opt['query_model'],
dictionary=dictionary,
delimiter='\n',
Expand Down