Skip to content

Commit

Permalink
Fix LightningCLI not saving correctly seed_everything for run=True (
Browse files Browse the repository at this point in the history
  • Loading branch information
mauvilsa authored and carmocca committed Jul 13, 2023
1 parent c875e55 commit efea6c1
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed FSDP re-applying activation checkpointing when the user had manually applied it already ([#18006](https://github.com/Lightning-AI/lightning/pull/18006))


- `LightningCLI` not saving correctly `seed_everything` when `run=True` and `seed_everything=True` ([#18056](https://github.com/Lightning-AI/lightning/pull/18056))


## [2.0.3] - 2023-06-07

### Changed
Expand Down
5 changes: 4 additions & 1 deletion src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,10 @@ def _set_seed(self) -> None:
config_seed = seed_everything(workers=True)
else:
config_seed = seed_everything(config_seed, workers=True)
self.config["seed_everything"] = config_seed
if self.subcommand:
self.config[self.subcommand]["seed_everything"] = config_seed
else:
self.config["seed_everything"] = config_seed


def _class_path_from_class(class_type: Type) -> str:
Expand Down
17 changes: 17 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,23 @@ def test_lightning_cli_save_config_only_once(cleandir):
cli.trainer.test(cli.model) # Should not fail because config already saved


def test_lightning_cli_save_config_seed_everything(cleandir):
config_path = Path("config.yaml")
cli_args = ["fit", "--seed_everything=true", "--trainer.logger=false", "--trainer.max_epochs=1"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(BoringModel)
config = yaml.safe_load(config_path.read_text())
assert isinstance(config["seed_everything"], int)
assert config["seed_everything"] == cli.config.fit.seed_everything

cli_args = ["--seed_everything=true", "--trainer.logger=false"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(BoringModel, run=False)
config = yaml.safe_load(config_path.read_text())
assert isinstance(config["seed_everything"], int)
assert config["seed_everything"] == cli.config.seed_everything


def test_save_to_log_dir_false_error():
with pytest.raises(ValueError):
SaveConfigCallback(
Expand Down

0 comments on commit efea6c1

Please sign in to comment.