Skip to content

Commit

Permalink
Mlflow integration callback (#8016)
Browse files Browse the repository at this point in the history
* Add MLflow integration class

Add integration code for MLflow in integrations.py along with the code
that checks that MLflow is installed.

* Add MLflowCallback import

Add import of MLflowCallback in trainer.py

* Handle model argument

Allow the callback to handle model argument and store model config items as hyperparameters.

* Log parameters to MLflow in batches

MLflow cannot log more than a hundred parameters at once.
Code added to split the parameters into batches of 100 items and log the batches one by one.

* Fix style

* Add docs on MLflow callback

* Fix issue with unfinished runs

The "fluent" api used in MLflow integration allows only one run to be active at any given moment. If the Trainer is disposed off and a new one is created, but the training is not finished, it will refuse to log the results when the next trainer is created.

* Add MLflow integration class

Add integration code for MLflow in integrations.py along with the code
that checks that MLflow is installed.

* Add MLflowCallback import

Add import of MLflowCallback in trainer.py

* Handle model argument

Allow the callback to handle model argument and store model config items as hyperparameters.

* Log parameters to MLflow in batches

MLflow cannot log more than a hundred parameters at once.
Code added to split the parameters into batches of 100 items and log the batches one by one.

* Fix style

* Add docs on MLflow callback

* Fix issue with unfinished runs

The "fluent" api used in MLflow integration allows only one run to be active at any given moment. If the Trainer is disposed off and a new one is created, but the training is not finished, it will refuse to log the results when the next trainer is created.
  • Loading branch information
noise-field authored Oct 26, 2020
1 parent 8be9cb0 commit c48b16b
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/source/main_classes/callback.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ By default a :class:`~transformers.Trainer` will use the following callbacks:
or tensorboardX).
- :class:`~transformers.integrations.WandbCallback` if `wandb <https://www.wandb.com/>`__ is installed.
- :class:`~transformers.integrations.CometCallback` if `comet_ml <https://www.comet.ml/site/>`__ is installed.
- :class:`~transformers.integrations.MLflowCallback` if `mlflow <https://www.mlflow.org/>`__ is installed.

The main class that implements callbacks is :class:`~transformers.TrainerCallback`. It gets the
:class:`~transformers.TrainingArguments` used to instantiate the :class:`~transformers.Trainer`, can access that
Expand All @@ -46,6 +47,9 @@ Here is the list of the available :class:`~transformers.TrainerCallback` in the
.. autoclass:: transformers.integrations.WandbCallback
:members: setup

.. autoclass:: transformers.integrations.MLflowCallback
:members: setup


TrainerCallback
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
89 changes: 89 additions & 0 deletions src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@
except ImportError:
_has_tensorboard = False

try:
import mlflow # noqa: F401

_has_mlflow = True
except ImportError:
_has_mlflow = False


# No transformer imports above this point

from .file_utils import is_torch_tpu_available
Expand Down Expand Up @@ -85,6 +93,10 @@ def is_ray_available():
return _has_ray


def is_mlflow_available():
return _has_mlflow


def hp_params(trial):
if is_optuna_available():
if isinstance(trial, optuna.Trial):
Expand Down Expand Up @@ -408,3 +420,80 @@ def on_log(self, args, state, control, model=None, logs=None, **kwargs):
experiment = comet_ml.config.get_global_experiment()
if experiment is not None:
experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")


class MLflowCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that sends the logs to `MLflow
<https://www.mlflow.org/>`__.
"""

MAX_LOG_SIZE = 100

def __init__(self):
assert _has_mlflow, "MLflow requires mlflow to be installed. Run `pip install mlflow`."
self._initialized = False
self._log_artifacts = False

def setup(self, args, state, model):
"""
Setup the optional MLflow integration.
Environment:
HF_MLFLOW_LOG_ARTIFACTS (:obj:`str`, `optional`):
Whether to use MLflow .log_artifact() facility to log artifacts.
This only makes sense if logging to a remote server, e.g. s3 or GCS.
If set to `True` or `1`, will copy whatever is in TrainerArgument's output_dir
to the local or remote artifact storage. Using it without a remote storage
will just copy the files to your artifact location.
"""
log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper()
if log_artifacts in {"TRUE", "1"}:
self._log_artifacts = True
if state.is_world_process_zero:
mlflow.start_run()
combined_dict = args.to_dict()
if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict}
# MLflow cannot log more than 100 values in one go, so we have to split it
combined_dict_items = list(combined_dict.items())
for i in range(0, len(combined_dict_items), MLflowCallback.MAX_LOG_SIZE):
mlflow.log_params(dict(combined_dict_items[i : i + MLflowCallback.MAX_LOG_SIZE]))
self._initialized = True

def on_train_begin(self, args, state, control, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)

def on_log(self, args, state, control, logs, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
for k, v in logs.items():
if isinstance(v, (int, float)):
mlflow.log_metric(k, v, step=state.global_step)
else:
logger.warning(
"Trainer is attempting to log a value of "
'"%s" of type %s for key "%s" as a metric. '
"MLflow's log_metric() only accepts float and "
"int types so we dropped this attribute.",
v,
type(v),
k,
)

def on_train_end(self, args, state, control, **kwargs):
if self._initialized and state.is_world_process_zero:
if self._log_artifacts:
logger.info("Logging artifacts. This may take time.")
mlflow.log_artifacts(args.output_dir)
mlflow.end_run()

def __del__(self):
# if the previous run is not terminated correctly, the fluent API will
# not let you start a new run before the previous one is killed
if mlflow.active_run is not None:
mlflow.end_run(status="KILLED")
6 changes: 6 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
default_hp_search_backend,
hp_params,
is_comet_available,
is_mlflow_available,
is_optuna_available,
is_ray_available,
is_tensorboard_available,
Expand Down Expand Up @@ -139,6 +140,11 @@

DEFAULT_CALLBACKS.append(CometCallback)

if is_mlflow_available():
from .integrations import MLflowCallback

DEFAULT_CALLBACKS.append(MLflowCallback)

if is_optuna_available():
import optuna

Expand Down

0 comments on commit c48b16b

Please sign in to comment.