diff --git a/nemo/core/connectors/save_restore_connector.py b/nemo/core/connectors/save_restore_connector.py index 856ade3a3eda..4b610fefe4d9 100644 --- a/nemo/core/connectors/save_restore_connector.py +++ b/nemo/core/connectors/save_restore_connector.py @@ -359,6 +359,8 @@ def register_artifact(self, model, config_path: str, src: str, verify_src_exists """ app_state = AppState() + artifact_item = model_utils.ArtifactItem() + # This is for backward compatibility, if the src objects exists simply inside of the tarfile # without its key having been overriden, this pathway will be used. src_obj_name = os.path.basename(src) @@ -370,18 +372,18 @@ def register_artifact(self, model, config_path: str, src: str, verify_src_exists # src is a local existing path - register artifact and return exact same path for usage by the model if os.path.exists(os.path.abspath(src)): return_path = os.path.abspath(src) - path_type = model_utils.ArtifactPathType.LOCAL_PATH + artifact_item.path_type = model_utils.ArtifactPathType.LOCAL_PATH # this is the case when artifact must be retried from the nemo file # we are assuming that the location of the right nemo file is available from _MODEL_RESTORE_PATH elif src.startswith("nemo:"): return_path = os.path.abspath(os.path.join(app_state.nemo_file_folder, src[5:])) - path_type = model_utils.ArtifactPathType.TAR_PATH + artifact_item.path_type = model_utils.ArtifactPathType.TAR_PATH # backward compatibility implementation elif os.path.exists(src_obj_path): return_path = src_obj_path - path_type = model_utils.ArtifactPathType.TAR_PATH + artifact_item.path_type = model_utils.ArtifactPathType.TAR_PATH else: if verify_src_exists: raise FileNotFoundError( @@ -396,7 +398,7 @@ def register_artifact(self, model, config_path: str, src: str, verify_src_exists assert os.path.exists(return_path) - artifact_item = model_utils.ArtifactItem(path=os.path.abspath(src), path_type=path_type,) + artifact_item.path = os.path.abspath(src) model.artifacts[config_path] = artifact_item # we were called by ModelPT if hasattr(model, "cfg"): @@ -487,9 +489,9 @@ def _handle_artifacts(self, model, nemo_file_folder): shutil.copy2(artifact_base_name, os.path.join(nemo_file_folder, artifact_uniq_name)) # Update artifacts registry - new_artiitem = model_utils.ArtifactItem( - path="nemo:" + artifact_uniq_name, path_type=model_utils.ArtifactPathType.TAR_PATH, - ) + new_artiitem = model_utils.ArtifactItem() + new_artiitem.path = "nemo:" + artifact_uniq_name + new_artiitem.path_type = model_utils.ArtifactPathType.TAR_PATH model.artifacts[conf_path] = new_artiitem finally: # change back working directory diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py index 122db3c400db..b2a6abbf54aa 100644 --- a/nemo/utils/model_utils.py +++ b/nemo/utils/model_utils.py @@ -56,8 +56,8 @@ class ArtifactPathType(Enum): @dataclass class ArtifactItem: - path: str - path_type: ArtifactPathType + path: str = "" + path_type: ArtifactPathType = ArtifactPathType.LOCAL_PATH hashed_path: Optional[str] = None