Skip to content

Commit

Permalink
Merge pull request #84 from Oxid15/refactor
Browse files Browse the repository at this point in the history
Refactor - make some fields protected
  • Loading branch information
Oxid15 authored Aug 2, 2022
2 parents 12d970a + 3fd05c9 commit bed2fb8
Show file tree
Hide file tree
Showing 23 changed files with 167 additions and 162 deletions.
14 changes: 7 additions & 7 deletions cascade/base/traceable.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(self, *args, meta_prefix=None, **kwargs) -> None:
meta_prefix = {}
elif isinstance(meta_prefix, str):
meta_prefix = self._read_meta_from_file(meta_prefix)
self.meta_prefix = meta_prefix
self._meta_prefix = meta_prefix

def _read_meta_from_file(self, path: str) -> Union[List[Dict], Dict]:
from . import MetaHandler
Expand All @@ -26,23 +26,23 @@ def get_meta(self) -> List[Dict]:
meta = {
'name': repr(self)
}
if hasattr(self, 'meta_prefix'):
meta.update(self.meta_prefix)
if hasattr(self, '_meta_prefix'):
meta.update(self._meta_prefix)
else:
self._warn_no_prefix()
return [meta]

def update_meta(self, obj: Union[Dict, str]) -> None:
"""
Updates meta_prefix, which is then updates dataset's meta when get_meta() is called
Updates _meta_prefix, which is then updates dataset's meta when get_meta() is called
"""
if isinstance(obj, str):
obj = self._read_meta_from_file(obj)

if hasattr(self, 'meta_prefix'):
self.meta_prefix.update(obj)
if hasattr(self, '_meta_prefix'):
self._meta_prefix.update(obj)
else:
self._warn_no_prefix()

def _warn_no_prefix(self):
warnings.warn('Object doesn\'t have meta_prefix. This may mean super().__init__() wasn\'t called somewhere')
warnings.warn('Object doesn\'t have _meta_prefix. This may mean super().__init__() wasn\'t called somewhere')
6 changes: 3 additions & 3 deletions cascade/data/apply_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ def __init__(self, dataset: Dataset, func: Callable, *args, **kwargs) -> None:
each `__getitem__` would call `func` on an item obtained from a previous dataset
"""
super().__init__(dataset, *args, **kwargs)
self.func = func
self._func = func

def __getitem__(self, index: int) -> T:
item = self._dataset[index]
return self.func(item)
return self._func(item)

def __repr__(self) -> str:
rp = super().__repr__()
return f'{rp}, {repr(self.func)}'
return f'{rp}, {repr(self._func)}'
6 changes: 3 additions & 3 deletions cascade/data/concatenator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ def __init__(self, datasets: Iterable[Dataset], *args, **kwargs) -> None:
"""
self._datasets = datasets
lengths = [len(ds) for ds in self._datasets]
self.shifts = np.cumsum([0] + lengths)
self._shifts = np.cumsum([0] + lengths)
super().__init__(*args, **kwargs)

def __getitem__(self, index) -> T:
ds_index = 0
for sh in self.shifts[1:]:
for sh in self._shifts[1:]:
if index >= sh:
ds_index += 1
return self._datasets[ds_index][index - self.shifts[ds_index]]
return self._datasets[ds_index][index - self._shifts[ds_index]]

