diff --git a/hydrolib/core/basemodel.py b/hydrolib/core/basemodel.py index 8ab92d2d6..29e398fd9 100644 --- a/hydrolib/core/basemodel.py +++ b/hydrolib/core/basemodel.py @@ -22,6 +22,7 @@ Tuple, Type, TypeVar, + Union, ) from weakref import WeakValueDictionary @@ -673,16 +674,17 @@ class FileModel(BaseModel, ABC): # Absolute anchor is used to resolve the save location when the filepath is relative. _absolute_anchor_path: Path = PrivateAttr(default_factory=Path.cwd) - def __new__(cls, filepath: Optional[Path] = None, *args, **kwargs): + def __new__(cls, filepath: Optional[Union[Path,str]] = None, *args, **kwargs): """Create a new model. If the file at the provided file path was already parsed, this instance is returned. Args: - filepath (Optional[Path], optional): The file path to the file. Defaults to None. + filepath (Optional[Union[Path,str]], optional): The file path to the file. Defaults to None. Returns: FileModel: A file model. """ + filepath = FileModel._change_to_path(filepath) with file_load_context() as context: if (file_model := context.retrieve_model(filepath)) is not None: return file_model @@ -691,7 +693,7 @@ def __new__(cls, filepath: Optional[Path] = None, *args, **kwargs): def __init__( self, - filepath: Optional[Path] = None, + filepath: Optional[Union[Path,str]] = None, resolve_casing: bool = False, recurse: bool = True, *args, @@ -704,7 +706,7 @@ def __init__( If the filepath is provided, it is read from disk. Args: - filepath (Optional[Path], optional): The file path. Defaults to None. + filepath (Optional[Union[Path,str]], optional): The file path. Defaults to None. resolve_casing (bool, optional): Whether or not to resolve the file name references so that they match the case with what is on disk. Defaults to False. recurse (bool, optional): Whether or not to recursively load the model. Defaults to True. """ @@ -712,6 +714,8 @@ def __init__( super().__init__(*args, **kwargs) return + filepath = FileModel._change_to_path(filepath) + with file_load_context() as context: context.initialize_load_settings(recurse, resolve_casing) @@ -1011,6 +1015,19 @@ def _load(self, filepath: Path) -> Dict: def __str__(self) -> str: return str(self.filepath if self.filepath else "") + + @staticmethod + def _change_to_path(filepath): + if filepath is None: + return filepath + if isinstance(filepath, Path): + return filepath + else: + return Path(filepath) + + @validator("filepath") + def _conform_filepath_to_hydrolib_standard(cls, value): + return FileModel._change_to_path(value) class SerializerConfig(BaseModel, ABC): diff --git a/tests/test_basemodel.py b/tests/test_basemodel.py index dc4115cda..01adda04c 100644 --- a/tests/test_basemodel.py +++ b/tests/test_basemodel.py @@ -314,6 +314,49 @@ def test_initialize_model_with_resolve_casing_updates_file_references_recursivel ) + @pytest.mark.parametrize( + ("given_path", "expected_path"), + [ + pytest.param + ( + Path("test/path"), Path("test/path") + ), + pytest.param + ( + "test/path", Path("test/path") + ), + pytest.param + ( + None,None + ), + ], + ) + def test_setting_filepath(self, given_path, expected_path): + model = FMModel() + model.filepath = given_path + assert model.filepath == expected_path + + @pytest.mark.parametrize( + ("given_path", "expected_path"), + [ + pytest.param + ( + f"{test_input_dir}/file_load_test/fm.mdu", Path(f"{test_input_dir}/file_load_test/fm.mdu") + ), + pytest.param + ( + Path(f"{test_input_dir}/file_load_test/fm.mdu"), Path(f"{test_input_dir}/file_load_test/fm.mdu") + ), + pytest.param + ( + None,None + ), + ], + ) + def test_constuctor_filepath(self, given_path, expected_path): + model = FMModel(given_path) + assert model.filepath == expected_path + class TestContextManagerFileLoadContext: def test_context_is_created_and_disposed_properly(self): assert context_file_loading.get(None) is None