diff --git a/jiant/__main__.py b/jiant/__main__.py index df52645ab..67f948bb5 100644 --- a/jiant/__main__.py +++ b/jiant/__main__.py @@ -510,8 +510,8 @@ def main(cl_arguments): # Load tasks log.info("Loading tasks...") start_time = time.time() - pretrain_tasks, target_tasks, vocab, word_embs = build_tasks(args) cuda_device = parse_cuda_list_arg(args.cuda) + pretrain_tasks, target_tasks, vocab, word_embs = build_tasks(args) tasks = sorted(set(pretrain_tasks + target_tasks), key=lambda x: x.name) log.info("\tFinished loading tasks in %.3fs", time.time() - start_time) log.info("\t Tasks: {}".format([task.name for task in tasks])) diff --git a/jiant/config/defaults.conf b/jiant/config/defaults.conf index a7f49a9ab..e57cd48a7 100644 --- a/jiant/config/defaults.conf +++ b/jiant/config/defaults.conf @@ -22,7 +22,7 @@ // Misc. Logistics // -cuda = -1 // GPU ID. Set to -1 for CPU, "auto" for all available GPUs on machine and +cuda = auto // GPU ID. Set to -1 for CPU, "auto" for all available GPUs on machine and // a comma-delimited list of GPU IDs for a subset of GPUs. random_seed = 1234 // Global random seed, used in both Python and PyTorch random number generators. track_batch_utilization = 0 // Track % of each batch that is padding tokens (for tasks with field diff --git a/jiant/trainer.py b/jiant/trainer.py index ebcca08f5..52d649e3e 100644 --- a/jiant/trainer.py +++ b/jiant/trainer.py @@ -1204,7 +1204,7 @@ def _restore_checkpoint(self, phase, tasks=None): self._serialization_dir, task_directory, "_".join(["metric", suffix]) ) - model_state = torch.load(model_path, map_location=device_mapping(self._cuda_device)) + model_state = torch.load(model_path) for name, param in self._model.named_parameters(): if param.requires_grad and name not in model_state: diff --git a/jiant/utils/options.py b/jiant/utils/options.py index 56b9677e3..73894dc56 100644 --- a/jiant/utils/options.py +++ b/jiant/utils/options.py @@ -2,6 +2,7 @@ Functions for parsing configs. """ import torch +import logging as log def parse_task_list_arg(task_list): @@ -19,22 +20,25 @@ def parse_task_list_arg(task_list): return task_names -def parse_cuda_list_arg(cuda_list): +def parse_cuda_list_arg(cuda_arg): """ Parse cuda_list settings """ result_cuda = [] - if cuda_list == "auto": - cuda_list = list(range(torch.cuda.device_count())) - return cuda_list - elif isinstance(cuda_list, int): - return cuda_list - elif "," in cuda_list: - return [int(d) for d in cuda_list.split(",")] + if cuda_arg == "auto": + result_cuda = list(range(torch.cuda.device_count())) + if len(result_cuda) == 1: + result_cuda = result_cuda[0] + elif len(result_cuda) == 0: + result_cuda = -1 + elif isinstance(cuda_arg, int): + result_cuda = cuda_arg + elif "," in cuda_arg: + result_cuda = [int(d) for d in cuda_arg.split(",")] else: raise ValueError( "Your cuda settings do not match any of the possibilities in defaults.conf" ) - if len(result_cuda) == 1: - result_cuda = result_cuda[0] + if torch.cuda.device_count() == 0 and not (isinstance(result_cuda, int) and result_cuda == -1): + raise ValueError("You specified usage of CUDA but CUDA devices not found.") return result_cuda diff --git a/jiant/utils/utils.py b/jiant/utils/utils.py index b2f315144..82cd208da 100644 --- a/jiant/utils/utils.py +++ b/jiant/utils/utils.py @@ -418,10 +418,7 @@ def format_output(obj, cuda_devices): def uses_cuda(cuda_devices): - use_cuda = 1 - if isinstance(cuda_devices, list) or isinstance(cuda_devices, int) and cuda_devices >= 0: - return use_cuda - return 0 + return isinstance(cuda_devices, list) or (isinstance(cuda_devices, int) and cuda_devices >= 0) def get_batch_size(batch, cuda_devices, keyword="input"):