Skip to content

Commit

Permalink
Support windows single-GPU training with DP (#4).
Browse files Browse the repository at this point in the history
  • Loading branch information
bennyguo committed Nov 2, 2022
1 parent 5e66596 commit dabae5e
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 12 deletions.
12 changes: 10 additions & 2 deletions launch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import argparse
import os
import time
Expand Down Expand Up @@ -87,13 +88,20 @@ def main():
TensorBoardLogger(args.runs_dir, name=config.name, version=config.trial_name),
CSVLogger(config.exp_dir, name=config.trial_name, version='csv_logs')
]


if sys.platform == 'win32':
# does not support multi-gpu on windows
strategy = 'dp'
assert n_gpus == 1
else:
strategy = 'ddp_find_unused_parameters_false'

trainer = Trainer(
devices=n_gpus,
accelerator='gpu',
callbacks=callbacks,
logger=loggers,
strategy='ddp_find_unused_parameters_false',
strategy=strategy,
**config.trainer
)

Expand Down
6 changes: 5 additions & 1 deletion models/network_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, _get_rank

from utils.misc import config_to_primitive
from models.utils import get_activation



Expand Down Expand Up @@ -69,9 +70,12 @@ def __init__(self, dim_in, dim_out, config):
self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), self.make_activation()]
self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)]
self.layers = nn.Sequential(*self.layers)
self.output_activation = get_activation(config['output_activation'])

def forward(self, x):
return self.layers(x.float())
x = self.layers(x.float())
x = self.output_activation(x)
return x

def make_linear(self, dim_in, dim_out, is_first, is_last):
layer = nn.Linear(dim_in, dim_out, bias=False)
Expand Down
3 changes: 2 additions & 1 deletion models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def backward(ctx, g): # pylint: disable=arguments-differ


def get_activation(name):
if name is None or name in ['none', 'None']:
name = name.lower()
if name is None or name == 'none':
return nn.Identity()
elif name.startswith('scale'):
scale_factor = float(name[5:])
Expand Down
18 changes: 14 additions & 4 deletions systems/nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,13 @@ def validation_epoch_end(self, out):
if self.trainer.is_global_zero:
out_set = {}
for step_out in out:
for oi, index in enumerate(step_out['index']):
out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]}
# DP
if step_out['index'].ndim == 1:
out_set[step_out['index'].item()] = {'psnr': step_out['psnr']}
# DDP
else:
for oi, index in enumerate(step_out['index']):
out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]}
psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()]))
self.log('val/psnr', psnr, prog_bar=True, rank_zero_only=True)

Expand All @@ -180,8 +185,13 @@ def test_epoch_end(self, out):
if self.trainer.is_global_zero:
out_set = {}
for step_out in out:
for oi, index in enumerate(step_out['index']):
out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]}
# DP
if step_out['index'].ndim == 1:
out_set[step_out['index'].item()] = {'psnr': step_out['psnr']}
# DDP
else:
for oi, index in enumerate(step_out['index']):
out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]}
psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()]))
self.log('test/psnr', psnr, prog_bar=True, rank_zero_only=True)

Expand Down
18 changes: 14 additions & 4 deletions systems/neus.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,13 @@ def validation_epoch_end(self, out):
if self.trainer.is_global_zero:
out_set = {}
for step_out in out:
for oi, index in enumerate(step_out['index']):
out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]}
# DP
if step_out['index'].ndim == 1:
out_set[step_out['index'].item()] = {'psnr': step_out['psnr']}
# DDP
else:
for oi, index in enumerate(step_out['index']):
out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]}
psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()]))
self.log('val/psnr', psnr, prog_bar=True, rank_zero_only=True)

Expand Down Expand Up @@ -195,8 +200,13 @@ def test_epoch_end(self, out):
if self.trainer.is_global_zero:
out_set = {}
for step_out in out:
for oi, index in enumerate(step_out['index']):
out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]}
# DP
if step_out['index'].ndim == 1:
out_set[step_out['index'].item()] = {'psnr': step_out['psnr']}
# DDP
else:
for oi, index in enumerate(step_out['index']):
out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]}
psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()]))
self.log('test/psnr', psnr, prog_bar=True, rank_zero_only=True)

Expand Down

0 comments on commit dabae5e

Please sign in to comment.