diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 45ef3c3c840b..330fccb20d4f 100644 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -131,13 +131,6 @@ def is_mlflow_available(): return importlib.util.find_spec("mlflow") is not None -def get_mlflow_version(): - try: - return importlib.metadata.version("mlflow") - except importlib.metadata.PackageNotFoundError: - return importlib.metadata.version("mlflow-skinny") - - def is_dagshub_available(): return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")] @@ -1005,12 +998,12 @@ def setup(self, args, state, model): self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None) self._flatten_params = os.getenv("MLFLOW_FLATTEN_PARAMS", "FALSE").upper() in ENV_VARS_TRUE_VALUES self._run_id = os.getenv("MLFLOW_RUN_ID", None) - self._async_log = False + # "synchronous" flag is only available with mlflow version >= 2.8.0 # https://github.com/mlflow/mlflow/pull/9705 # https://github.com/mlflow/mlflow/releases/tag/v2.8.0 - if packaging.version.parse(get_mlflow_version()) >= packaging.version.parse("2.8.0"): - self._async_log = True + self._async_log = packaging.version.parse(self._ml_flow.__version__) >= packaging.version.parse("2.8.0") + logger.debug( f"MLflow experiment_name={self._experiment_name}, run_name={args.run_name}, nested={self._nested_run}," f" tags={self._nested_run}, tracking_uri={self._tracking_uri}"