diff --git a/checkpoint.py b/checkpoint.py index 455ef69..edaa16b 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -34,4 +34,8 @@ def load_checkpoint(name, key_name='state_dict'): Selected element from loaded checkpoint pickle file """ checkpoint = torch.load(name) + + if key_name not in checkpoint: + return checkpoint + return checkpoint[key_name] diff --git a/datasets/loading_function.py b/datasets/loading_function.py index e24e7ae..3d0a3d1 100644 --- a/datasets/loading_function.py +++ b/datasets/loading_function.py @@ -45,6 +45,7 @@ def data_loader(**kwargs): """ load_type = kwargs['load_type'] + num_nodes = max(kwargs['num_gpus'], 1) if load_type == 'train_val': kwargs['load_type'] = 'train' train_data = create_dataset_object(**kwargs) @@ -52,20 +53,20 @@ def data_loader(**kwargs): val_data = create_dataset_object(**kwargs) kwargs['load_type'] = load_type - trainloader = torch.utils.data.DataLoader(dataset=train_data, batch_size=kwargs['batch_size'], shuffle=True, num_workers=kwargs['num_workers']) - valloader = torch.utils.data.DataLoader(dataset=val_data, batch_size=kwargs['batch_size'], shuffle=False, num_workers=kwargs['num_workers']) + trainloader = torch.utils.data.DataLoader(dataset=train_data, batch_size=kwargs['batch_size']*num_nodes, shuffle=True, num_workers=kwargs['num_workers']) + valloader = torch.utils.data.DataLoader(dataset=val_data, batch_size=kwargs['batch_size']*num_nodes, shuffle=False, num_workers=kwargs['num_workers']) ret_dict = dict(train=trainloader, valid=valloader) elif load_type == 'train': data = create_dataset_object(**kwargs) - loader = torch.utils.data.DataLoader(dataset=data, batch_size=kwargs['batch_size'], shuffle=True, num_workers=kwargs['num_workers']) + loader = torch.utils.data.DataLoader(dataset=data, batch_size=kwargs['batch_size']*num_nodes, shuffle=True, num_workers=kwargs['num_workers']) ret_dict = dict(train=loader) else: data = create_dataset_object(**kwargs) - loader = torch.utils.data.DataLoader(dataset=data, batch_size=kwargs['batch_size'], shuffle=False, num_workers=kwargs['num_workers']) + loader = torch.utils.data.DataLoader(dataset=data, batch_size=kwargs['batch_size']*num_nodes, shuffle=False, num_workers=kwargs['num_workers']) ret_dict = dict(test=loader) diff --git a/eval.py b/eval.py index 43148c7..6b6bc25 100644 --- a/eval.py +++ b/eval.py @@ -56,13 +56,23 @@ def eval(**args): writer = SummaryWriter(log_dir) # Check if GPU is available (CUDA) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + num_gpus = args['num_gpus'] + device = torch.device("cuda:0" if num_gpus > 0 and torch.cuda.is_available() else "cpu") + print('Using {}'.format(device.type)) # Load Network model = create_model_object(**args).to(device) + model_obj = model + + if device.type == 'cuda' and num_gpus > 1: + device_ids = list(range(num_gpus)) #number of GPUs specified + model = nn.DataParallel(model, device_ids=device_ids) + model_obj = model.module #Model from DataParallel object has to be accessed through module + + print('GPUs Device IDs: {}'.format(device_ids)) # Load Data - loader = data_loader(**args, model_obj=model) + loader = data_loader(**args, model_obj=model_obj) if args['load_type'] == 'train_val': eval_loader = loader['valid'] @@ -80,7 +90,13 @@ def eval(**args): if isinstance(args['pretrained'], str): ckpt = load_checkpoint(args['pretrained']) - model.load_state_dict(ckpt) + + ckpt_keys = list(ckpt.keys()) + if ckpt_keys[0].startswith('module.'): #if checkpoint weights are from DataParallel object + for key in ckpt_keys: + ckpt[key[7:]] = ckpt.pop(key) + + model_obj.load_state_dict(ckpt) # Training Setup params = [p for p in model.parameters() if p.requires_grad] diff --git a/parse_args.py b/parse_args.py index 3e7c8fe..3330ac5 100644 --- a/parse_args.py +++ b/parse_args.py @@ -1,5 +1,6 @@ import argparse import yaml +import torch class Parse(): @@ -15,8 +16,9 @@ def __init__(self): #Command-line arguments will override any config file arguments parser.add_argument('--rerun', type=int, help='Number of trials to repeat an experiment') parser.add_argument('--dataset', type=str, help='Name of dataset') - parser.add_argument('--batch_size', type=int, help='Numbers of videos in a mini-batch') + parser.add_argument('--batch_size', type=int, help='Numbers of videos in a mini-batch (per GPU)') parser.add_argument('--pseudo_batch_loop', type=int, help='Number of loops for mini-batch') + parser.add_argument('--num_gpus', type=int, help='Number of GPUs to use, default: -1 (all available GPUs). 0 (use CPU), >1 (number of GPUs to use)') parser.add_argument('--num_workers', type=int, help='Number of subprocesses for dataloading') parser.add_argument('--load_type', type=str, help='Environment selection, to include only training/training and validation/testing dataset (train, train_val, test)') parser.add_argument('--model', type=str, help='Name of model to be loaded') @@ -56,32 +58,33 @@ def __init__(self): # Default dict, anything not present is required to exist as an argument or in yaml file self.defaults = dict( - rerun = 5, - batch_size = 1, - pseudo_batch_loop= 1, - num_workers = 1, - acc_metric = None, - opt = 'sgd', - lr = 0.001, - momentum = 0.9, - weight_decay = 0.0005, - milestones = [5], - gamma = 0.1, - epoch = 10, - save_dir = './results', - exp = 'exp', - preprocess = 'default', - pretrained = 0, - subtract_mean = '', - clip_offset = 0, - random_offset = 0, - clip_stride = 0, - crop_type = None, - num_clips = 1, - debug = 0, - seed = 0, - scale = [1,1], - resume = 0) + rerun = 5, + batch_size = 1, + pseudo_batch_loop = 1, + num_gpus = -1, + num_workers = 1, + acc_metric = None, + opt = 'sgd', + lr = 0.001, + momentum = 0.9, + weight_decay = 0.0005, + milestones = [5], + gamma = 0.1, + epoch = 10, + save_dir = './results', + exp = 'exp', + preprocess = 'default', + pretrained = 0, + subtract_mean = '', + clip_offset = 0, + random_offset = 0, + clip_stride = 0, + crop_type = None, + num_clips = 1, + debug = 0, + seed = 0, + scale = [1,1], + resume = 0) @@ -120,6 +123,9 @@ def get_args(self): if self.cfg_args['clip_stride'] < 1: self.cfg_args['clip_stride'] = 1 - + #Use all available GPUs if num_gpus = -1 + #Else select the minimum between available GPUS and requested GPUs + num_gpus = torch.cuda.device_count() if self.cfg_args['num_gpus'] == -1 else min(torch.cuda.device_count(), self.cfg_args['num_gpus']) + self.cfg_args['num_gpus'] = num_gpus return self.cfg_args diff --git a/train.py b/train.py index d889d84..8c5dcc5 100644 --- a/train.py +++ b/train.py @@ -68,13 +68,23 @@ def train(**args): writer = SummaryWriter(log_dir) # Check if GPU is available (CUDA) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - + num_gpus = args['num_gpus'] + device = torch.device("cuda:0" if num_gpus > 0 and torch.cuda.is_available() else "cpu") + print('Using {}'.format(device.type)) + # Load Network model = create_model_object(**args).to(device) + model_obj = model + + if device.type == 'cuda' and num_gpus > 1: + device_ids = list(range(num_gpus)) #number of GPUs specified + model = nn.DataParallel(model, device_ids=device_ids) + model_obj = model.module #Model from DataParallel object has to be accessed through module + print('GPUs Device IDs: {}'.format(device_ids)) + # Load Data - loader = data_loader(model_obj=model, **args) + loader = data_loader(model_obj=model_obj, **args) if args['load_type'] == 'train': train_loader = loader['train'] @@ -107,7 +117,13 @@ def train(**args): if isinstance(args['pretrained'], str): ckpt = load_checkpoint(args['pretrained']) - model.load_state_dict(ckpt) + + ckpt_keys = list(ckpt.keys()) + if ckpt_keys[0].startswith('module.'): #if checkpoint weights are from DataParallel object + for key in ckpt_keys: + ckpt[key[7:]] = ckpt.pop(key) + + model_obj.load_state_dict(ckpt) if args['resume']: start_epoch = load_checkpoint(args['pretrained'], key_name='epoch') + 1