diff --git a/monai/apps/auto3dseg/auto_runner.py b/monai/apps/auto3dseg/auto_runner.py index 23fb3eb250..80ae34180e 100644 --- a/monai/apps/auto3dseg/auto_runner.py +++ b/monai/apps/auto3dseg/auto_runner.py @@ -83,6 +83,8 @@ class AutoRunner: zip url will be downloaded and extracted into the work_dir. allow_skip: a switch passed to BundleGen process which determines if some Algo in the default templates can be skipped based on the analysis on the dataset from Auto3DSeg DataAnalyzer. + mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of the remote + tracking Server; MLflow runs will be recorded locally in algorithms' model folder if the value is None. kwargs: image writing parameters for the ensemble inference. The kwargs format follows the SaveImage transform. For more information, check https://docs.monai.io/en/stable/transforms.html#saveimage. @@ -209,6 +211,7 @@ def __init__( not_use_cache: bool = False, templates_path_or_url: str | None = None, allow_skip: bool = True, + mlflow_tracking_uri: str | None = None, **kwargs: Any, ): logger.info(f"AutoRunner using work directory {work_dir}") @@ -220,6 +223,7 @@ def __init__( self.algos = algos self.templates_path_or_url = templates_path_or_url self.allow_skip = allow_skip + self.mlflow_tracking_uri = mlflow_tracking_uri self.kwargs = deepcopy(kwargs) if input is None and os.path.isfile(self.data_src_cfg_name): @@ -783,6 +787,7 @@ def run(self): templates_path_or_url=self.templates_path_or_url, data_stats_filename=self.datastats_filename, data_src_cfg_name=self.data_src_cfg_name, + mlflow_tracking_uri=self.mlflow_tracking_uri, ) if self.gpu_customization: diff --git a/monai/apps/auto3dseg/bundle_gen.py b/monai/apps/auto3dseg/bundle_gen.py index a091739dd3..03b9c8bbf4 100644 --- a/monai/apps/auto3dseg/bundle_gen.py +++ b/monai/apps/auto3dseg/bundle_gen.py @@ -85,6 +85,7 @@ def __init__(self, template_path: PathLike): self.template_path = template_path self.data_stats_files = "" self.data_list_file = "" + self.mlflow_tracking_uri = None self.output_path = "" self.name = "" self.best_metric = None @@ -129,6 +130,17 @@ def set_data_source(self, data_src_cfg: str) -> None: """ self.data_list_file = data_src_cfg + def set_mlflow_tracking_uri(self, mlflow_tracking_uri: str | None) -> None: + """ + Set the tracking URI for MLflow server + + Args: + mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of + the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if + the value is None. + """ + self.mlflow_tracking_uri = mlflow_tracking_uri # type: ignore + def fill_template_config(self, data_stats_filename: str, algo_path: str, **kwargs: Any) -> dict: """ The configuration files defined when constructing this Algo instance might not have a complete training @@ -432,6 +444,9 @@ class BundleGen(AlgoGen): data_stats_filename: the path to the data stats file (generated by DataAnalyzer). data_src_cfg_name: the path to the data source config YAML file. The config will be in a form of {"modality": "ct", "datalist": "path_to_json_datalist", "dataroot": "path_dir_data"}. + mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of + the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if + the value is None. .. code-block:: bash python -m monai.apps.auto3dseg BundleGen generate --data_stats_filename="../algorithms/datastats.yaml" @@ -444,6 +459,7 @@ def __init__( templates_path_or_url: str | None = None, data_stats_filename: str | None = None, data_src_cfg_name: str | None = None, + mlflow_tracking_uri: str | None = None, ): if algos is None or isinstance(algos, (list, tuple, str)): if templates_path_or_url is None: @@ -496,6 +512,7 @@ def __init__( self.data_stats_filename = data_stats_filename self.data_src_cfg_name = data_src_cfg_name + self.mlflow_tracking_uri = mlflow_tracking_uri self.history: list[dict] = [] def set_data_stats(self, data_stats_filename: str) -> None: @@ -524,6 +541,21 @@ def get_data_src(self): """Get the data source filename""" return self.data_src_cfg_name + def set_mlflow_tracking_uri(self, mlflow_tracking_uri): + """ + Set the tracking URI for MLflow server + + Args: + mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of + the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if + the value is None. + """ + self.mlflow_tracking_uri = mlflow_tracking_uri + + def get_mlflow_tracking_uri(self): + """Get the tracking URI for MLflow server""" + return self.mlflow_tracking_uri + def get_history(self) -> list: """Get the history of the bundleAlgo object with their names/identifiers""" return self.history @@ -575,9 +607,11 @@ def generate( for f_id in ensure_tuple(fold_idx): data_stats = self.get_data_stats() data_src_cfg = self.get_data_src() + mlflow_tracking_uri = self.get_mlflow_tracking_uri() gen_algo = deepcopy(algo) gen_algo.set_data_stats(data_stats) gen_algo.set_data_source(data_src_cfg) + gen_algo.set_mlflow_tracking_uri(mlflow_tracking_uri) name = f"{gen_algo.name}_{f_id}" if allow_skip: