Skip to content

Commit

Permalink
bugfix/3185 transpose (#3252)
Browse files Browse the repository at this point in the history
* change t() to transpose() as xla devices do not support .t() on 1-dim tensor

* detach tensor before copying

* Revert "detach tensor before copying"

This reverts commit 37cc7bb

* changed dims

* added test_result_obj_on_tpu

* detach before copying

* detach before copying

* detach before copying

* replace torch.cat with sum
  • Loading branch information
lezwon authored Sep 1, 2020
1 parent 8354c4a commit 3910ad0
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _stop_distributed_training(self, trainer, pl_module):

if trainer.use_tpu:
stop = torch.tensor(int(trainer.should_stop), device=pl_module.device, dtype=torch.int32)
stop = xm.mesh_reduce("stop_signal", stop, torch.cat)
stop = xm.mesh_reduce("stop_signal", stop, sum)
torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
trainer.should_stop = int(stop.item()) == trainer.world_size

Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,8 @@ def __str__(self):
def __copy__(self):
newone = type(self)()
for k, v in self.items():
if isinstance(v, torch.Tensor):
v = v.detach()
newone[k] = copy(v)
return newone

Expand Down Expand Up @@ -843,6 +845,6 @@ def write_dict(self, predictions_dict, filename='predictions.pt'):

def weighted_mean(result, weights):
weights = weights.to(result.device)
numerator = torch.dot(result.float(), weights.t().float())
numerator = torch.dot(result.float(), weights.transpose(-1, 0).float())
result = numerator / weights.sum().float()
return result
35 changes: 34 additions & 1 deletion tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.utils.data import DataLoader

import tests.base.develop_pipelines as tpipes
from pytorch_lightning import Trainer
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.base.datasets import TrialMNIST
Expand Down Expand Up @@ -241,3 +241,36 @@ def test_exception_when_no_tpu_found(tmpdir):
def test_distributed_backend_set_when_using_tpu(tmpdir, tpu_cores):
"""Test if distributed_backend is set to `tpu` when tpu_cores is not None"""
assert Trainer(tpu_cores=tpu_cores).distributed_backend == 'tpu'


@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_result_obj_on_tpu(tmpdir):
seed_everything(1234)
os.environ['PL_DEV_DEBUG'] = '1'

batches = 5
epochs = 2

model = EvalModelTemplate()
model.training_step = model.training_step_result_obj
model.training_step_end = None
model.training_epoch_end = None
model.validation_step = model.validation_step_result_obj
model.validation_step_end = None
model.validation_epoch_end = None
model.test_step = model.test_step_result_obj
model.test_step_end = None
model.test_epoch_end = None

trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=epochs,
early_stop_callback=True,
row_log_interval=2,
limit_train_batches=batches,
weights_summary=None,
tpu_cores=8
)

tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)

0 comments on commit 3910ad0

Please sign in to comment.