diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index d66cd3b41c993a..ddae993f1d2cc6 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -781,17 +781,26 @@ def setup(self, args, state, model): Environment: HF_MLFLOW_LOG_ARTIFACTS (`str`, *optional*): - Whether to use MLflow .log_artifact() facility to log artifacts. - - This only makes sense if logging to a remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy - whatever is in [`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it - without a remote storage will just copy the files to your artifact location. + Whether to use MLflow .log_artifact() facility to log artifacts. This only makes sense if logging to a + remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy whatever is in + [`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it without a remote + storage will just copy the files to your artifact location. + MLFLOW_EXPERIMENT_NAME (`str`, *optional*): + Whether to use an MLflow experiment_name under which to launch the run. Default to "None" which will + point to the "Default" experiment in MLflow. Otherwise, it is a case sensitive name of the experiment + to be activated. If an experiment with this name does not exist, a new experiment with this name is + created. """ log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() if log_artifacts in {"TRUE", "1"}: self._log_artifacts = True + experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None) + logger.debug(f"MLFLOW experiment_name={experiment_name}, run_name={args.run_name}") if state.is_world_process_zero: - if self._ml_flow.active_run is None: + if self._ml_flow.active_run() is None: + if experiment_name: + # Use of set_experiment() ensure that Experiment is created if not exists + self._ml_flow.set_experiment(experiment_name) self._ml_flow.start_run(run_name=args.run_name) combined_dict = args.to_dict() if hasattr(model, "config") and model.config is not None: @@ -844,7 +853,7 @@ def on_train_end(self, args, state, control, **kwargs): def __del__(self): # if the previous run is not terminated correctly, the fluent API will # not let you start a new run before the previous one is killed - if self._ml_flow.active_run is not None: + if self._ml_flow.active_run() is not None: self._ml_flow.end_run()