Skip to content

Commit

Permalink
revert changes to the tests, enforce str contract in code
Browse files Browse the repository at this point in the history
  • Loading branch information
f4hy committed Jun 24, 2020
1 parent f2d9c97 commit 9a6131e
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 66 deletions.
8 changes: 5 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,12 @@ def __init__(
' val and test loop using a single batch')

# set default save path if user didn't provide one
self.default_root_dir = default_root_dir

if self.default_root_dir is None:
if default_root_dir is None:
self.default_root_dir = os.getcwd()
else:
# we have to do str() because the unit tests violate type annotation and pass path objecto
self.default_root_dir = str(default_root_dir)


# training bookeeping
self.total_batch_idx = 0
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def restore_hpc_weights_if_needed(self, model: LightningModule):
did_restore = False

# look for hpc weights
folderpath = self.weights_save_path
folderpath = str(self.weights_save_path)
if gfile.exists(folderpath):
files = gfile.listdir(folderpath)
hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]
Expand Down Expand Up @@ -452,6 +452,7 @@ def restore_training_state(self, checkpoint):
# ----------------------------------
def hpc_save(self, folderpath: str, logger):
# make sure the checkpoint folder exists
folderpath = str(folderpath) # because the tests pass a path object
if not gfile.exists(folderpath):
gfile.makedirs(folderpath)

Expand Down Expand Up @@ -511,7 +512,7 @@ def hpc_load(self, folderpath, on_gpu):
log.info(f'restored hpc model from: {filepath}')

def max_ckpt_in_folder(self, path, name_key='ckpt_'):
files = gfile.listdir(path)
files = gfile.listdir(str(path))
files = [x for x in files if name_key in x]
if len(files) == 0:
return 0
Expand Down
6 changes: 3 additions & 3 deletions tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def log_metrics(self, metrics, step):
super().log_metrics(metrics, step)
self.history.append((step, metrics))

logger_args = _get_logger_args(logger_class, str(tmpdir))
logger_args = _get_logger_args(logger_class, tmpdir)
logger = StoreHistoryLogger(**logger_args)

trainer = Trainer(
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
import atexit
monkeypatch.setattr(atexit, 'register', lambda _: None)

logger_args = _get_logger_args(logger_class, str(tmpdir))
logger_args = _get_logger_args(logger_class, tmpdir)
logger = logger_class(**logger_args)

# test pickling loggers
Expand All @@ -109,7 +109,7 @@ def test_logger_reset_correctly(tmpdir, extra_params):
model = EvalModelTemplate()

trainer = Trainer(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
**extra_params
)
logger1 = trainer.logger
Expand Down
36 changes: 20 additions & 16 deletions tests/models/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_cpu_slurm_save_load(tmpdir):
model = EvalModelTemplate(**hparams)

# logger file to get meta
logger = tutils.get_default_logger(str(tmpdir))
logger = tutils.get_default_logger(tmpdir)
version = logger.version

# fit model
Expand All @@ -28,7 +28,7 @@ def test_cpu_slurm_save_load(tmpdir):
logger=logger,
train_percent_check=0.2,
val_percent_check=0.2,
checkpoint_callback=ModelCheckpoint(str(tmpdir)),
checkpoint_callback=ModelCheckpoint(tmpdir)
)
result = trainer.fit(model)
real_global_step = trainer.global_step
Expand All @@ -54,13 +54,17 @@ def test_cpu_slurm_save_load(tmpdir):

# test HPC saving
# simulate snapshot on slurm
saved_filepath = trainer.hpc_save(str(tmpdir), logger)
saved_filepath = trainer.hpc_save(tmpdir, logger)
assert os.path.exists(saved_filepath)

# new logger file to get meta
logger = tutils.get_default_logger(str(tmpdir), version=version)
logger = tutils.get_default_logger(tmpdir, version=version)

trainer = Trainer(max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(str(tmpdir)),)
trainer = Trainer(
max_epochs=1,
logger=logger,
checkpoint_callback=ModelCheckpoint(tmpdir),
)
model = EvalModelTemplate(**hparams)

