diff --git a/cascade/base/__init__.py b/cascade/base/__init__.py index 0b4a2b7c..4248412a 100644 --- a/cascade/base/__init__.py +++ b/cascade/base/__init__.py @@ -1,2 +1,3 @@ from .meta_handler import MetaHandler from .traceable import Traceable +from .meta_handler import CustomEncoder as JSONEncoder diff --git a/cascade/base/meta_handler.py b/cascade/base/meta_handler.py index 521d3503..a1222148 100644 --- a/cascade/base/meta_handler.py +++ b/cascade/base/meta_handler.py @@ -16,9 +16,12 @@ import os import json +from typing import Union, Dict, List import datetime +from typing import List, Dict from json import JSONEncoder +import yaml import numpy as np @@ -28,13 +31,13 @@ def default(self, obj): return str(obj) if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)): return obj.isoformat() + elif isinstance(obj, datetime.timedelta): return (datetime.datetime.min + obj).time().isoformat() elif isinstance(obj, (np.int_, np.intc, np.intp, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64)): - return int(obj) elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)): @@ -46,29 +49,39 @@ def default(self, obj): elif isinstance(obj, (np.ndarray,)): return obj.tolist() - elif isinstance(obj, (np.bool_)): + elif isinstance(obj, np.bool_): return bool(obj) - elif isinstance(obj, (np.void)): + elif isinstance(obj, np.void): return None return super(CustomEncoder, self).default(obj) + def obj_to_dict(self, obj): + return json.loads(self.encode(obj)) -class MetaHandler: + +class BaseHandler: + def read(self, path) -> List[Dict]: + raise NotImplementedError() + + def write(self, path, obj, overwrite=True) -> None: + raise NotImplementedError() + + +class JSONHandler(BaseHandler): """ Handles the logic of dumping and loading json files """ - def read(self, path) -> dict: + def read(self, path) -> Union[Dict, List[Dict]]: """ Reads json from path Parameters ---------- path: - Path to the file. If no extension provided, then .json assumed + Path to the file. If no extension provided, then .json will be added """ - assert os.path.exists(path) _, ext = os.path.splitext(path) if ext == '': path += '.json' @@ -79,16 +92,77 @@ def read(self, path) -> dict: meta = json.loads(meta) return meta - def write(self, name, obj, overwrite=True) -> None: + def write(self, name, obj:List[Dict], overwrite=True) -> None: """ Writes json to path using custom encoder """ - if not overwrite and os.path.exists(name): return - with open(name, 'w') as json_meta: - json.dump(obj, json_meta, cls=CustomEncoder, indent=4) + with open(name, 'w') as f: + json.dump(obj, f, cls=CustomEncoder, indent=4) + + +class YAMLHandler(BaseHandler): + def read(self, path) -> Union[Dict, List[Dict]]: + """ + Reads yaml from path + + Parameters + ---------- + path: + Path to the file. If no extension provided, then .yml will be added + """ + _, ext = os.path.splitext(path) + if ext == '': + path += '.yml' + + with open(path, 'r') as meta_file: + meta = yaml.safe_load(meta_file) + return meta + + def write(self, path, obj, overwrite=True) -> None: + if not overwrite and os.path.exists(path): + return + + obj = CustomEncoder().obj_to_dict(obj) + with open(path, 'w') as f: + yaml.safe_dump(obj, f) - def encode(self, obj): - return CustomEncoder().encode(obj) + +class TextHandler(BaseHandler): + def read(self, path) -> Dict: + """ + Reads text file from path and returns dict in the form {path: 'text from file'} + + Parameters + ---------- + path: + Path to the file + """ + + with open(path, 'r') as meta_file: + meta = {path: ''.join(meta_file.readlines())} + return meta + + def write(self, path, obj, overwrite=True) -> None: + raise NotImplementedError('MetaHandler does not write text files, only reads') + + +class MetaHandler: + def read(self, path) -> List[Dict]: + handler = self._get_handler(path) + return handler.read(path) + + def write(self, path, obj, overwrite=True) -> None: + handler = self._get_handler(path) + return handler.write(path, obj, overwrite=overwrite) + + def _get_handler(self, path) -> BaseHandler: + ext = os.path.splitext(path)[-1] + if ext == '.json': + return JSONHandler() + elif ext == '.yml': + return YAMLHandler() + else: + return TextHandler() diff --git a/cascade/meta/meta_viewer.py b/cascade/meta/meta_viewer.py index de623d6b..507d87c1 100644 --- a/cascade/meta/meta_viewer.py +++ b/cascade/meta/meta_viewer.py @@ -15,8 +15,8 @@ """ import os -import json -from ..base import MetaHandler +from typing import List, Dict, Union +from ..base import MetaHandler, JSONEncoder class MetaViewer: @@ -50,11 +50,11 @@ def __init__(self, root, filt=None) -> None: if filt is not None: self.metas = list(filter(self._filter, self.metas)) - def __getitem__(self, index) -> dict: + def __getitem__(self, index) -> List[Dict]: """ Returns ------- - meta: dict + meta: List[Dict] object containing meta """ return self.metas[index] @@ -86,14 +86,14 @@ def pretty(d, indent=0, sep=' '): out += pretty(meta, 4) return out - def write(self, name, obj: dict) -> None: - """ - Dumps obj to name + def write(self, name, obj: List[Dict]) -> None: """ + Dumps obj to name + """ self.metas.append(obj) self.mh.write(name, obj) - def read(self, path) -> dict: + def read(self, path) -> List[Dict]: """ Loads object from path """ @@ -110,4 +110,4 @@ def _filter(self, meta): return True def obj_to_dict(self, obj): - return json.loads(self.mh.encode(obj)) + return JSONEncoder().obj_to_dict(obj) diff --git a/cascade/models/model_line.py b/cascade/models/model_line.py index 6a16d010..88d8feeb 100644 --- a/cascade/models/model_line.py +++ b/cascade/models/model_line.py @@ -31,7 +31,7 @@ class ModelLine(Traceable): A line of models is typically a models with the same hyperparameters and architecture, but different epochs or using different data. """ - def __init__(self, folder, model_cls=Model, **kwargs) -> None: + def __init__(self, folder, model_cls=Model, meta_fmt='.json', **kwargs) -> None: """ All models in line should be instances of the same class. @@ -42,12 +42,16 @@ def __init__(self, folder, model_cls=Model, **kwargs) -> None: if folder does not exist, creates it model_cls: A class of models in repo. ModelLine uses this class to reconstruct a model + meta_fmt: + Format in which to store meta data. '.json', '.yml' are supported. .json is default. See also -------- cascade.models.ModelRepo """ super().__init__(**kwargs) + assert meta_fmt in ['.json', '.yml'], 'Only .json or .yml are supported formats' + self.meta_fmt = meta_fmt self.model_cls = model_cls self.root = folder self.model_names = [] @@ -122,10 +126,10 @@ def save(self, model: Model) -> None: meta[-1]['saved_at'] = pendulum.now(tz='UTC') # Save model's meta - self.meta_viewer.write(os.path.join(self.root, folder_name, 'meta.json'), meta) + self.meta_viewer.write(os.path.join(self.root, folder_name, 'meta' + self.meta_fmt), meta) # Save updated line's meta - self.meta_viewer.write(os.path.join(self.root, 'meta.json'), self.get_meta()) + self.meta_viewer.write(os.path.join(self.root, 'meta' + self.meta_fmt), self.get_meta()) def __repr__(self) -> str: return f'ModelLine of {len(self)} models of {self.model_cls}' diff --git a/cascade/models/model_repo.py b/cascade/models/model_repo.py index f920a520..dd873893 100644 --- a/cascade/models/model_repo.py +++ b/cascade/models/model_repo.py @@ -95,7 +95,7 @@ def __init__(self, folder, lines=None, overwrite=False, **kwargs): self._update_meta() - def add_line(self, name, model_cls): + def add_line(self, name, model_cls, **kwargs): """ Adds new line to repo if it doesn't exist and returns it If line exists, defines it in repo @@ -111,7 +111,7 @@ def add_line(self, name, model_cls): assert type(model_cls) == type, f'You should pass model\'s class, not {type(model_cls)}' folder = os.path.join(self.root, name) - line = ModelLine(folder, model_cls=model_cls, meta_prefix=self.meta_prefix) + line = ModelLine(folder, model_cls=model_cls, meta_prefix=self.meta_prefix, **kwargs) self.lines[name] = line self._update_meta() diff --git a/cascade/tests/conftest.py b/cascade/tests/conftest.py index 5ec4f862..9de78fd0 100644 --- a/cascade/tests/conftest.py +++ b/cascade/tests/conftest.py @@ -17,8 +17,9 @@ import os import sys -import random -import shutil +import datetime +import pendulum +from dateutil import tz import numpy as np import pytest @@ -26,7 +27,7 @@ sys.path.append(os.path.dirname(MODULE_PATH)) from cascade.data import Wrapper, Iterator -from cascade.models import Model, ModelRepo, BasicModel +from cascade.models import Model, ModelLine, ModelRepo, BasicModel class DummyModel(Model): @@ -99,14 +100,30 @@ def number_iterator(request): {'a': 0}, {'b': 1}, {'a': 0, 'b': 'alala'}, - {'c': np.array([1, 2]), 'd': {'a': 0}}]) + {'c': np.array([1, 2]), 'd': {'a': 0}}, + {'e': datetime.datetime(2022, 7, 8, 16, 4, 3, 5, tz.gettz('Europe / Moscow'))}, + {'f': pendulum.datetime(2022, 7, 8, 16, 4, 3, 5, 'Europe/Moscow')}]) def dummy_model(request): - return DummyModel(**request.param) + return DummyModel(**request.param) @pytest.fixture def empty_model(): - return EmptyModel() + return EmptyModel() + + +@pytest.fixture(params=[ + { + 'model_cls': DummyModel, + 'meta_fmt': '.json' + }, + { + 'model_cls': DummyModel, + 'meta_fmt': '.yml' + }]) +def model_line(request, tmp_path): + line = ModelLine(str(tmp_path), **request.param) + return line @pytest.fixture diff --git a/cascade/tests/test_bruteforce_cacher.py b/cascade/tests/test_bruteforce_cacher.py index e9d50380..cfd7596a 100644 --- a/cascade/tests/test_bruteforce_cacher.py +++ b/cascade/tests/test_bruteforce_cacher.py @@ -25,14 +25,12 @@ def test_ds(number_dataset): ds = BruteforceCacher(number_dataset) - assert([number_dataset[i] for i in range(len(number_dataset))] \ - == [item for item in ds]) + assert([number_dataset[i] for i in range(len(number_dataset))] == [item for item in ds]) def test_it(number_iterator): ds = BruteforceCacher(number_iterator) - assert([item for item in number_iterator] \ - == [item for item in ds]) + assert([item for item in number_iterator] == [item for item in ds]) def test_meta(): diff --git a/cascade/tests/test_meta_handler.py b/cascade/tests/test_meta_handler.py index f86a417d..eb759126 100644 --- a/cascade/tests/test_meta_handler.py +++ b/cascade/tests/test_meta_handler.py @@ -18,6 +18,7 @@ import sys import pendulum import numpy as np +import pytest MODULE_PATH = os.path.dirname(os.path.abspath(os.path.dirname(__file__))) sys.path.append(os.path.dirname(MODULE_PATH)) @@ -25,10 +26,16 @@ from cascade.base import MetaHandler -def test(tmp_path): +@pytest.mark.parametrize( + 'ext', [ + '.json', + '.yml' + ] +) +def test(tmp_path, ext): tmp_path = str(tmp_path) mh = MetaHandler() - mh.write(os.path.join(tmp_path, 'meta.json'), + mh.write(os.path.join(tmp_path, 'meta' + ext), { 'name': 'test_mh', 'array': np.zeros(4), @@ -36,22 +43,28 @@ def test(tmp_path): 'date': pendulum.now(tz='UTC') }) - obj = mh.read(os.path.join(tmp_path, 'meta.json')) + obj = mh.read(os.path.join(tmp_path, 'meta' + ext)) assert(obj['name'] == 'test_mh') - assert(all(obj['array'] == np.zeros(4))) + assert(obj['array'] == [0, 0, 0, 0]) assert(obj['none'] is None) -def test_overwrite(tmp_path): - tmp_path = os.path.join(str(tmp_path), 'test_mh_ow.json') +@pytest.mark.parametrize( + 'ext', [ + '.json', + '.yml' + ] +) +def test_overwrite(tmp_path, ext): + tmp_path = os.path.join(str(tmp_path), 'test_mh_ow' + ext) mh = MetaHandler() mh.write( tmp_path, {'name': 'first'}, overwrite=False) - + mh.write( tmp_path, {'name': 'second'}, @@ -59,3 +72,22 @@ def test_overwrite(tmp_path): obj = mh.read(tmp_path) assert(obj['name'] == 'first') + + +@pytest.mark.parametrize( + 'ext', [ + '.txt', + '.md' + ] +) +def test_text(tmp_path, ext): + tmp_path = str(os.path.join(tmp_path, 'meta' + ext)) + mh = MetaHandler() + + info = '#Meta\n\n\n this is object for testing text files' + with open(tmp_path, 'w') as f: + f.write(info) + + obj = mh.read(tmp_path) + + assert(obj[tmp_path] == info) diff --git a/cascade/tests/test_meta_viewer.py b/cascade/tests/test_meta_viewer.py index f179fe06..de1654d7 100644 --- a/cascade/tests/test_meta_viewer.py +++ b/cascade/tests/test_meta_viewer.py @@ -32,7 +32,7 @@ def test(tmp_path): json.dump({'name': 'test0'}, f) # Write also using mv - mv = MetaViewer() + mv = MetaViewer(tmp_path) mv.write(os.path.join(tmp_path, 'model', 'test.json'), {'name': 'test1'}) mv = MetaViewer(tmp_path) diff --git a/cascade/tests/test_model_line.py b/cascade/tests/test_model_line.py index 158b5f80..ff8d0a15 100644 --- a/cascade/tests/test_model_line.py +++ b/cascade/tests/test_model_line.py @@ -24,23 +24,17 @@ from cascade.models.model_repo import ModelLine -def test_save_load(tmp_path): - tmp_path = str(tmp_path) - line = ModelLine(tmp_path, DummyModel) - m = DummyModel() - line.save(m) - model = line[0] - - assert(len(line) == 1) +def test_save_load(model_line, dummy_model): + model_line.save(dummy_model) + model = model_line[0] + + assert(len(model_line) == 1) assert(model.model == "b'model'") -def test_meta(tmp_path): - tmp_path = str(tmp_path) - line = ModelLine(tmp_path, DummyModel) - m = DummyModel() - line.save(m) - meta = line.get_meta() +def test_meta(model_line, dummy_model): + model_line.save(dummy_model) + meta = model_line.get_meta() assert(meta[0]['model_cls'] == repr(DummyModel)) assert(meta[0]['len'] == 1) diff --git a/requirements.txt b/requirements.txt index 179ffa70..0aaed753 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ deepdiff pendulum plotly flatten_json +pyyaml diff --git a/setup.py b/setup.py index 8f098894..011b3089 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ 'deepdiff', 'pendulum', 'plotly', - 'flatten_json' + 'flatten_json', + 'pyyaml' ] )