Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Add support for PathManager. #3011

Merged
merged 7 commits into from
Aug 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions parlai/agents/bart/convert_fairseq_to_parlai.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser
from parlai.core.script import ParlaiScript
from parlai.utils.io import PathManager


TRANSFORMER_PARAMETER_MAPPING = {
'attention_heads': 'n_heads',
Expand Down Expand Up @@ -240,7 +242,7 @@ def _load_single_fairseq_checkpoint(self, path: str) -> Dict[str, Any]:
:return state:
loaded fairseq state
"""
with open(path, "rb") as f:
with PathManager.open(path, "rb") as f:
try:
state = torch.load(
f, map_location=lambda s, l: default_restore_location(s, "cpu")
Expand Down Expand Up @@ -397,7 +399,7 @@ def convert_model_weight(self, opt: Opt) -> Dict[str, Any]:
# 6. Shuffle embedding matrix given dictionary.
enc_emb_key = 'encoder.embeddings.weight'
bart_dict = os.path.join(opt['datapath'], 'models/bart/bart.large/dict.txt')
with open(bart_dict) as f:
with PathManager.open(bart_dict) as f:
offset_dict = {i: l.split()[0] for i, l in enumerate(f.readlines())}
new_embs = return_dict[enc_emb_key].clone()
for idx, new_idx in offset_dict.items():
Expand Down
3 changes: 2 additions & 1 deletion parlai/agents/bert_ranker/bi_encoder_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from parlai.core.torch_ranker_agent import TorchRankerAgent
from parlai.utils.torch import padded_3d
from parlai.zoo.bert.build import download
from parlai.utils.io import PathManager

from .bert_dictionary import BertDictionaryAgent
from .helpers import (
Expand Down Expand Up @@ -101,7 +102,7 @@ def set_vocab_candidates(self, shared):
"".format(len(self.vocab_candidates))
)
enc_path = self.opt.get('model_file') + '.vocab.encs'
if os.path.isfile(enc_path):
if PathManager.exists(enc_path):
self.vocab_candidate_encs = self.load_candidates(
enc_path, cand_type='vocab encodings'
)
Expand Down
10 changes: 5 additions & 5 deletions parlai/agents/drqa/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from parlai.core.build_data import modelzoo_path
from parlai.utils.io import PathManager


def add_cmdline_args(parser):
Expand Down Expand Up @@ -161,10 +161,10 @@ def add_cmdline_args(parser):
def set_defaults(opt):
init_model = None
# check first for 'init_model' for loading model from file
if opt.get('init_model') and os.path.isfile(opt['init_model']):
if opt.get('init_model') and PathManager.exists(opt['init_model']):
init_model = opt['init_model']
# next check for 'model_file', this would override init_model
if opt.get('model_file') and os.path.isfile(opt['model_file']):
if opt.get('model_file') and PathManager.exists(opt['model_file']):
init_model = opt['model_file']

if init_model is None:
Expand All @@ -173,9 +173,9 @@ def set_defaults(opt):
opt.get('datapath'), opt['embedding_file']
)
if opt.get('embedding_file'):
if not os.path.isfile(opt['embedding_file']):
if not PathManager.exists(opt['embedding_file']):
raise IOError('No such file: %s' % opt['embedding_file'])
with open(opt['embedding_file']) as f:
with PathManager.open(opt['embedding_file']) as f:
dim = len(f.readline().strip().split(' ')) - 1
if dim == 1:
# first line was a dud
Expand Down
8 changes: 4 additions & 4 deletions parlai/agents/drqa/drqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
raise ImportError('Need to install pytorch: go to pytorch.org')

import bisect
import os
import numpy as np
import json
import random

from parlai.core.agents import Agent
from parlai.core.dict import DictionaryAgent
from parlai.core.build_data import modelzoo_path
from parlai.utils.io import PathManager
from . import config
from .utils import build_feature_dict, vectorize, batchify, normalize_text
from .model import DocReaderModel
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(self, *args, **kwargs):
self.opt['embedding_file'] = modelzoo_path(
self.opt.get('datapath'), self.opt['embedding_file']
)
with open(self.opt['embedding_file']) as f:
with PathManager.open(self.opt['embedding_file']) as f:
for line in f:
w = normalize_text(line.rstrip().split(' ')[0])
self.embedding_words.add(w)
Expand Down Expand Up @@ -128,7 +128,7 @@ def __init__(self, opt, shared=None):
else:
# set up model
self.word_dict = DrqaAgent.dictionary_class()(opt)
if self.opt.get('model_file') and os.path.isfile(opt['model_file']):
if self.opt.get('model_file') and PathManager.exists(opt['model_file']):
self._init_from_saved(opt['model_file'])
else:
if self.opt.get('init_model'):
Expand Down Expand Up @@ -274,7 +274,7 @@ def save(self, fname=None):
self.opt['trained'] = True
self.model.save(fname)
# save opt file
with open(fname + '.opt', 'w') as handle:
with PathManager.open(fname + '.opt', 'w') as handle:
json.dump(self.opt, handle)

# --------------------------------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion parlai/agents/drqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import unicodedata
from collections import Counter

from parlai.utils.io import PathManager
from parlai.core.build_data import modelzoo_path


Expand All @@ -29,7 +30,7 @@ def load_embeddings(opt, word_dict):
# Fill in embeddings
if not opt.get('embedding_file'):
raise RuntimeError('Tried to load embeddings with no embedding file.')
with open(opt['embedding_file']) as f:
with PathManager.open(opt['embedding_file']) as f:
for line in f:
parsed = line.rstrip().split(' ')
if len(parsed) > 2:
Expand Down
5 changes: 3 additions & 2 deletions parlai/agents/ir_baseline/ir_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from parlai.core.agents import Agent
from parlai.core.dict import DictionaryAgent
from parlai.utils.io import PathManager


class MaxPriorityQueue(Sequence):
Expand Down Expand Up @@ -330,9 +331,9 @@ def save(self, path=None):
self.dictionary.save(path + '.dict')
data = {}
data['opt'] = self.opt
with open(path, 'wb') as handle:
with PathManager.open(path, 'wb') as handle:
torch.save(data, handle)
with open(path + '.opt', 'w') as handle:
with PathManager.open(path + '.opt', 'w') as handle:
json.dump(self.opt, handle)

def load(self, fname):
Expand Down
12 changes: 6 additions & 6 deletions parlai/agents/starspace/starspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from parlai.core.dict import DictionaryAgent
from parlai.utils.misc import maintain_dialog_history, load_cands
from parlai.core.torch_agent import TorchAgent
from parlai.utils.io import PathManager
from .modules import Starspace

import torch
Expand All @@ -20,7 +21,6 @@
from collections import deque

import copy
import os
import random
import json

Expand Down Expand Up @@ -198,7 +198,7 @@ def __init__(self, opt, shared=None):
print("[ creating StarspaceAgent ]")
# this is not a shared instance of this class, so do full init
if opt.get('model_file') and (
os.path.isfile(opt.get('model_file') + '.dict')
PathManager.exists(opt.get('model_file') + '.dict')
or (opt['dict_file'] is None)
):
# set default dict-file if not set
Expand All @@ -207,7 +207,7 @@ def __init__(self, opt, shared=None):
self.dict = DictionaryAgent(opt)

self.model = Starspace(opt, len(self.dict), self.dict)
if opt.get('model_file') and os.path.isfile(opt['model_file']):
if opt.get('model_file') and PathManager.exists(opt['model_file']):
self.load(opt['model_file'])
else:
self._init_embeddings()
Expand Down Expand Up @@ -434,7 +434,7 @@ def predict(self, xs, ys=None, cands=None, cands_txt=None, obs=None):
for c in negs:
print("neg: " + self.v2t(c.squeeze()))
print("---")
y = -torch.ones(xe.size(0))
y = -(torch.ones(xe.size(0)))
y[0] = 1
loss = self.criterion(xe, ye, y)
loss.backward()
Expand Down Expand Up @@ -585,9 +585,9 @@ def save(self, path=None):
data['model'] = self.model.state_dict()
data['optimizer'] = self.optimizer.state_dict()
data['opt'] = self.opt
with open(path, 'wb') as handle:
with PathManager.open(path, 'wb') as handle:
torch.save(data, handle)
with open(path + '.opt', 'w') as handle:
with PathManager.open(path + '.opt', 'w') as handle:
json.dump(self.opt, handle)

def load(self, path):
Expand Down
4 changes: 2 additions & 2 deletions parlai/agents/tfidf_retriever/build_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
"""

import sqlite3
import os

from tqdm import tqdm

from collections import deque
import random
from parlai.core.teachers import create_task_agent_from_taskname
import parlai.utils.logging as logging
from parlai.utils.io import PathManager

# ------------------------------------------------------------------------------
# Store corpus.
Expand All @@ -33,7 +33,7 @@ def store_contents(opt, task, save_path, context_length=-1, include_labels=True)
save_path: Path to output sqlite db.
num_workers: Number of parallel processes to use when reading docs.
"""
if os.path.isfile(save_path):
if PathManager.exists(save_path):
raise RuntimeError('%s already exists! Not overwriting.' % save_path)

logging.info('Reading into database...')
Expand Down
5 changes: 3 additions & 2 deletions parlai/agents/tfidf_retriever/tfidf_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)

from parlai.core.agents import Agent
from parlai.utils.io import PathManager
from parlai.utils.misc import AttrDict
from .doc_db import DocDB
from .tfidf_doc_ranker import TfidfDocRanker
Expand Down Expand Up @@ -218,9 +219,9 @@ def rebuild(self):

def save(self, path=None):
self.rebuild()
with open(self.opt['model_file'] + '.opt', 'w') as handle:
with PathManager.open(self.opt['model_file'] + '.opt', 'w') as handle:
json.dump(self.opt, handle)
with open(self.opt['model_file'], 'w') as f:
with PathManager.open(self.opt['model_file'], 'w') as f:
f.write('\n')

def train_act(self):
Expand Down
5 changes: 3 additions & 2 deletions parlai/agents/unigram/unigram.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from parlai.core.agents import Agent
from parlai.core.dict import DictionaryAgent
from itertools import islice
from parlai.utils.io import PathManager


class UnigramAgent(Agent):
Expand Down Expand Up @@ -109,10 +110,10 @@ def save(self, path=None):
if not path:
return

with open(path, 'w') as f:
with PathManager.open(path, 'w') as f:
f.write(self.get_prediction() + '\n')

with open(path + '.opt', 'w') as f:
with PathManager.open(path + '.opt', 'w') as f:
json.dump(self.opt, f)

def load(self, path):
Expand Down
3 changes: 2 additions & 1 deletion parlai/chat_service/services/messenger/messenger_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from parlai.core.agents import create_agent
import parlai.chat_service.utils.logging as log_utils
import parlai.chat_service.utils.server as server_utils
from parlai.utils.io import PathManager
from parlai.chat_service.services.messenger.agents import MessengerAgent
from parlai.chat_service.core.socket import ChatServiceMessageSocket
from parlai.chat_service.services.messenger.message_sender import MessageSender
Expand Down Expand Up @@ -222,7 +223,7 @@ def get_app_token(self):
"""
if not self.opt.get('force_page_token'):
if not os.path.exists(os.path.expanduser('~/.parlai/')):
os.makedirs(os.path.expanduser('~/.parlai/'))
PathManager.mkdirs(os.path.expanduser('~/.parlai/'))
access_token_file_path = '~/.parlai/messenger_token'
expanded_file_path = os.path.expanduser(access_token_file_path)
if os.path.exists(expanded_file_path):
Expand Down
15 changes: 8 additions & 7 deletions parlai/core/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@
``MultiTaskTeacher``.
"""

import copy

from parlai.core.build_data import modelzoo_path
from parlai.core.loader import load_agent_module
from parlai.core.loader import register_agent # noqa: F401
from parlai.core.opt import Opt
from parlai.utils.misc import warn_once
import copy
import os
import parlai.utils.logging as logging
from parlai.utils.io import PathManager


NOCOPY_ARGS = [
Expand Down Expand Up @@ -205,7 +206,7 @@ def compare_init_model_opts(opt: Opt, curr_opt: Opt):
return
opt['init_model'] = modelzoo_path(opt['datapath'], opt['init_model'])
optfile = opt['init_model'] + '.opt'
if not os.path.isfile(optfile):
if not PathManager.exists(optfile):
return
init_model_opt = Opt.load(optfile)

Expand Down Expand Up @@ -294,7 +295,7 @@ def create_agent_from_opt_file(opt: Opt):
model_file = opt['model_file']
optfile = model_file + '.opt'

if not os.path.isfile(optfile):
if not PathManager.exists(optfile):
return None

opt_from_file = Opt.load(optfile)
Expand Down Expand Up @@ -328,12 +329,12 @@ def create_agent_from_opt_file(opt: Opt):
# update dict file path
if not opt_from_file.get('dict_file'):
opt_from_file['dict_file'] = model_file + '.dict'
elif opt_from_file.get('dict_file') and not os.path.isfile(
elif opt_from_file.get('dict_file') and not PathManager.exists(
opt_from_file['dict_file']
):
old_dict_file = opt_from_file['dict_file']
opt_from_file['dict_file'] = model_file + '.dict'
if not os.path.isfile(opt_from_file['dict_file']):
if not PathManager.exists(opt_from_file['dict_file']):
warn_once(
'WARNING: Neither the specified dict file ({}) nor the '
'`model_file`.dict file ({}) exists, check to make sure either '
Expand Down Expand Up @@ -384,7 +385,7 @@ def create_agent(opt: Opt, requireModelExists=False):

if opt.get('model_file'):
opt['model_file'] = modelzoo_path(opt.get('datapath'), opt['model_file'])
if requireModelExists and not os.path.isfile(opt['model_file']):
if requireModelExists and not PathManager.exists(opt['model_file']):
raise RuntimeError(
'WARNING: Model file does not exist, check to make '
'sure it is correct: {}'.format(opt['model_file'])
Expand Down
Loading