diff --git a/cascade/base/__init__.py b/cascade/base/__init__.py index 692a5d92..b2404e65 100644 --- a/cascade/base/__init__.py +++ b/cascade/base/__init__.py @@ -47,5 +47,5 @@ def raise_not_implemented(class_name: str, name: str) -> NoReturn: from .history_logger import HistoryLogger from .meta_handler import CustomEncoder as JSONEncoder -from .meta_handler import MetaHandler, supported_meta_formats +from .meta_handler import MetaHandler, default_meta_format, supported_meta_formats from .traceable import Traceable, TraceableOnDisk diff --git a/cascade/base/meta_handler.py b/cascade/base/meta_handler.py index a835da5f..a2c653d2 100644 --- a/cascade/base/meta_handler.py +++ b/cascade/base/meta_handler.py @@ -26,6 +26,7 @@ from . import MetaFromFile +default_meta_format = ".json" supported_meta_formats = (".json", ".yml", ".yaml") diff --git a/cascade/base/traceable.py b/cascade/base/traceable.py index 702bbeaa..348069f7 100644 --- a/cascade/base/traceable.py +++ b/cascade/base/traceable.py @@ -24,7 +24,7 @@ import pendulum from datetime import datetime -from . import PipeMeta, MetaFromFile, supported_meta_formats +from . import PipeMeta, MetaFromFile, default_meta_format, supported_meta_formats @dataclass @@ -265,16 +265,42 @@ class TraceableOnDisk(Traceable): def __init__( self, root: str, - meta_fmt: Literal[".json", ".yml", ".yaml"], + meta_fmt: Literal[".json", ".yml", ".yaml", None], *args: Any, meta_prefix: Union[Dict[Any, Any], str, None] = None, **kwargs: Any, ) -> None: super().__init__(*args, meta_prefix=meta_prefix, **kwargs) self._root = root - if meta_fmt not in supported_meta_formats: - raise ValueError(f"Only {supported_meta_formats} are supported formats") - self._meta_fmt = meta_fmt + + ext = self._determine_meta_fmt() + + if ext is None and meta_fmt is None: + meta_fmt = default_meta_format + elif not ext: + # Here we write meta first time and + # don't know the real ext from file + if meta_fmt not in supported_meta_formats: + raise ValueError(f"Only {supported_meta_formats} are supported formats") + self._meta_fmt = meta_fmt + else: + # Here we know the real extension and will + # strictly use it regardless of what was passed + self._meta_fmt = ext + if meta_fmt != ext: + warnings.warn( + f"Trying to set {meta_fmt} to the object that already has {ext} on path {self._root}" + ) + + def _determine_meta_fmt(self) -> Union[str, None]: + meta_paths = glob.glob(os.path.join(self._root, "meta.*")) + if len(meta_paths) == 1: + _, ext = os.path.splitext(meta_paths[0]) + return ext + else: + warnings.warn( + f"Multiple meta files found in {self._root}" + ) def _create_meta(self) -> None: meta_path = sorted(glob.glob(os.path.join(self._root, "meta.*"))) diff --git a/cascade/models/model_line.py b/cascade/models/model_line.py index 79e4bb51..4476081d 100644 --- a/cascade/models/model_line.py +++ b/cascade/models/model_line.py @@ -36,7 +36,7 @@ def __init__( self, folder: str, model_cls: Type = Model, - meta_fmt: Literal[".json", ".yml", ".yaml"] = ".json", + meta_fmt: Literal[".json", ".yml", ".yaml", None] = None, **kwargs: Any, ) -> None: """ @@ -49,7 +49,7 @@ def __init__( If folder does not exist, creates it model_cls: type, optional A class of models in line. ModelLine uses this class to reconstruct a model - meta_fmt: Literal[".json", ".yml", ".yaml"], optional + meta_fmt: Literal[".json", ".yml", ".yaml", None], optional Format in which to store meta data. See also -------- diff --git a/cascade/models/model_repo.py b/cascade/models/model_repo.py index 36ec0843..4636a20a 100644 --- a/cascade/models/model_repo.py +++ b/cascade/models/model_repo.py @@ -44,10 +44,6 @@ def __init__( super().__init__(*args, meta_prefix=meta_prefix, **kwargs) self._lines = dict() - def reload(self) -> None: - for line in self._lines: - self._lines[line].reload() - def __getitem__(self, key: str) -> ModelLine: raise NotImplementedError() @@ -75,6 +71,9 @@ def get_line_names(self) -> List[str]: """ return list(self._lines.keys()) + def reload(self) -> None: + pass + class SingleLineRepo(Repo): def __init__( @@ -86,10 +85,15 @@ def __init__( ) -> None: self._root = line.get_root() super().__init__(*args, meta_prefix=meta_prefix, **kwargs) - self._lines = {os.path.split(self._root)[-1]: line} + self._lines = {os.path.split(self._root)[-1]: {}} + self._line = line def __getitem__(self, key: str) -> ModelLine: - return self._lines[key] + if key in self._lines: + return self._line + else: + raise KeyError( + f"The only line is {list(self._lines.keys())[0]}, {key} does not exist") def __repr__(self) -> str: return f"SingleLine in {self._root}" @@ -162,26 +166,23 @@ def __init__( shutil.rmtree(self._root) os.makedirs(self._root, exist_ok=True) - self._load_lines() + self._lines = { + name: {} + for name in sorted(os.listdir(self._root)) + if os.path.isdir(os.path.join(self._root, name)) + } if lines is not None: for line in lines: - self.add_line(**line) + name = line["name"] + del line["name"] - self._create_meta() + self._lines[name] = { + "args": [], + "kwargs": line + } - def _load_lines(self) -> None: - self._lines = { - name: ModelLine( - os.path.join(self._root, name), - model_cls=self._model_cls - if isinstance(self._model_cls, type) - else self._model_cls[name], - meta_fmt=self._meta_fmt, - ) - for name in sorted(os.listdir(self._root)) - if os.path.isdir(os.path.join(self._root, name)) - } + self._create_meta() def add_line( self, @@ -223,10 +224,17 @@ def add_line( folder = os.path.join(self._root, name) if meta_fmt is None: meta_fmt = self._meta_fmt - line = ModelLine(folder, *args, meta_fmt=meta_fmt, **kwargs) - self._lines[name] = line + self._lines[name] = { + "args": args, + "kwargs": { + "meta_fmt": meta_fmt, + **kwargs + } + } self._update_meta() + + line = ModelLine(folder, *args, meta_fmt=meta_fmt, **kwargs) return line def __getitem__(self, key: Union[str, int]) -> ModelLine: @@ -236,13 +244,20 @@ def __getitem__(self, key: Union[str, int]) -> ModelLine: line: ModelLine existing line of the name passed in `key` """ - if isinstance(key, str): - return self._lines[key] - elif isinstance(key, int): - return self._lines[list(self._lines.keys())[key]] - else: + if isinstance(key, int): + key = list(self._lines.keys())[key] + elif not isinstance(key, str): raise TypeError(f"{type(key)} is not supported as key") + if key in self._lines: + return ModelLine( + os.path.join(self._root, key), + *self._lines[key]["args"], + **self._lines[key]["kwargs"], + ) + else: + raise KeyError(f"Line {key} does not exist in {self}") + def __repr__(self) -> str: return f"ModelRepo in {self._root} of {len(self)} lines" @@ -277,8 +292,9 @@ def load_model_meta(self, model: str) -> MetaFromFile: Raises if failed to find the model with slug specified """ - for line in self._lines.values(): + for name in self._lines: try: + line = ModelLine(os.path.join(self._root, name), *self._lines[name]["args"], **self._lines[name]["kwargs"]) meta = line.load_model_meta(model) except FileNotFoundError: continue @@ -294,13 +310,10 @@ def _update_lines(self) -> None: os.path.isdir(os.path.join(self._root, name)) and name not in self._lines ): - self._lines[name] = ModelLine( - os.path.join(self._root, name), - model_cls=self._model_cls - if isinstance(self._model_cls, type) - else self._model_cls[name], - meta_fmt=self._meta_fmt, - ) + self._lines[name] = { + "args": [], + "kwargs": dict() + } class ModelRepoConcatenator(Repo): diff --git a/cascade/tests/test_model_repo.py b/cascade/tests/test_model_repo.py index a2920cac..14e8371b 100644 --- a/cascade/tests/test_model_repo.py +++ b/cascade/tests/test_model_repo.py @@ -254,7 +254,7 @@ def test_failed_line_meta(tmp_path, ext): f.write("\t{{{: 'sorry, i am broken'") repo = ModelRepo( - repo_path, lines=[dict(name="0", model_cls=DummyModel, meta_fmt=ext)] + repo_path, lines=[dict(name="0", model_cls=DummyModel)], meta_fmt=ext ) model = repo["0"][0] @@ -299,8 +299,8 @@ def test_integer_indices(tmp_path, ext): first_line = repo.add_line("a") last_line = repo.add_line("b") - assert first_line == repo[0] - assert last_line == repo[-1] + assert first_line.get_root() == repo[0].get_root() + assert last_line.get_root() == repo[-1].get_root() @pytest.mark.parametrize("ext", [".json", ".yml", ".yaml"])