Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abstract from json #66

Merged
merged 15 commits into from
Jul 29, 2022
1 change: 1 addition & 0 deletions cascade/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .meta_handler import MetaHandler
from .traceable import Traceable
from .meta_handler import CustomEncoder as JSONEncoder
100 changes: 87 additions & 13 deletions cascade/base/meta_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@

import os
import json
from typing import Union, Dict, List
import datetime
from typing import List, Dict
from json import JSONEncoder

import yaml
import numpy as np


Expand All @@ -28,13 +31,13 @@ def default(self, obj):
return str(obj)
if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
return obj.isoformat()

elif isinstance(obj, datetime.timedelta):
return (datetime.datetime.min + obj).time().isoformat()

elif isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
np.int16, np.int32, np.int64, np.uint8,
np.uint16, np.uint32, np.uint64)):

return int(obj)

elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
Expand All @@ -46,29 +49,39 @@ def default(self, obj):
elif isinstance(obj, (np.ndarray,)):
return obj.tolist()

elif isinstance(obj, (np.bool_)):
elif isinstance(obj, np.bool_):
return bool(obj)

elif isinstance(obj, (np.void)):
elif isinstance(obj, np.void):
return None

return super(CustomEncoder, self).default(obj)

def obj_to_dict(self, obj):
return json.loads(self.encode(obj))

class MetaHandler:

class BaseHandler:
def read(self, path) -> List[Dict]:
raise NotImplementedError()

def write(self, path, obj, overwrite=True) -> None:
raise NotImplementedError()


class JSONHandler(BaseHandler):
"""
Handles the logic of dumping and loading json files
"""
def read(self, path) -> dict:
def read(self, path) -> Union[Dict, List[Dict]]:
"""
Reads json from path

Parameters
----------
path:
Path to the file. If no extension provided, then .json assumed
Path to the file. If no extension provided, then .json will be added
"""
assert os.path.exists(path)
_, ext = os.path.splitext(path)
if ext == '':
path += '.json'
Expand All @@ -79,16 +92,77 @@ def read(self, path) -> dict:
meta = json.loads(meta)
return meta

def write(self, name, obj, overwrite=True) -> None:
def write(self, name, obj:List[Dict], overwrite=True) -> None:
"""
Writes json to path using custom encoder
"""

if not overwrite and os.path.exists(name):
return

with open(name, 'w') as json_meta:
json.dump(obj, json_meta, cls=CustomEncoder, indent=4)
with open(name, 'w') as f:
json.dump(obj, f, cls=CustomEncoder, indent=4)


class YAMLHandler(BaseHandler):
def read(self, path) -> Union[Dict, List[Dict]]:
"""
Reads yaml from path

Parameters
----------
path:
Path to the file. If no extension provided, then .yml will be added
"""
_, ext = os.path.splitext(path)
if ext == '':
path += '.yml'

with open(path, 'r') as meta_file:
meta = yaml.safe_load(meta_file)
return meta

def write(self, path, obj, overwrite=True) -> None:
if not overwrite and os.path.exists(path):
return

obj = CustomEncoder().obj_to_dict(obj)
with open(path, 'w') as f:
yaml.safe_dump(obj, f)

def encode(self, obj):
return CustomEncoder().encode(obj)

class TextHandler(BaseHandler):
def read(self, path) -> Dict:
"""
Reads text file from path and returns dict in the form {path: 'text from file'}

Parameters
----------
path:
Path to the file
"""

with open(path, 'r') as meta_file:
meta = {path: ''.join(meta_file.readlines())}
return meta

def write(self, path, obj, overwrite=True) -> None:
raise NotImplementedError('MetaHandler does not write text files, only reads')


class MetaHandler:
def read(self, path) -> List[Dict]:
handler = self._get_handler(path)
return handler.read(path)

def write(self, path, obj, overwrite=True) -> None:
handler = self._get_handler(path)
return handler.write(path, obj, overwrite=overwrite)

def _get_handler(self, path) -> BaseHandler:
ext = os.path.splitext(path)[-1]
if ext == '.json':
return JSONHandler()
elif ext == '.yml':
return YAMLHandler()
else:
return TextHandler()
18 changes: 9 additions & 9 deletions cascade/meta/meta_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
"""

import os
import json
from ..base import MetaHandler
from typing import List, Dict, Union
from ..base import MetaHandler, JSONEncoder


class MetaViewer:
Expand Down Expand Up @@ -50,11 +50,11 @@ def __init__(self, root, filt=None) -> None:
if filt is not None:
self.metas = list(filter(self._filter, self.metas))

def __getitem__(self, index) -> dict:
def __getitem__(self, index) -> List[Dict]:
"""
Returns
-------
meta: dict
meta: List[Dict]
object containing meta
"""
return self.metas[index]
Expand Down Expand Up @@ -86,14 +86,14 @@ def pretty(d, indent=0, sep=' '):
out += pretty(meta, 4)
return out

def write(self, name, obj: dict) -> None:
"""
Dumps obj to name
def write(self, name, obj: List[Dict]) -> None:
"""
Dumps obj to name
"""
self.metas.append(obj)
self.mh.write(name, obj)

