Skip to content

Commit

Permalink
flake8 fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
peterAsapp committed Jul 30, 2020
1 parent 055bab4 commit 224ec3f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 24 deletions.
14 changes: 2 additions & 12 deletions tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):

checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k)

trainer = Trainer(
default_root_dir=tmpdir,
checkpoint_callback=checkpoint,
overfit_batches=0.20,
max_epochs=2,
)
trainer = Trainer(default_root_dir=tmpdir, checkpoint_callback=checkpoint, overfit_batches=0.20, max_epochs=2)
trainer.fit(model)
assert checkpoint.dirpath == tmpdir / trainer.logger.name / 'version_0' / 'checkpoints'

Expand All @@ -40,12 +35,7 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected):
model = EvalModelTemplate()
logger = TensorBoardLogger(str(tmpdir), version=logger_version)

trainer = Trainer(
default_root_dir=tmpdir,
overfit_batches=0.2,
max_epochs=2,
logger=logger,
)
trainer = Trainer(default_root_dir=tmpdir, overfit_batches=0.2, max_epochs=2, logger=logger)
trainer.fit(model)

ckpt_version = Path(trainer.checkpoint_callback.dirpath).parent.name
Expand Down
15 changes: 3 additions & 12 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,22 +221,14 @@ def test_dataloaders_passed_to_fit(tmpdir):

model = EvalModelTemplate()

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
distributed_backend='tpu',
tpu_cores=8,
)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, distributed_backend='tpu', tpu_cores=8)
result = trainer.fit(model, train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader())
assert result, "TPU doesn't work with dataloaders passed to fit()."


@pytest.mark.parametrize(
['tpu_cores', 'expected_tpu_id'],
[pytest.param(1, None),
pytest.param(8, None),
pytest.param([1], 1),
pytest.param([8], 8)],
[pytest.param(1, None), pytest.param(8, None), pytest.param([1], 1), pytest.param([8], 8)],
)
def test_tpu_id_to_be_as_expected(tpu_cores, expected_tpu_id):
"""Test if trainer.tpu_id is set as expected"""
Expand All @@ -247,8 +239,7 @@ def test_tpu_misconfiguration():
"""Test if trainer.tpu_id is set as expected"""
with pytest.raises(MisconfigurationException, match="`tpu_cores` can only be"):
Trainer(
tpu_cores=[1, 8],
distributed_backend='tpu',
tpu_cores=[1, 8], distributed_backend='tpu',
)


Expand Down

0 comments on commit 224ec3f

Please sign in to comment.