Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: .tune() (temporary) #3293

Merged
merged 6 commits into from
Aug 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/training_tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ longer training time. Inspired by https://github.com/BlackHC/toma.
# Autoscale batch size
trainer = Trainer(auto_scale_batch_size=None|'power'|'binsearch')

# find the batch size
trainer.tune(model)

Currently, this feature supports two modes `'power'` scaling and `'binsearch'`
scaling. In `'power'` scaling, starting from a batch size of 1 keeps doubling
the batch size until an out-of-memory (OOM) error is encountered. Setting the
Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,17 @@ def forward(self, x):
Automatically tries to find the largest batch size that fits into memory,
before any training.

.. testcode::
.. code-block::

# default used by the Trainer (no scaling of batch size)
trainer = Trainer(auto_scale_batch_size=None)

# run batch size scaling, result overrides hparams.batch_size
trainer = Trainer(auto_scale_batch_size='binsearch')

# call tune to find the batch size
trainer.tune(model)

auto_select_gpus
^^^^^^^^^^^^^^^^

Expand All @@ -241,11 +244,14 @@ def forward(self, x):
Runs a learning rate finder algorithm (see this `paper <https://arxiv.org/abs/1506.01186>`_)
before any training, to find optimal initial learning rate.

.. testcode::
.. code-block:: python

# default used by the Trainer (no learning rate finder)
trainer = Trainer(auto_lr_find=False)

# call tune to find the lr
trainer.tune(model)

Example::

# run learning rate finder, results override hparams.learning_rate
Expand Down
42 changes: 30 additions & 12 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,36 @@ def weights_save_path(self) -> str:
return self._weights_save_path
return os.path.normpath(self._weights_save_path)

def tune(
self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
):
# TODO: temporary, need to decide if tune or separate object

# setup data, etc...
self.setup_fit(model, train_dataloader, val_dataloaders, datamodule)

# hook
self.call_hook('on_fit_start', model)

# hook
self.prepare_data(model)

# Run auto batch size scaling
if self.auto_scale_batch_size:
if isinstance(self.auto_scale_batch_size, bool):
self.auto_scale_batch_size = 'power'
self.scale_batch_size(model, mode=self.auto_scale_batch_size)
model.logger = self.logger # reset logger binding

# Run learning rate finder:
if self.auto_lr_find:
self._run_lr_finder_internally(model)
model.logger = self.logger # reset logger binding

# -----------------------------
# MODEL TRAINING
# -----------------------------
Expand Down Expand Up @@ -986,18 +1016,6 @@ def fit(
# hook
self.prepare_data(model)

# Run auto batch size scaling
if self.auto_scale_batch_size:
if isinstance(self.auto_scale_batch_size, bool):
self.auto_scale_batch_size = 'power'
self.scale_batch_size(model, mode=self.auto_scale_batch_size)
model.logger = self.logger # reset logger binding

# Run learning rate finder:
if self.auto_lr_find:
self._run_lr_finder_internally(model)
model.logger = self.logger # reset logger binding

# set testing if set in environ
self.testing = os.environ.get('PL_TESTING_MODE', self.testing)

Expand Down
2 changes: 1 addition & 1 deletion tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def test_logger_reset_correctly(tmpdir, extra_params):
**extra_params,
)
logger1 = trainer.logger
trainer.fit(model)
trainer.tune(model)
logger2 = trainer.logger
logger3 = model.logger

Expand Down
6 changes: 3 additions & 3 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_trainer_arg_bool(tmpdir, use_hparams):
auto_lr_find=True,
)

trainer.fit(model)
trainer.tune(model)
if use_hparams:
after_lr = model.hparams.learning_rate
else:
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_trainer_arg_str(tmpdir, use_hparams):
auto_lr_find='my_fancy_lr',
)

trainer.fit(model)
trainer.tune(model)
if use_hparams:
after_lr = model.hparams.my_fancy_lr
else:
Expand All @@ -146,7 +146,7 @@ def test_call_to_trainer_method(tmpdir):
lrfinder = trainer.lr_find(model, mode='linear')
after_lr = lrfinder.suggestion()
model.learning_rate = after_lr
trainer.fit(model)
trainer.tune(model)

assert before_lr != after_lr, \
'Learning rate was not altered after running learning rate finder'
Expand Down
10 changes: 5 additions & 5 deletions tests/trainer/test_trainer_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def test_auto_scale_batch_size_trainer_arg(tmpdir, scale_arg):
model = EvalModelTemplate(**hparams)
before_batch_size = hparams.get('batch_size')
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=scale_arg)
trainer.fit(model)
trainer.tune(model)
after_batch_size = model.batch_size
assert before_batch_size != after_batch_size, \
'Batch size was not altered after running auto scaling of batch size'
Expand Down Expand Up @@ -232,7 +232,7 @@ def dataloader(self, *args, **kwargs):
model = model_class(**hparams)

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
trainer.fit(model)
trainer.tune(model)
after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size
assert before_batch_size != after_batch_size

Expand All @@ -246,7 +246,7 @@ def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, auto_scale_batch_size=True)
expected_message = "Field `model.batch_size` and `model.hparams.batch_size` are mutually exclusive!"
with pytest.warns(UserWarning, match=expected_message):
trainer.fit(model)
trainer.tune(model)


@pytest.mark.parametrize('scale_method', ['power', 'binsearch'])
Expand Down Expand Up @@ -288,7 +288,7 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):
fit_options = dict(train_dataloader=model.dataloader(train=True))

with pytest.raises(MisconfigurationException):
trainer.fit(model, **fit_options)
trainer.tune(model, **fit_options)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
Expand All @@ -303,7 +303,7 @@ def test_auto_scale_batch_size_with_amp(tmpdir):
gpus=1,
precision=16
)
trainer.fit(model)
trainer.tune(model)
batch_size_after = model.batch_size
assert trainer.amp_backend == AMPType.NATIVE
assert trainer.scaler is not None
Expand Down