diff --git a/composer/loggers/mlflow_logger.py b/composer/loggers/mlflow_logger.py index 27ddccab31..a9abed9362 100644 --- a/composer/loggers/mlflow_logger.py +++ b/composer/loggers/mlflow_logger.py @@ -10,6 +10,7 @@ import os import pathlib import posixpath +import sys import textwrap import time import warnings @@ -487,13 +488,24 @@ def log_images( ) def post_close(self): - if self._enabled: - import mlflow + import pdb + pdb.set_trace() + if not self._enabled or self._run_id is None: + return - assert isinstance(self._run_id, str) - mlflow.flush_async_logging() - self._mlflow_client.set_terminated(self._run_id) - mlflow.end_run() + import mlflow + + exc_tpe, exc_info, tb = sys.exc_info() + if (exc_tpe, exc_info, tb) == (None, None, None): + status = 'FINISHED' + else: + status = 'FAILED' + + log.info(f'Finishing MLflow run with status {status}') + mlflow.flush_async_logging() + + self._mlflow_client.set_terminated(self._run_id, status=status) + mlflow.end_run(status=status) def _convert_to_mlflow_image(image: Union[np.ndarray, torch.Tensor], channels_last: bool) -> np.ndarray: diff --git a/tests/loggers/test_mlflow_logger.py b/tests/loggers/test_mlflow_logger.py index 61d52d8023..e64b86f99d 100644 --- a/tests/loggers/test_mlflow_logger.py +++ b/tests/loggers/test_mlflow_logger.py @@ -640,7 +640,7 @@ def before_forward(self, state: State, logger: Logger): ) trainer.fit() - test_mlflow_logger.post_close() + # test_mlflow_logger.post_close() run = _get_latest_mlflow_run( experiment_name=experiment_name,