Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for multi-gpu or cpu-only training #51

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,8 @@ def load_checkpoint(name, key_name='state_dict'):
Selected element from loaded checkpoint pickle file
"""
checkpoint = torch.load(name)

if key_name not in checkpoint:
return checkpoint

return checkpoint[key_name]
9 changes: 5 additions & 4 deletions datasets/loading_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,27 +45,28 @@ def data_loader(**kwargs):
"""

load_type = kwargs['load_type']
num_nodes = max(kwargs['num_gpus'], 1)
if load_type == 'train_val':
kwargs['load_type'] = 'train'
train_data = create_dataset_object(**kwargs)
kwargs['load_type'] = 'val'
val_data = create_dataset_object(**kwargs)
kwargs['load_type'] = load_type

trainloader = torch.utils.data.DataLoader(dataset=train_data, batch_size=kwargs['batch_size'], shuffle=True, num_workers=kwargs['num_workers'])
valloader = torch.utils.data.DataLoader(dataset=val_data, batch_size=kwargs['batch_size'], shuffle=False, num_workers=kwargs['num_workers'])
trainloader = torch.utils.data.DataLoader(dataset=train_data, batch_size=kwargs['batch_size']*num_nodes, shuffle=True, num_workers=kwargs['num_workers'])
valloader = torch.utils.data.DataLoader(dataset=val_data, batch_size=kwargs['batch_size']*num_nodes, shuffle=False, num_workers=kwargs['num_workers'])
ret_dict = dict(train=trainloader, valid=valloader)

elif load_type == 'train':
data = create_dataset_object(**kwargs)

loader = torch.utils.data.DataLoader(dataset=data, batch_size=kwargs['batch_size'], shuffle=True, num_workers=kwargs['num_workers'])
loader = torch.utils.data.DataLoader(dataset=data, batch_size=kwargs['batch_size']*num_nodes, shuffle=True, num_workers=kwargs['num_workers'])
ret_dict = dict(train=loader)

else:
data = create_dataset_object(**kwargs)

loader = torch.utils.data.DataLoader(dataset=data, batch_size=kwargs['batch_size'], shuffle=False, num_workers=kwargs['num_workers'])
loader = torch.utils.data.DataLoader(dataset=data, batch_size=kwargs['batch_size']*num_nodes, shuffle=False, num_workers=kwargs['num_workers'])
ret_dict = dict(test=loader)


Expand Down
22 changes: 19 additions & 3 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,23 @@ def eval(**args):
writer = SummaryWriter(log_dir)

# Check if GPU is available (CUDA)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_gpus = args['num_gpus']
device = torch.device("cuda:0" if num_gpus > 0 and torch.cuda.is_available() else "cpu")
print('Using {}'.format(device.type))

# Load Network
model = create_model_object(**args).to(device)
model_obj = model

if device.type == 'cuda' and num_gpus > 1:
device_ids = list(range(num_gpus)) #number of GPUs specified
model = nn.DataParallel(model, device_ids=device_ids)
model_obj = model.module #Model from DataParallel object has to be accessed through module

print('GPUs Device IDs: {}'.format(device_ids))

# Load Data
loader = data_loader(**args, model_obj=model)
loader = data_loader(**args, model_obj=model_obj)

if args['load_type'] == 'train_val':
eval_loader = loader['valid']
Expand All @@ -80,7 +90,13 @@ def eval(**args):

if isinstance(args['pretrained'], str):
ckpt = load_checkpoint(args['pretrained'])
model.load_state_dict(ckpt)

ckpt_keys = list(ckpt.keys())
if ckpt_keys[0].startswith('module.'): #if checkpoint weights are from DataParallel object
for key in ckpt_keys:
ckpt[key[7:]] = ckpt.pop(key)

model_obj.load_state_dict(ckpt)

# Training Setup
params = [p for p in model.parameters() if p.requires_grad]
Expand Down
62 changes: 34 additions & 28 deletions parse_args.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import yaml
import torch

class Parse():