def read(self, path) -> dict:
def read(self, path) -> List[Dict]:
"""
Loads object from path
"""
Expand All @@ -110,4 +110,4 @@ def _filter(self, meta):
return True

def obj_to_dict(self, obj):
return json.loads(self.mh.encode(obj))
return JSONEncoder().obj_to_dict(obj)
10 changes: 7 additions & 3 deletions cascade/models/model_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ModelLine(Traceable):
A line of models is typically a models with the same hyperparameters and architecture,
but different epochs or using different data.
"""
def __init__(self, folder, model_cls=Model, **kwargs) -> None:
def __init__(self, folder, model_cls=Model, meta_fmt='.json', **kwargs) -> None:
"""
All models in line should be instances of the same class.

Expand All @@ -42,12 +42,16 @@ def __init__(self, folder, model_cls=Model, **kwargs) -> None:
if folder does not exist, creates it
model_cls:
A class of models in repo. ModelLine uses this class to reconstruct a model
meta_fmt:
Format in which to store meta data. '.json', '.yml' are supported. .json is default.
See also
--------
cascade.models.ModelRepo
"""
super().__init__(**kwargs)

assert meta_fmt in ['.json', '.yml'], 'Only .json or .yml are supported formats'
self.meta_fmt = meta_fmt
self.model_cls = model_cls
self.root = folder
self.model_names = []
Expand Down Expand Up @@ -122,10 +126,10 @@ def save(self, model: Model) -> None:
meta[-1]['saved_at'] = pendulum.now(tz='UTC')

# Save model's meta
self.meta_viewer.write(os.path.join(self.root, folder_name, 'meta.json'), meta)
self.meta_viewer.write(os.path.join(self.root, folder_name, 'meta' + self.meta_fmt), meta)

# Save updated line's meta
self.meta_viewer.write(os.path.join(self.root, 'meta.json'), self.get_meta())
self.meta_viewer.write(os.path.join(self.root, 'meta' + self.meta_fmt), self.get_meta())

def __repr__(self) -> str:
return f'ModelLine of {len(self)} models of {self.model_cls}'
Expand Down
4 changes: 2 additions & 2 deletions cascade/models/model_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(self, folder, lines=None, overwrite=False, **kwargs):

self._update_meta()

def add_line(self, name, model_cls):
def add_line(self, name, model_cls, **kwargs):
"""
Adds new line to repo if it doesn't exist and returns it
If line exists, defines it in repo
Expand All @@ -111,7 +111,7 @@ def add_line(self, name, model_cls):
assert type(model_cls) == type, f'You should pass model\'s class, not {type(model_cls)}'

folder = os.path.join(self.root, name)
line = ModelLine(folder, model_cls=model_cls, meta_prefix=self.meta_prefix)
line = ModelLine(folder, model_cls=model_cls, meta_prefix=self.meta_prefix, **kwargs)
self.lines[name] = line

self._update_meta()
Expand Down
29 changes: 23 additions & 6 deletions cascade/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@

import os
import sys
import random
import shutil
import datetime
import pendulum
from dateutil import tz
import numpy as np
import pytest

MODULE_PATH = os.path.dirname(os.path.abspath(os.path.dirname(__file__)))
sys.path.append(os.path.dirname(MODULE_PATH))

from cascade.data import Wrapper, Iterator
from cascade.models import Model, ModelRepo, BasicModel
from cascade.models import Model, ModelLine, ModelRepo, BasicModel


class DummyModel(Model):
Expand Down Expand Up @@ -99,14 +100,30 @@ def number_iterator(request):
{'a': 0},
{'b': 1},
{'a': 0, 'b': 'alala'},
{'c': np.array([1, 2]), 'd': {'a': 0}}])
{'c': np.array([1, 2]), 'd': {'a': 0}},
{'e': datetime.datetime(2022, 7, 8, 16, 4, 3, 5, tz.gettz('Europe / Moscow'))},
{'f': pendulum.datetime(2022, 7, 8, 16, 4, 3, 5, 'Europe/Moscow')}])
def dummy_model(request):
return DummyModel(**request.param)
return DummyModel(**request.param)


@pytest.fixture
def empty_model():
return EmptyModel()
return EmptyModel()


@pytest.fixture(params=[
{
'model_cls': DummyModel,
'meta_fmt': '.json'
},
{
'model_cls': DummyModel,
'meta_fmt': '.yml'
}])
def model_line(request, tmp_path):
line = ModelLine(str(tmp_path), **request.param)
return line


@pytest.fixture
Expand Down
6 changes: 2 additions & 4 deletions cascade/tests/test_bruteforce_cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,12 @@

def test_ds(number_dataset):
ds = BruteforceCacher(number_dataset)
assert([number_dataset[i] for i in range(len(number_dataset))] \
== [item for item in ds])
assert([number_dataset[i] for i in range(len(number_dataset))] == [item for item in ds])


def test_it(number_iterator):
ds = BruteforceCacher(number_iterator)
assert([item for item in number_iterator] \
== [item for item in ds])
assert([item for item in number_iterator] == [item for item in ds])


def test_meta():
Expand Down
Loading