Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MLflowCallback and add support for MLFLOW_EXPERIMENT_NAME #17091

Merged
merged 2 commits into from
May 5, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,17 +781,26 @@ def setup(self, args, state, model):

Environment:
HF_MLFLOW_LOG_ARTIFACTS (`str`, *optional*):
Whether to use MLflow .log_artifact() facility to log artifacts.

This only makes sense if logging to a remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy
whatever is in [`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it
without a remote storage will just copy the files to your artifact location.
Whether to use MLflow .log_artifact() facility to log artifacts. This only makes sense if logging to a
remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy whatever is in
[`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it without a remote
storage will just copy the files to your artifact location.
MLFLOW_EXPERIMENT_NAME (`str`, *optional*):
Whether to use an MLflow experiment_name under which to launch the run. Default to "None" which will
point to the "Default" experiment in MLflow. Otherwise, it is a case sensitive name of the experiment
to be activated. If an experiment with this name does not exist, a new experiment with this name is
created.
"""
log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper()
if log_artifacts in {"TRUE", "1"}:
self._log_artifacts = True
experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
logger.debug(f"MLFLOW experiment_name={experiment_name}, run_name={args.run_name}")
if state.is_world_process_zero:
if self._ml_flow.active_run is None:
if self._ml_flow.active_run() is None:
if experiment_name:
# Use of set_experiment() ensure that Experiment is created if not exists
self._ml_flow.set_experiment(experiment_name)
self._ml_flow.start_run(run_name=args.run_name)
combined_dict = args.to_dict()
if hasattr(model, "config") and model.config is not None:
Expand Down Expand Up @@ -844,7 +853,7 @@ def on_train_end(self, args, state, control, **kwargs):
def __del__(self):
# if the previous run is not terminated correctly, the fluent API will
# not let you start a new run before the previous one is killed
if self._ml_flow.active_run is not None:
if self._ml_flow.active_run() is not None:
self._ml_flow.end_run()


Expand Down