Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy Repo #206

Merged
merged 6 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cascade/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@ def raise_not_implemented(class_name: str, name: str) -> NoReturn:

from .history_logger import HistoryLogger
from .meta_handler import CustomEncoder as JSONEncoder
from .meta_handler import MetaHandler, supported_meta_formats
from .meta_handler import MetaHandler, default_meta_format, supported_meta_formats
from .traceable import Traceable, TraceableOnDisk
1 change: 1 addition & 0 deletions cascade/base/meta_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from . import MetaFromFile

default_meta_format = ".json"
supported_meta_formats = (".json", ".yml", ".yaml")


Expand Down
36 changes: 31 additions & 5 deletions cascade/base/traceable.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import pendulum
from datetime import datetime

from . import PipeMeta, MetaFromFile, supported_meta_formats
from . import PipeMeta, MetaFromFile, default_meta_format, supported_meta_formats


@dataclass
Expand Down Expand Up @@ -265,16 +265,42 @@ class TraceableOnDisk(Traceable):
def __init__(
self,
root: str,
meta_fmt: Literal[".json", ".yml", ".yaml"],
meta_fmt: Literal[".json", ".yml", ".yaml", None],
*args: Any,
meta_prefix: Union[Dict[Any, Any], str, None] = None,
**kwargs: Any,
) -> None:
super().__init__(*args, meta_prefix=meta_prefix, **kwargs)
self._root = root
if meta_fmt not in supported_meta_formats:
raise ValueError(f"Only {supported_meta_formats} are supported formats")
self._meta_fmt = meta_fmt

ext = self._determine_meta_fmt()

if ext is None and meta_fmt is None:
meta_fmt = default_meta_format
elif not ext:
# Here we write meta first time and
# don't know the real ext from file
if meta_fmt not in supported_meta_formats:
raise ValueError(f"Only {supported_meta_formats} are supported formats")
self._meta_fmt = meta_fmt
else:
# Here we know the real extension and will
# strictly use it regardless of what was passed
self._meta_fmt = ext
if meta_fmt != ext:
warnings.warn(
f"Trying to set {meta_fmt} to the object that already has {ext} on path {self._root}"
)

def _determine_meta_fmt(self) -> Union[str, None]:
meta_paths = glob.glob(os.path.join(self._root, "meta.*"))
if len(meta_paths) == 1:
_, ext = os.path.splitext(meta_paths[0])
return ext
else:
warnings.warn(
f"Multiple meta files found in {self._root}"
)

def _create_meta(self) -> None:
meta_path = sorted(glob.glob(os.path.join(self._root, "meta.*")))
Expand Down
4 changes: 2 additions & 2 deletions cascade/models/model_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
self,
folder: str,
model_cls: Type = Model,
meta_fmt: Literal[".json", ".yml", ".yaml"] = ".json",
meta_fmt: Literal[".json", ".yml", ".yaml", None] = None,
**kwargs: Any,
) -> None:
"""
Expand All @@ -49,7 +49,7 @@ def __init__(
If folder does not exist, creates it
model_cls: type, optional
A class of models in line. ModelLine uses this class to reconstruct a model
meta_fmt: Literal[".json", ".yml", ".yaml"], optional
meta_fmt: Literal[".json", ".yml", ".yaml", None], optional
Format in which to store meta data.
See also
--------
Expand Down
85 changes: 49 additions & 36 deletions cascade/models/model_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ def __init__(
super().__init__(*args, meta_prefix=meta_prefix, **kwargs)
self._lines = dict()

def reload(self) -> None:
for line in self._lines:
self._lines[line].reload()

def __getitem__(self, key: str) -> ModelLine:
raise NotImplementedError()

Expand Down Expand Up @@ -75,6 +71,9 @@ def get_line_names(self) -> List[str]:
"""
return list(self._lines.keys())

def reload(self) -> None:
pass


