Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[Agents] Create agent from model file #2826

Merged
merged 4 commits into from
Jul 10, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 79 additions & 63 deletions parlai/core/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@
import parlai.utils.logging as logging


NOCOPY_ARGS = [
'datapath', # never use the datapath from an opt dump
'batchindex', # this saved variable can cause trouble if we switch to BS=1 at test time
]


class Agent(object):
"""
Base class for all other agents.
Expand Down Expand Up @@ -256,7 +262,7 @@ def compare_init_model_opts(opt: Opt, curr_opt: Opt):
)


def create_agent_from_model_file(model_file, opt_overides=None):
def create_agent_from_model_file(model_file, opt_overrides=None):
emilydinan marked this conversation as resolved.
Show resolved Hide resolved
"""
Load agent from model file if it exists.

Expand All @@ -266,10 +272,11 @@ def create_agent_from_model_file(model_file, opt_overides=None):
The agent
"""
opt = {}
opt['model_file'] = model_file
if opt_overides is None:
opt_overides = {}
opt['override'] = opt_overides
add_datapath_and_model_args(opt)
klshuster marked this conversation as resolved.
Show resolved Hide resolved
opt['model_file'] = modelzoo_path(opt.get('datapath'), model_file)
if opt_overrides is None:
opt_overrides = {}
opt['override'] = opt_overrides
return create_agent_from_opt_file(opt)


Expand All @@ -286,54 +293,75 @@ def create_agent_from_opt_file(opt: Opt):
"""
model_file = opt['model_file']
optfile = model_file + '.opt'
if os.path.isfile(optfile):
klshuster marked this conversation as resolved.
Show resolved Hide resolved
new_opt = Opt.load(optfile)
# TODO we need a better way to say these options are never copied...
if 'datapath' in new_opt:
# never use the datapath from an opt dump
del new_opt['datapath']
if 'batchindex' in new_opt:
# This saved variable can cause trouble if we switch to BS=1 at test time
del new_opt['batchindex']
# only override opts specified in 'override' dict
if opt.get('override'):
for k, v in opt['override'].items():
if k in new_opt and str(v) != str(new_opt.get(k)):
logging.warn(
f"overriding opt['{k}'] to {v} (previously: {new_opt.get(k)})"
)
new_opt[k] = v

model_class = load_agent_module(new_opt['model'])

if hasattr(model_class, 'upgrade_opt'):
new_opt = model_class.upgrade_opt(new_opt)

# add model arguments to new_opt if they aren't in new_opt already
for k, v in opt.items():
if k not in new_opt:
new_opt[k] = v
new_opt['model_file'] = model_file
if not new_opt.get('dict_file'):
new_opt['dict_file'] = model_file + '.dict'
elif new_opt.get('dict_file') and not os.path.isfile(new_opt['dict_file']):
old_dict_file = new_opt['dict_file']
new_opt['dict_file'] = model_file + '.dict'
if not os.path.isfile(new_opt['dict_file']):
warn_once(
'WARNING: Neither the specified dict file ({}) nor the '
'`model_file`.dict file ({}) exists, check to make sure either '
'is correct. This may manifest as a shape mismatch later '
'on.'.format(old_dict_file, new_opt['dict_file'])
)

# if we want to load weights from --init-model, compare opts with
# loaded ones
compare_init_model_opts(opt, new_opt)
return model_class(new_opt)
else:
if not os.path.isfile(optfile):
return None

opt_from_file = Opt.load(optfile)

# delete args that we do not want to copy over when loading the model
for arg in NOCOPY_ARGS:
if arg in opt_from_file:
del opt_from_file[arg]

# only override opts specified in 'override' dict
if opt.get('override'):
for k, v in opt['override'].items():
if k in opt_from_file and str(v) != str(opt_from_file.get(k)):
logging.warn(
f'Overriding opt["{k}"] to {v} (previously: {opt_from_file.get(k)})'
)
opt_from_file[k] = v

model_class = load_agent_module(opt_from_file['model'])

if hasattr(model_class, 'upgrade_opt'):
opt_from_file = model_class.upgrade_opt(opt_from_file)

# add model arguments to opt_from_file if they aren't in opt_from_file already
for k, v in opt.items():
if k not in opt_from_file:
opt_from_file[k] = v

opt_from_file['model_file'] = model_file # update model file path

# update dict file path
if not opt_from_file.get('dict_file'):
opt_from_file['dict_file'] = model_file + '.dict'
elif opt_from_file.get('dict_file') and not os.path.isfile(
opt_from_file['dict_file']
):
old_dict_file = opt_from_file['dict_file']
opt_from_file['dict_file'] = model_file + '.dict'
if not os.path.isfile(opt_from_file['dict_file']):
warn_once(
'WARNING: Neither the specified dict file ({}) nor the '
'`model_file`.dict file ({}) exists, check to make sure either '
'is correct. This may manifest as a shape mismatch later '
'on.'.format(old_dict_file, opt_from_file['dict_file'])
)

# if we want to load weights from --init-model, compare opts with
# loaded ones
compare_init_model_opts(opt, opt_from_file)
return model_class(opt_from_file)


def add_datapath_and_model_args(opt: Opt):
# add datapath, it is missing
from parlai.core.params import ParlaiParser, get_model_name

parser = ParlaiParser(add_parlai_args=False)
parser.add_parlai_data_path()
# add model args if they are missing
model = get_model_name(opt)
if model is not None:
parser.add_model_subargs(model)
opt_parser = parser.parse_args("", print_args=False)
for k, v in opt_parser.items():
if k not in opt:
opt[k] = v


def create_agent(opt: Opt, requireModelExists=False):
"""
Expand All @@ -352,19 +380,7 @@ def create_agent(opt: Opt, requireModelExists=False):
containing the model's options).
"""
if opt.get('datapath', None) is None:
# add datapath, it is missing
from parlai.core.params import ParlaiParser, get_model_name

parser = ParlaiParser(add_parlai_args=False)
parser.add_parlai_data_path()
# add model args if they are missing
model = get_model_name(opt)
if model is not None:
parser.add_model_subargs(model)
opt_parser = parser.parse_args("", print_args=False)
for k, v in opt_parser.items():
if k not in opt:
opt[k] = v
add_datapath_and_model_args(opt)

if opt.get('model_file'):
opt['model_file'] = modelzoo_path(opt.get('datapath'), opt['model_file'])
Expand Down