# set the epoch start hook so we can predict before the model does the full training
Expand All @@ -83,7 +87,7 @@ def test_early_stopping_cpu_model(tmpdir):
"""Test each of the trainer options."""
stopping = EarlyStopping(monitor='val_loss', min_delta=0.1)
trainer_options = dict(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
early_stop_callback=stopping,
max_epochs=2,
gradient_clip_val=1.0,
Expand Down Expand Up @@ -112,7 +116,7 @@ def test_multi_cpu_model_ddp(tmpdir):
tutils.set_random_master_port()

trainer_options = dict(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.4,
Expand All @@ -129,7 +133,7 @@ def test_multi_cpu_model_ddp(tmpdir):
def test_lbfgs_cpu_model(tmpdir):
"""Test each of the trainer options."""
trainer_options = dict(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
max_epochs=1,
progress_bar_refresh_rate=0,
weights_summary='top',
Expand All @@ -148,7 +152,7 @@ def test_lbfgs_cpu_model(tmpdir):
def test_default_logger_callbacks_cpu_model(tmpdir):
"""Test each of the trainer options."""
trainer_options = dict(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
max_epochs=1,
gradient_clip_val=1.0,
overfit_pct=0.20,
Expand All @@ -170,14 +174,14 @@ def test_running_test_after_fitting(tmpdir):
model = EvalModelTemplate()

# logger file to get meta
logger = tutils.get_default_logger(str(tmpdir))
logger = tutils.get_default_logger(tmpdir)

# logger file to get weights
checkpoint = tutils.init_checkpoint_callback(logger)

# fit model
trainer = Trainer(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=2,
train_percent_check=0.4,
Expand All @@ -201,7 +205,7 @@ def test_running_test_no_val(tmpdir):
model = EvalModelTemplate()

# logger file to get meta
logger = tutils.get_default_logger(str(tmpdir))
logger = tutils.get_default_logger(tmpdir)

# logger file to get weights
checkpoint = tutils.init_checkpoint_callback(logger)
Expand Down Expand Up @@ -280,7 +284,7 @@ def test_simple_cpu(tmpdir):

# fit model
trainer = Trainer(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.1,
Expand All @@ -294,7 +298,7 @@ def test_simple_cpu(tmpdir):
def test_cpu_model(tmpdir):
"""Make sure model trains on CPU."""
trainer_options = dict(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.4,
Expand All @@ -309,7 +313,7 @@ def test_cpu_model(tmpdir):
def test_all_features_cpu_model(tmpdir):
"""Test each of the trainer options."""
trainer_options = dict(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
gradient_clip_val=1.0,
overfit_pct=0.20,
track_grad_norm=2,
Expand Down Expand Up @@ -383,7 +387,7 @@ def train_dataloader(self):

# fit model
trainer = Trainer(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
max_epochs=1,
truncated_bptt_steps=truncated_bptt_steps,
val_percent_check=0,
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def test_single_gpu_model(tmpdir, gpus):
"""Make sure single GPU works (DP mode)."""
trainer_options = dict(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.1,
Expand All @@ -38,7 +38,7 @@ def test_multi_gpu_model(tmpdir, backend):
tutils.set_random_master_port()

trainer_options = dict(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
Expand Down Expand Up @@ -84,7 +84,7 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
def test_multi_gpu_none_backend(tmpdir):
"""Make sure when using multiple GPUs the user can't use `distributed_backend = None`."""
trainer_options = dict(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.1,
Expand Down
18 changes: 9 additions & 9 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_error_on_more_than_1_optimizer(tmpdir):

# logger file to get meta
trainer = Trainer(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
max_epochs=1
)

Expand All @@ -29,7 +29,7 @@ def test_model_reset_correctly(tmpdir):

# logger file to get meta
trainer = Trainer(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
max_epochs=1
)

Expand All @@ -51,7 +51,7 @@ def test_trainer_reset_correctly(tmpdir):

# logger file to get meta
trainer = Trainer(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
max_epochs=1
)

Expand Down Expand Up @@ -81,7 +81,7 @@ def test_trainer_arg_bool(tmpdir):

# logger file to get meta
trainer = Trainer(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
max_epochs=2,
auto_lr_find=True
)
Expand All @@ -100,7 +100,7 @@ def test_trainer_arg_str(tmpdir):
before_lr = model.my_fancy_lr
# logger file to get meta
trainer = Trainer(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
max_epochs=2,
auto_lr_find='my_fancy_lr'
)
Expand All @@ -120,7 +120,7 @@ def test_call_to_trainer_method(tmpdir):
before_lr = hparams.get('learning_rate')
# logger file to get meta
trainer = Trainer(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
max_epochs=2,
)

Expand All @@ -144,7 +144,7 @@ def test_accumulation_and_early_stopping(tmpdir):
before_lr = hparams.get('learning_rate')
# logger file to get meta
trainer = Trainer(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
accumulate_grad_batches=2,
)

Expand All @@ -167,7 +167,7 @@ def test_suggestion_parameters_work(tmpdir):

# logger file to get meta
trainer = Trainer(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
max_epochs=3,
)

Expand All @@ -187,7 +187,7 @@ def test_suggestion_with_non_finite_values(tmpdir):

# logger file to get meta
trainer = Trainer(
default_root_dir=str(tmpdir),
default_root_dir=tmpdir,
max_epochs=3
)

Expand Down
Loading

0 comments on commit 9a6131e

Please sign in to comment.