class SingleLineRepo(Repo):
def __init__(
Expand All @@ -86,10 +85,15 @@ def __init__(
) -> None:
self._root = line.get_root()
super().__init__(*args, meta_prefix=meta_prefix, **kwargs)
self._lines = {os.path.split(self._root)[-1]: line}
self._lines = {os.path.split(self._root)[-1]: {}}
self._line = line

def __getitem__(self, key: str) -> ModelLine:
return self._lines[key]
if key in self._lines:
return self._line
else:
raise KeyError(
f"The only line is {list(self._lines.keys())[0]}, {key} does not exist")

def __repr__(self) -> str:
return f"SingleLine in {self._root}"
Expand Down Expand Up @@ -162,26 +166,23 @@ def __init__(
shutil.rmtree(self._root)

os.makedirs(self._root, exist_ok=True)
self._load_lines()
self._lines = {
name: {}
for name in sorted(os.listdir(self._root))
if os.path.isdir(os.path.join(self._root, name))
}

if lines is not None:
for line in lines:
self.add_line(**line)
name = line["name"]
del line["name"]

self._create_meta()
self._lines[name] = {
"args": [],
"kwargs": line
}

def _load_lines(self) -> None:
self._lines = {
name: ModelLine(
os.path.join(self._root, name),
model_cls=self._model_cls
if isinstance(self._model_cls, type)
else self._model_cls[name],
meta_fmt=self._meta_fmt,
)
for name in sorted(os.listdir(self._root))
if os.path.isdir(os.path.join(self._root, name))
}
self._create_meta()

def add_line(
self,
Expand Down Expand Up @@ -223,10 +224,17 @@ def add_line(
folder = os.path.join(self._root, name)
if meta_fmt is None:
meta_fmt = self._meta_fmt
line = ModelLine(folder, *args, meta_fmt=meta_fmt, **kwargs)
self._lines[name] = line

self._lines[name] = {
"args": args,
"kwargs": {
"meta_fmt": meta_fmt,
**kwargs
}
}
self._update_meta()

line = ModelLine(folder, *args, meta_fmt=meta_fmt, **kwargs)
return line

def __getitem__(self, key: Union[str, int]) -> ModelLine:
Expand All @@ -236,13 +244,20 @@ def __getitem__(self, key: Union[str, int]) -> ModelLine:
line: ModelLine
existing line of the name passed in `key`
"""
if isinstance(key, str):
return self._lines[key]
elif isinstance(key, int):
return self._lines[list(self._lines.keys())[key]]
else:
if isinstance(key, int):
key = list(self._lines.keys())[key]
elif not isinstance(key, str):
raise TypeError(f"{type(key)} is not supported as key")

if key in self._lines:
return ModelLine(
os.path.join(self._root, key),
*self._lines[key]["args"],
**self._lines[key]["kwargs"],
)
else:
raise KeyError(f"Line {key} does not exist in {self}")

def __repr__(self) -> str:
return f"ModelRepo in {self._root} of {len(self)} lines"

Expand Down Expand Up @@ -277,8 +292,9 @@ def load_model_meta(self, model: str) -> MetaFromFile:
Raises if failed to find the model with slug specified
"""

for line in self._lines.values():
for name in self._lines:
try:
line = ModelLine(os.path.join(self._root, name), *self._lines[name]["args"], **self._lines[name]["kwargs"])
meta = line.load_model_meta(model)
except FileNotFoundError:
continue
Expand All @@ -294,13 +310,10 @@ def _update_lines(self) -> None:
os.path.isdir(os.path.join(self._root, name))
and name not in self._lines
):
self._lines[name] = ModelLine(
os.path.join(self._root, name),
model_cls=self._model_cls
if isinstance(self._model_cls, type)
else self._model_cls[name],
meta_fmt=self._meta_fmt,
)
self._lines[name] = {
"args": [],
"kwargs": dict()
}


class ModelRepoConcatenator(Repo):
Expand Down
6 changes: 3 additions & 3 deletions cascade/tests/test_model_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def test_failed_line_meta(tmp_path, ext):
f.write("\t{{{: 'sorry, i am broken'")

repo = ModelRepo(
repo_path, lines=[dict(name="0", model_cls=DummyModel, meta_fmt=ext)]
repo_path, lines=[dict(name="0", model_cls=DummyModel)], meta_fmt=ext
)
model = repo["0"][0]

Expand Down Expand Up @@ -299,8 +299,8 @@ def test_integer_indices(tmp_path, ext):
first_line = repo.add_line("a")
last_line = repo.add_line("b")

assert first_line == repo[0]
assert last_line == repo[-1]
assert first_line.get_root() == repo[0].get_root()
assert last_line.get_root() == repo[-1].get_root()


@pytest.mark.parametrize("ext", [".json", ".yml", ".yaml"])
Expand Down