Expand All @@ -15,8 +16,9 @@ def __init__(self):
#Command-line arguments will override any config file arguments
parser.add_argument('--rerun', type=int, help='Number of trials to repeat an experiment')
parser.add_argument('--dataset', type=str, help='Name of dataset')
parser.add_argument('--batch_size', type=int, help='Numbers of videos in a mini-batch')
parser.add_argument('--batch_size', type=int, help='Numbers of videos in a mini-batch (per GPU)')
parser.add_argument('--pseudo_batch_loop', type=int, help='Number of loops for mini-batch')
parser.add_argument('--num_gpus', type=int, help='Number of GPUs to use, default: -1 (all available GPUs). 0 (use CPU), >1 (number of GPUs to use)')
parser.add_argument('--num_workers', type=int, help='Number of subprocesses for dataloading')
parser.add_argument('--load_type', type=str, help='Environment selection, to include only training/training and validation/testing dataset (train, train_val, test)')
parser.add_argument('--model', type=str, help='Name of model to be loaded')
Expand Down Expand Up @@ -56,32 +58,33 @@ def __init__(self):

# Default dict, anything not present is required to exist as an argument or in yaml file
self.defaults = dict(
rerun = 5,
batch_size = 1,
pseudo_batch_loop= 1,
num_workers = 1,
acc_metric = None,
opt = 'sgd',
lr = 0.001,
momentum = 0.9,
weight_decay = 0.0005,
milestones = [5],
gamma = 0.1,
epoch = 10,
save_dir = './results',
exp = 'exp',
preprocess = 'default',
pretrained = 0,
subtract_mean = '',
clip_offset = 0,
random_offset = 0,
clip_stride = 0,
crop_type = None,
num_clips = 1,
debug = 0,
seed = 0,
scale = [1,1],
resume = 0)
rerun = 5,
batch_size = 1,
pseudo_batch_loop = 1,
num_gpus = -1,
num_workers = 1,
acc_metric = None,
opt = 'sgd',
lr = 0.001,
momentum = 0.9,
weight_decay = 0.0005,
milestones = [5],
gamma = 0.1,
epoch = 10,
save_dir = './results',
exp = 'exp',
preprocess = 'default',
pretrained = 0,
subtract_mean = '',
clip_offset = 0,
random_offset = 0,
clip_stride = 0,
crop_type = None,
num_clips = 1,
debug = 0,
seed = 0,
scale = [1,1],
resume = 0)



Expand Down Expand Up @@ -120,6 +123,9 @@ def get_args(self):
if self.cfg_args['clip_stride'] < 1:
self.cfg_args['clip_stride'] = 1


#Use all available GPUs if num_gpus = -1
#Else select the minimum between available GPUS and requested GPUs
num_gpus = torch.cuda.device_count() if self.cfg_args['num_gpus'] == -1 else min(torch.cuda.device_count(), self.cfg_args['num_gpus'])
self.cfg_args['num_gpus'] = num_gpus

return self.cfg_args
24 changes: 20 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,23 @@ def train(**args):
writer = SummaryWriter(log_dir)

# Check if GPU is available (CUDA)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

num_gpus = args['num_gpus']
device = torch.device("cuda:0" if num_gpus > 0 and torch.cuda.is_available() else "cpu")
print('Using {}'.format(device.type))

# Load Network
model = create_model_object(**args).to(device)
model_obj = model

if device.type == 'cuda' and num_gpus > 1:
device_ids = list(range(num_gpus)) #number of GPUs specified
model = nn.DataParallel(model, device_ids=device_ids)
model_obj = model.module #Model from DataParallel object has to be accessed through module

print('GPUs Device IDs: {}'.format(device_ids))

# Load Data
loader = data_loader(model_obj=model, **args)
loader = data_loader(model_obj=model_obj, **args)

if args['load_type'] == 'train':
train_loader = loader['train']
Expand Down Expand Up @@ -107,7 +117,13 @@ def train(**args):

if isinstance(args['pretrained'], str):
ckpt = load_checkpoint(args['pretrained'])
model.load_state_dict(ckpt)

ckpt_keys = list(ckpt.keys())
if ckpt_keys[0].startswith('module.'): #if checkpoint weights are from DataParallel object
for key in ckpt_keys:
ckpt[key[7:]] = ckpt.pop(key)

model_obj.load_state_dict(ckpt)

if args['resume']:
start_epoch = load_checkpoint(args['pretrained'], key_name='epoch') + 1
Expand Down