diff --git a/mlflow_export_import/model/export_model.py b/mlflow_export_import/model/export_model.py index aae4ded1..db530966 100644 --- a/mlflow_export_import/model/export_model.py +++ b/mlflow_export_import/model/export_model.py @@ -27,36 +27,43 @@ def __init__(self, mlflow_client, export_source_tags=False, notebook_formats=No self.export_run = export_run - def export_model(self, model_name, output_dir): + def export_model(self, model_name, output_dir, filter_version=None): """ :param model_name: Registered model name. :param output_dir: Output directory. - :return: Returns bool if export succeeded and the model name. + :return: Returns the manifest holding the run id """ try: - self._export_model(model_name, output_dir) - return True, model_name + return self._export_model(model_name, output_dir, filter_version) except Exception as e: print("ERROR:", e) - return False, model_name - + return [] - def _export_model(self, model_name, output_dir): + def _export_model(self, model_name, output_dir, filter_version): fs = _filesystem.get_filesystem(output_dir) - model = self.http_client.get(f"registered-models/get", {"name": model_name}) + model = self.http_client.get("registered-models/get", {"name": model_name}) fs.mkdirs(output_dir) - output_versions = [] + model["registered_model"]["latest_versions"] = [] versions = self.mlflow_client.search_model_versions(f"name='{model_name}'") print(f"Found {len(versions)} versions for model {model_name}") manifest = [] exported_versions = 0 for vr in versions: + try: + vr_version_int = int(vr.version) + except ValueError: + # Handle the exception + print(f"ERROR: {vr.version} no valid integer number") + if len(self.stages) > 0 and not vr.current_stage.lower() in self.stages: continue + if filter_version is not None and (filter_version != vr_version_int): + print(f"skipping model {vr.name} in version {vr.version}") + continue run_id = vr.run_id - opath = os.path.join(output_dir,run_id) + opath = os.path.join(output_dir, run_id) opath = opath.replace("dbfs:", "/dbfs") - dct = { "version": vr.version, "stage": vr.current_stage, "run_id": run_id } + dct = {"version": vr.version, "stage": vr.current_stage, "run_id": run_id} print(f"Exporting: {dct}") manifest.append(dct) try: @@ -67,18 +74,18 @@ def _export_model(self, model_name, output_dir): dct["_run_artifact_uri"] = run.info.artifact_uri experiment = mlflow.get_experiment(run.info.experiment_id) dct["_experiment_name"] = experiment.name - output_versions.append(dct) + model["registered_model"]["latest_versions"].append(dct) exported_versions += 1 except mlflow.exceptions.RestException as e: if "RESOURCE_DOES_NOT_EXIST: Run" in str(e): print(f"WARNING: Run for version {vr.version} does not exist. {e}") else: import traceback - traceback.print_exc() - output_versions.sort(key=lambda x: x["version"], reverse=False) - model["registered_model"]["latest_versions"] = output_versions - print(f"Exported {exported_versions}/{len(output_versions)} versions for model {model_name}") + traceback.print_exc() + print( + f"Exported {exported_versions}/{len(versions)} versions for model {model_name}" + ) path = os.path.join(output_dir, "model.json") utils.write_json_file(fs, path, model) return manifest