Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove obsolete self._device in Trainer #1849

Merged
merged 3 commits into from
May 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,6 @@ def __init__(self, *args, **kwargs):
self.logger = None
self.example_input_array = None

#: True if your model is currently running on GPUs.
#: Useful to set flags around the LightningModule for different CPU vs GPU behavior.
self.on_gpu = False

#: True if using dp
self.use_dp = False

Expand All @@ -72,10 +68,19 @@ def __init__(self, *args, **kwargs):
self.hparams = None

#: Current dtype
self._dtype = torch.FloatTensor
self._dtype = torch.float

#: device reference
self._device = torch.device('cpu')

@property
def on_gpu(self):
"""
True if your model is currently running on GPUs.
Useful to set flags around the LightningModule for different CPU vs GPU behavior.
"""
return self.device.type == 'cuda'

def print(self, *args, **kwargs) -> None:
r"""
Prints only from process 0. Use this in any distributed mode to log only once.
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ def ddp_train(self, process_idx, model):
# copy model to each gpu
if self.on_gpu:
self.root_gpu = process_idx
self._device = torch.device('cuda', self.root_gpu)
torch.cuda.set_device(self.root_gpu)
model.cuda(self.root_gpu)

Expand Down
6 changes: 0 additions & 6 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,6 @@ def copy_trainer_model_properties(self, model):

for m in [model, ref_model]:
m.trainer = self
m.on_gpu = self.on_gpu
m.use_dp = self.use_dp
m.use_ddp2 = self.use_ddp2
m.use_ddp = self.use_ddp
Expand All @@ -432,7 +431,6 @@ def copy_trainer_model_properties(self, model):
m.use_tpu = self.use_tpu
m.tpu_local_core_rank = self.tpu_local_core_rank
m.tpu_global_core_rank = self.tpu_global_core_rank
m._device = self._device

def transfer_batch_to_tpu(self, batch):
return self.__transfer_data_to_device(batch, device='tpu')
Expand Down Expand Up @@ -488,7 +486,6 @@ def __transfer_data_to_device(self, batch, device, gpu_id=None):

def single_gpu_train(self, model):
model.cuda(self.root_gpu)
self._device = torch.device('cuda', self.root_gpu)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
Expand All @@ -505,7 +502,6 @@ def single_gpu_train(self, model):
def tpu_train(self, tpu_core_idx, model):
# put model on tpu
model.to(xm.xla_device())
self._device = xm.xla_device()

# get the appropriate tpu ranks
self.tpu_local_core_rank = xm.get_local_ordinal()
Expand Down Expand Up @@ -545,7 +541,6 @@ def dp_train(self, model):
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)

model.cuda(self.root_gpu)
self._device = torch.device('cuda', self.root_gpu)

# hack forward to do autocast for the user
model_autocast_original_forward = model.forward
Expand Down Expand Up @@ -585,7 +580,6 @@ def horovod_train(self, model):
assert self.root_gpu == hvd.local_rank()
torch.cuda.set_device(self.root_gpu)
model.cuda(self.root_gpu)
self._device = torch.device('cuda', self.root_gpu)

# avoid duplicating progress bar
if hvd.rank() != 0 and self.progress_bar_callback is not None:
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,6 @@ def __init__(
# distributed backend choice
self.distributed_backend = distributed_backend
self.set_distributed_mode(distributed_backend)
self._device = torch.device('cpu')

# override dist backend when using tpus
if self.on_tpu:
Expand Down