Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ module = [
"pytorch_lightning.demos.boring_classes",
"pytorch_lightning.demos.mnist_datamodule",
"pytorch_lightning.loggers.comet",
"pytorch_lightning.loggers.mlflow",
"pytorch_lightning.loggers.neptune",
"pytorch_lightning.loggers.tensorboard",
"pytorch_lightning.loggers.wandb",
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/loggers/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,12 @@ def group_separator(self) -> str:

@property
@abstractmethod
def name(self) -> str:
def name(self) -> Optional[str]:
"""Return the experiment name."""

@property
@abstractmethod
def version(self) -> Union[int, str]:
def version(self) -> Union[int, str, None]:
"""Return the experiment version."""


Expand Down
24 changes: 17 additions & 7 deletions src/pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,17 @@
from mlflow.tracking.context.registry import resolve_tags
else:

def resolve_tags(tags=None):
def resolve_tags(tags: Optional[Dict] = None) -> Optional[Dict]:
"""
Args:
tags: A dictionary of tags to override. If specified, tags passed in this argument will
override those inferred from the context.

Returns: A dictionary of resolved tags.

Note:
See ``mlflow.tracking.context.registry`` for more details.
"""
return tags


Expand Down Expand Up @@ -129,7 +139,7 @@ def __init__(
tracking_uri = f"{LOCAL_FILE_URI_PREFIX}{save_dir}"

self._experiment_name = experiment_name
self._experiment_id = None
self._experiment_id: Optional[str] = None
self._tracking_uri = tracking_uri
self._run_name = run_name
self._run_id = run_id
Expand All @@ -141,7 +151,7 @@ def __init__(

self._mlflow_client = MlflowClient(tracking_uri)

@property
@property # type: ignore[misc]
@rank_zero_experiment
def experiment(self) -> MlflowClient:
r"""
Expand Down Expand Up @@ -187,7 +197,7 @@ def experiment(self) -> MlflowClient:
return self._mlflow_client

@property
def run_id(self) -> str:
def run_id(self) -> Optional[str]:
"""Create the experiment if it does not exist to get the run id.

Returns:
Expand All @@ -197,7 +207,7 @@ def run_id(self) -> str:
return self._run_id

@property
def experiment_id(self) -> str:
def experiment_id(self) -> Optional[str]:
"""Create the experiment if it does not exist to get the experiment id.

Returns:
Expand Down Expand Up @@ -261,7 +271,7 @@ def save_dir(self) -> Optional[str]:
return self._tracking_uri.lstrip(LOCAL_FILE_URI_PREFIX)

@property
def name(self) -> str:
def name(self) -> Optional[str]:
"""Get the experiment id.

Returns:
Expand All @@ -270,7 +280,7 @@ def name(self) -> str:
return self.experiment_id

@property
def version(self) -> str:
def version(self) -> Optional[str]:
"""Get the run id.

Returns:
Expand Down