Skip to content

Commit

Permalink
fix dirpath in log_dir for CSVLogger (#16401)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
Borda and carmocca authored Jan 17, 2023
1 parent cfe87a0 commit d46091c
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 13 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/ci-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,10 @@ jobs:
flags: ${COVERAGE_SCOPE},cpu,pytest-full,python${{ matrix.python-version }},pytorch${{ matrix.pytorch-version }}
name: CPU-coverage
fail_ci_if_error: false

# TODO
# - name: Testing legacy creation
# working-directory: tests/
# run: |
# export PYTHONPATH=$(dirname $LEGACY_PATH);$PYTHONPATH # for `import tests_pytorch`
# python legacy/simple_classif_training.py
11 changes: 4 additions & 7 deletions src/pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,14 +582,11 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> _PATH:
return self.dirpath

if len(trainer.loggers) > 0:
if trainer.loggers[0].save_dir is not None:
save_dir = trainer.loggers[0].save_dir
else:
save_dir = trainer.default_root_dir
name = trainer.loggers[0].name
version = trainer.loggers[0].version
logger_ = trainer.loggers[0]
save_dir = getattr(logger_, "save_dir", None) or trainer.default_root_dir
version = logger_.version
version = version if isinstance(version, str) else f"version_{version}"
ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints")
ckpt_path = os.path.join(save_dir, str(logger_.name), version, "checkpoints")
else:
# if no loggers, use default_root_dir
ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints")
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/loggers/csv_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ExperimentWriter(_FabricExperimentWriter):
r"""
Experiment writer for CSVLogger.
Currently supports to log hyperparameters and metrics in YAML and CSV
Currently, supports to log hyperparameters and metrics in YAML and CSV
format, respectively.
Args:
Expand Down
7 changes: 2 additions & 5 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.loggers import Logger
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.loops import PredictionLoop, TrainingEpochLoop
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
from pytorch_lightning.loops.fit_loop import FitLoop
Expand Down Expand Up @@ -1807,10 +1806,8 @@ def model(self, model: torch.nn.Module) -> None:
@property
def log_dir(self) -> Optional[str]:
if len(self.loggers) > 0:
if not isinstance(self.loggers[0], TensorBoardLogger):
dirpath = self.loggers[0].save_dir
else:
dirpath = self.loggers[0].log_dir
logger_ = self.loggers[0]
dirpath = getattr(logger_, "log_dir", None) or getattr(logger_, "save_dir", None)
else:
dirpath = self.default_root_dir

Expand Down

0 comments on commit d46091c

Please sign in to comment.