Skip to content

Commit

Permalink
Fixes #902
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Feb 25, 2020
1 parent c12cb92 commit 3173ad3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
11 changes: 10 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel

try:
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True

except ImportError:
XLA_AVAILABLE = False


class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):

Expand Down Expand Up @@ -798,7 +805,9 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx, sec
optimizer.zero_grad()
"""
if isinstance(optimizer, torch.optim.LBFGS):
if self.trainer.use_tpu and XLA_AVAILABLE:
xm.optimizer_step(optimizer)
elif isinstance(optimizer, torch.optim.LBFGS):
optimizer.step(second_order_closure)
else:
optimizer.step()
Expand Down
15 changes: 3 additions & 12 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,9 @@ def training_step(self, batch, batch_idx):
except ImportError:
APEX_AVAILABLE = False

try:
import torch_xla.core.xla_model as xm

XLA_AVAILABLE = True
except ImportError:
XLA_AVAILABLE = False

try:
import torch_xla.distributed.parallel_loader as xla_pl
import torch_xla.core.xla_model as xm

XLA_AVAILABLE = True

Expand Down Expand Up @@ -600,11 +594,8 @@ def optimizer_closure():
# override function to modify this behavior
model = self.get_model()
with self.profiler.profile('optimizer_step'):
if self.use_tpu:
xm.optimizer_step(optimizer)
else:
model.optimizer_step(self.current_epoch, batch_idx,
optimizer, opt_idx, optimizer_closure)
model.optimizer_step(self.current_epoch, batch_idx,
optimizer, opt_idx, optimizer_closure)

# calculate running loss for display
self.running_loss.append(self.batch_loss_value)
Expand Down

0 comments on commit 3173ad3

Please sign in to comment.