def __len__(self) -> int:
"""
Expand Down
15 changes: 8 additions & 7 deletions cascade/data/folder_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ class FolderDataset(Dataset):
"""
def __init__(self, root, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.root = os.path.abspath(root)
assert os.path.exists(self.root)
self.names = [os.path.join(self.root, name) for name in sorted(os.listdir(self.root)) if not os.path.isdir(name)]
self._root = os.path.abspath(root)
assert os.path.exists(self._root)
self._names = [os.path.join(self._root, name)
for name in sorted(os.listdir(self._root)) if not os.path.isdir(name)]

def __getitem__(self, item) -> T:
raise NotImplementedError()
Expand All @@ -28,14 +29,14 @@ def get_meta(self) -> List[Dict]:
meta[0].update({
'name': repr(self),
'len': len(self),
'paths': self.names,
'paths': self._names,
'md5sums': []
})

for name in self.names:
with open(os.path.join(self.root, name), 'rb') as f:
for name in self._names:
with open(os.path.join(self._root, name), 'rb') as f:
meta[0]['md5sums'].append(md5(f.read()).hexdigest())
return meta

def __len__(self):
return len(self.names)
return len(self._names)
8 changes: 4 additions & 4 deletions cascade/data/pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,18 @@ def __init__(self, path, dataset=None, *args, **kwargs) -> None:
a dataset to be pickled
"""
super().__init__(dataset, *args, **kwargs)
self.path = path
self._path = path

if self._dataset is None:
assert os.path.exists(self.path)
assert os.path.exists(self._path)
self._load()
else:
self._dump()

def _dump(self) -> None:
with open(self.path, 'wb') as f:
with open(self._path, 'wb') as f:
pickle.dump(self._dataset, f)

def _load(self) -> None:
with open(self.path, 'rb') as f:
with open(self._path, 'rb') as f:
self._dataset = pickle.load(f)
8 changes: 4 additions & 4 deletions cascade/data/random_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def __init__(self, dataset: Dataset, num_samples=None, **kwargs) -> None:
if num_samples is None:
num_samples = len(dataset)
super().__init__(dataset, num_samples, **kwargs)
self.indices = [i for i in range(len(dataset))]
shuffle(self.indices)
self.indices = self.indices[:num_samples]
self._indices = [i for i in range(len(dataset))]
shuffle(self._indices)
self._indices = self._indices[:num_samples]

def __getitem__(self, index):
return super().__getitem__(self.indices[index])
return super().__getitem__(self._indices[index])
28 changes: 14 additions & 14 deletions cascade/data/sequential_cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,28 @@ def __init__(self, dataset: Dataset, batch_size=2, *args, **kwargs) -> None:
# TODO: make something to release this assert
assert hasattr(dataset, '__len__'), 'Dataset should have __len__'
super().__init__(dataset, *args, **kwargs)
self.bs = batch_size
self.num_batches = int(ceil(len(self._dataset) / self.bs))
self.index = -1
self.batch = None
self._bs = batch_size
self._num_batches = int(ceil(len(self._dataset) / self._bs))
self._index = -1
self._batch = None

def _load(self, index) -> None:
del self.batch
self.batch = []
del self._batch
self._batch = []

start = index * self.bs
end = min(start + self.bs, len(self._dataset))
start = index * self._bs
end = min(start + self._bs, len(self._dataset))

for i in range(start, end):
self.batch.append(self._dataset[i])
self._batch.append(self._dataset[i])

self.index += 1
self._index += 1

def __getitem__(self, index) -> T:
batch_index = index // self.bs
in_batch_idx = index % self.bs
batch_index = index // self._bs
in_batch_idx = index % self._bs

if batch_index != self.index:
if batch_index != self._index:
self._load(batch_index)

return self.batch[in_batch_idx]
return self._batch[in_batch_idx]
30 changes: 15 additions & 15 deletions cascade/meta/history_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ class HistoryViewer:
models with different hyperparameters depend on each other
"""
def __init__(self, repo) -> None:
self.repo = repo
self._repo = repo

metas = []
self.params = []
for line in self.repo:
self._params = []
for line in self._repo:
# Try to use viewer only on models using type key
try:
view = MetaViewer(line.root, filt={'type': 'model'})
Expand All @@ -64,11 +64,11 @@ def __init__(self, repo) -> None:
if 'params' in view[i][-1]:
if len(view[i][-1]['params']) > 0:
params.update(flatten({'params': view[i][-1]['params']}))
self.params.append(params)
self._params.append(params)

self.table = pd.DataFrame(metas)
if 'saved_at' in self.table:
self.table = self.table.sort_values('saved_at')
self._table = pd.DataFrame(metas)
if 'saved_at' in self._table:
self._table = self._table.sort_values('saved_at')

def _diff(self, p1, params) -> List:
diff = [DeepDiff(p1, p2) for p2 in params]
Expand Down Expand Up @@ -102,19 +102,19 @@ def plot(self, metric: str) -> None:
# After flatten 'metrics_' will be added to the metric name
if not metric.startswith('metrics_'):
metric = 'metrics_' + metric
assert metric in self.table
assert metric in self._table

# turn time into evenly spaced intervals
time = [i for i in range(len(self.table))]
lines = self.table['line'].unique()
time = [i for i in range(len(self._table))]
lines = self._table['line'].unique()

cmap = px.colors.qualitative.Plotly
cmap_len = len(px.colors.qualitative.Plotly)
line_cols = {line: cmap[i % cmap_len] for i, line in enumerate(lines)}

self.table['time'] = time
self.table['color'] = [line_cols[line] for line in self.table['line']]
table = self.table.fillna('')
self._table['time'] = time
self._table['color'] = [line_cols[line] for line in self._table['line']]
table = self._table.fillna('')

# plot each model against metric
# with all metadata on hover
Expand All @@ -123,15 +123,15 @@ def plot(self, metric: str) -> None:
table,
x='time',
y=metric,
hover_data=[name for name in pd.DataFrame(self.params).columns],
hover_data=[name for name in pd.DataFrame(self._params).columns],
color='line'
)

