Skip to content

Commit

Permalink
Fix MLflowCallback and add support for MLFLOW_EXPERIMENT_NAME (huggin…
Browse files Browse the repository at this point in the history
…gface#17091)

* Fix use of mlflow.active_run() and add proper support for MLFLOW_EXPERIMENT_NAME

* Fix code style (make style)
  • Loading branch information
orieg authored and nandwalritik committed May 10, 2022
1 parent 70e3256 commit fa4babf
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,17 +781,26 @@ def setup(self, args, state, model):
Environment:
HF_MLFLOW_LOG_ARTIFACTS (`str`, *optional*):
Whether to use MLflow .log_artifact() facility to log artifacts.
This only makes sense if logging to a remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy
whatever is in [`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it
without a remote storage will just copy the files to your artifact location.
Whether to use MLflow .log_artifact() facility to log artifacts. This only makes sense if logging to a
remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy whatever is in
[`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it without a remote
storage will just copy the files to your artifact location.
MLFLOW_EXPERIMENT_NAME (`str`, *optional*):
Whether to use an MLflow experiment_name under which to launch the run. Default to "None" which will
point to the "Default" experiment in MLflow. Otherwise, it is a case sensitive name of the experiment
to be activated. If an experiment with this name does not exist, a new experiment with this name is
created.
"""
log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper()
if log_artifacts in {"TRUE", "1"}:
self._log_artifacts = True
experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
logger.debug(f"MLFLOW experiment_name={experiment_name}, run_name={args.run_name}")
if state.is_world_process_zero:
if self._ml_flow.active_run is None:
if self._ml_flow.active_run() is None:
if experiment_name:
# Use of set_experiment() ensure that Experiment is created if not exists
self._ml_flow.set_experiment(experiment_name)
self._ml_flow.start_run(run_name=args.run_name)
combined_dict = args.to_dict()
if hasattr(model, "config") and model.config is not None:
Expand Down Expand Up @@ -844,7 +853,7 @@ def on_train_end(self, args, state, control, **kwargs):
def __del__(self):
# if the previous run is not terminated correctly, the fluent API will
# not let you start a new run before the previous one is killed
if self._ml_flow.active_run is not None:
if self._ml_flow.active_run() is not None:
self._ml_flow.end_run()


Expand Down

0 comments on commit fa4babf

Please sign in to comment.