Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing bug with restoring checkpoint with two gpus + cleaning CUDA parsing related code #928

Merged
merged 11 commits into from
Oct 16, 2019
2 changes: 1 addition & 1 deletion jiant/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,8 @@ def main(cl_arguments):
# Load tasks
log.info("Loading tasks...")
start_time = time.time()
pretrain_tasks, target_tasks, vocab, word_embs = build_tasks(args)
cuda_device = parse_cuda_list_arg(args.cuda)
pretrain_tasks, target_tasks, vocab, word_embs = build_tasks(args)
tasks = sorted(set(pretrain_tasks + target_tasks), key=lambda x: x.name)
log.info("\tFinished loading tasks in %.3fs", time.time() - start_time)
log.info("\t Tasks: {}".format([task.name for task in tasks]))
Expand Down
2 changes: 1 addition & 1 deletion jiant/config/defaults.conf
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

// Misc. Logistics //

cuda = -1 // GPU ID. Set to -1 for CPU, "auto" for all available GPUs on machine and
cuda = auto // GPU ID. Set to -1 for CPU, "auto" for all available GPUs on machine and
// a comma-delimited list of GPU IDs for a subset of GPUs.
random_seed = 1234 // Global random seed, used in both Python and PyTorch random number generators.
track_batch_utilization = 0 // Track % of each batch that is padding tokens (for tasks with field
Expand Down
2 changes: 1 addition & 1 deletion jiant/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,7 +1204,7 @@ def _restore_checkpoint(self, phase, tasks=None):
self._serialization_dir, task_directory, "_".join(["metric", suffix])
)

model_state = torch.load(model_path, map_location=device_mapping(self._cuda_device))
model_state = torch.load(model_path)

for name, param in self._model.named_parameters():
if param.requires_grad and name not in model_state:
Expand Down
24 changes: 14 additions & 10 deletions jiant/utils/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Functions for parsing configs.
"""
import torch
import logging as log


def parse_task_list_arg(task_list):
Expand All @@ -19,22 +20,25 @@ def parse_task_list_arg(task_list):
return task_names


def parse_cuda_list_arg(cuda_list):
def parse_cuda_list_arg(cuda_arg):
"""
Parse cuda_list settings
"""
result_cuda = []
if cuda_list == "auto":
cuda_list = list(range(torch.cuda.device_count()))
return cuda_list
elif isinstance(cuda_list, int):
return cuda_list
elif "," in cuda_list:
return [int(d) for d in cuda_list.split(",")]
if cuda_arg == "auto":
result_cuda = list(range(torch.cuda.device_count()))
if len(result_cuda) == 1:
result_cuda = result_cuda[0]
elif len(result_cuda) == 0:
result_cuda = -1
elif isinstance(cuda_arg, int):
result_cuda = cuda_arg
elif "," in cuda_arg:
result_cuda = [int(d) for d in cuda_arg.split(",")]
else:
raise ValueError(
"Your cuda settings do not match any of the possibilities in defaults.conf"
)
if len(result_cuda) == 1:
result_cuda = result_cuda[0]
if torch.cuda.device_count() == 0 and not (isinstance(result_cuda, int) and result_cuda == -1):
raise ValueError("You specified usage of CUDA but CUDA devices not found.")
return result_cuda
5 changes: 1 addition & 4 deletions jiant/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,10 +418,7 @@ def format_output(obj, cuda_devices):


def uses_cuda(cuda_devices):
use_cuda = 1
if isinstance(cuda_devices, list) or isinstance(cuda_devices, int) and cuda_devices >= 0:
return use_cuda
return 0
return isinstance(cuda_devices, list) or (isinstance(cuda_devices, int) and cuda_devices >= 0)


def get_batch_size(batch, cuda_devices, keyword="input"):
Expand Down