diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 87b0db295e11d..fc34085972e03 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -132,7 +132,7 @@ def __init__(self, self.gradient_clip_val = gradient_clip_val self.check_val_every_n_epoch = check_val_every_n_epoch self.track_grad_norm = track_grad_norm - self.on_gpu = gpus is not None and torch.cuda.is_available() + self.on_gpu = True if (gpus and torch.cuda.is_available()) else False self.process_position = process_position self.weights_summary = weights_summary self.max_nb_epochs = max_nb_epochs