# determine connections between models
# plot each one with respected color

for line in lines:
params = [p for p in self.params if p['line'] == line]
params = [p for p in self._params if p['line'] == line]
edges = []
for i in range(len(params)):
if i == 0:
Expand Down
10 changes: 5 additions & 5 deletions cascade/meta/meta_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,16 @@ def __init__(self, dataset: Dataset, root=None) -> None:
default is './.cascade'
"""
super().__init__(dataset, lambda x: True)
self.mh = MetaHandler()
self._mh = MetaHandler()
if root is None:
root = './.cascade'
os.makedirs(root, exist_ok=True)
self.root = root
self._root = root

meta = self._dataset.get_meta()
name = md5(str.encode(' '.join([m['name'] for m in meta]), 'utf-8')).hexdigest()
name += '.json'
name = os.path.join(self.root, name)
name = os.path.join(self._root, name)

if os.path.exists(name):
self.base_meta = self._load(name)
Expand All @@ -83,11 +83,11 @@ def __init__(self, dataset: Dataset, root=None) -> None:
self._save(meta, name)

def _save(self, meta, name) -> None:
self.mh.write(name, meta)
self._mh.write(name, meta)
print(f'Saved as {name}!')

def _load(self, name) -> dict:
return self.mh.read(name)
return self._mh.read(name)

def _check(self, query_meta):
diff = DeepDiff(self.base_meta, query_meta, verbose_level=2)
Expand Down
20 changes: 10 additions & 10 deletions cascade/meta/meta_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,18 @@ def __init__(self, root, filt=None) -> None:
cascade.meta.MetaHandler
"""
assert os.path.exists(root)
self.root = root
self.filt = filt
self.mh = MetaHandler()
self._root = root
self._filt = filt
self._mh = MetaHandler()

names = []
for root, _, files in os.walk(self.root):
for root, _, files in os.walk(self._root):
names += [os.path.join(root, name) for name in files if os.path.splitext(name)[-1] == '.json']
names = sorted(names)

self.metas = []
for name in names:
self.metas.append(self.mh.read(name))
self.metas.append(self._mh.read(name))
if filt is not None:
self.metas = list(filter(self._filter, self.metas))

Expand Down Expand Up @@ -83,7 +83,7 @@ def pretty(d, indent=0, sep=' '):
out += '\n'
return out

out = f'MetaViewer at {self.root}:\n'
out = f'MetaViewer at {self._root}:\n'
for i, meta in enumerate(self.metas):
out += '-' * 20 + '\n'
out += f' Meta {i}:\n'
Expand All @@ -96,21 +96,21 @@ def write(self, name, obj: List[Dict]) -> None:
Dumps obj to name
"""
self.metas.append(obj)
self.mh.write(name, obj)
self._mh.write(name, obj)

def read(self, path) -> List[Dict]:
"""
Loads object from path
"""
return self.mh.read(path)
return self._mh.read(path)

def _filter(self, meta):
meta = meta[-1] # Takes last meta
for key in self.filt:
for key in self._filt:
if key not in meta:
raise KeyError(f"'{key}' key is not in\n{meta}")

if self.filt[key] != meta[key]:
if self._filt[key] != meta[key]:
return False
return True

Expand Down
12 changes: 6 additions & 6 deletions cascade/meta/metric_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ def __init__(self, repo) -> None:
repo: ModelRepo
ModelRepo object to extract metrics from
"""
self.repo = repo
self._repo = repo

self.metrics = []
for line in self.repo:
self._metrics = []
for line in self._repo:
viewer_root = line.root

# Try to use viewer only on models using type key
Expand Down Expand Up @@ -77,8 +77,8 @@ def __init__(self, repo) -> None:
if 'params' in meta:
metric.update(meta['params'])

self.metrics.append(metric)
self.table = pd.DataFrame(self.metrics)
self._metrics.append(metric)
self.table = pd.DataFrame(self._metrics)

def __repr__(self) -> str:
return repr(self.table)
Expand Down Expand Up @@ -137,7 +137,7 @@ def serve(self, page_size=50, include=None, exclude=None, **kwargs) -> None:

app.layout = html.Div([
html.H1(
children=f'MetricViewer in {self.repo.root}',
children=f'MetricViewer in {self._repo.root}',
style={
'textAlign': 'center',
'color': '#084c61',
Expand Down
Loading

0 comments on commit bed2fb8

Please sign in to comment.