Skip to content

Commit

Permalink
Fix typehint for WandbLogger's log_model argument (#18458)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
  • Loading branch information
3 people authored Sep 5, 2023
1 parent c807bbb commit bff2e42
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed redundant `iter()` call to dataloader when checking dataloading configuration ([#18415](https://github.com/Lightning-AI/lightning/pull/18415))


- Fixed an issue that wouldn't prevent the user to set the `log_model` parameter in `WandbLogger` via the LightningCLI ([#18458](https://github.com/Lightning-AI/lightning/pull/18458))


## [2.0.7] - 2023-08-14

### Added
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
from argparse import Namespace
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Union
from typing import Any, Dict, List, Literal, Mapping, Optional, Union

import torch.nn as nn
from lightning_utilities.core.imports import RequirementCache
Expand Down Expand Up @@ -295,7 +295,7 @@ def __init__(
id: Optional[str] = None,
anonymous: Optional[bool] = None,
project: Optional[str] = None,
log_model: Union[str, bool] = False,
log_model: Union[Literal["all"], bool] = False,
experiment: Union[Run, RunDisabled, None] = None,
prefix: str = "",
checkpoint_name: Optional[str] = None,
Expand Down
38 changes: 38 additions & 0 deletions tests/tests_pytorch/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
from unittest import mock

import pytest
import yaml
from lightning_utilities.test.warning import no_warning_call

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -497,3 +499,39 @@ def test_wandb_logger_download_artifact(wandb, tmpdir):
WandbLogger.download_artifact("test_artifact", str(tmpdir), "model", True)

wandb.Api().artifact.assert_called_once_with("test_artifact", type="model")


@mock.patch("lightning.pytorch.loggers.wandb._WANDB_AVAILABLE", True)
@mock.patch("lightning.pytorch.loggers.wandb.wandb", mock.Mock())
@pytest.mark.parametrize(("log_model", "expected"), [("True", True), ("False", False), ("all", "all")])
def test_wandb_logger_cli_integration(log_model, expected, monkeypatch, tmp_path):
"""Test that the WandbLogger can be used with the LightningCLI."""
monkeypatch.chdir(tmp_path)

class InspectParsedCLI(LightningCLI):
def before_instantiate_classes(self):
assert self.config.trainer.logger.init_args.log_model == expected

# Create a config file with the log_model parameter set. This seems necessary to be able
# to set the init_args parameter of the logger on the CLI later on.
input_config = {
"trainer": {
"logger": {
"class_path": "lightning.pytorch.loggers.wandb.WandbLogger",
"init_args": {"log_model": log_model},
},
}
}
config_path = "config.yaml"
with open(config_path, "w") as f:
f.write(yaml.dump(input_config))

# Test case 1: Set the log_model parameter only via the config file.
with mock.patch("sys.argv", ["any.py", "--config", config_path]):
InspectParsedCLI(BoringModel, run=False, save_config_callback=None)

# Test case 2: Overwrite the log_model parameter via the command line.
wandb_cli_arg = f"--trainer.logger.init_args.log_model={log_model}"

with mock.patch("sys.argv", ["any.py", "--config", config_path, wandb_cli_arg]):
InspectParsedCLI(BoringModel, run=False, save_config_callback=None)

0 comments on commit bff2e42

Please sign in to comment.