Skip to content

Commit

Permalink
minor change
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Al-Saffar committed May 19, 2024
1 parent 3ad6612 commit ac809f7
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 21 deletions.
69 changes: 49 additions & 20 deletions myresources/crocodile/deeplearning.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,21 @@
dl
"""

import numpy as np
import numpy.typing as npt
import pandas as pd
from tqdm import tqdm

from crocodile.matplotlib_management import ImShow
from crocodile.core import List as L, Struct as S, Base
from crocodile.file_management import P, Save, PLike, Read
from crocodile.meta import Experimental

import numpy as np
import numpy.typing as npt
import pandas as pd
from abc import ABC
from typing import TypeVar, Type, Any, Optional, Union, Callable, Literal, TypeAlias
import enum
from tqdm import tqdm
import copy
from abc import ABC
from dataclasses import dataclass, field
from typing import TypeVar, Type, Any, Optional, Union, Callable, Literal, TypeAlias, Protocol


@dataclass
Expand Down Expand Up @@ -69,10 +70,33 @@ class Device(enum.Enum):
SubclassedBaseModel = TypeVar("SubclassedBaseModel", bound='BaseModel')


HPARAMS_SUBPATH: str = 'metadata/hyperparameters' # location within model directory where this will be saved.
PACKAGE: TypeAlias = Literal['tensorflow', 'torch']
PRECISON = Literal['float64', 'float32', 'float16']


class HyperParams(Protocol):
# ================== General ==============================
name: str
root: P
pkg_name: PACKAGE

# ===================== Data ==============================
seed: int
shuffle: bool
precision: PRECISON

# ===================== Training ==========================
test_split: float
learning_rate: float
batch_size: int
epochs: int


def get_hp_save_dir(hp: HyperParams):
return (P(hp.root) / hp.name).create()


@dataclass
class HParams:
# ===================== Data ==============================
Expand All @@ -96,7 +120,8 @@ class HParams:
subpath: str = 'metadata/hyperparameters' # location within model directory where this will be saved.

def save(self):
subpath = self.subpath
# subpath = self.subpath
subpath = HPARAMS_SUBPATH
save_dir = self.save_dir
self_repr = str(self)

Expand Down Expand Up @@ -137,7 +162,7 @@ class DataReader:
subpath = P("metadata/data_reader")
"""This class holds the dataset for training and testing.
"""
def __init__(self, hp: SubclassedHParams, # type: ignore
def __init__(self, hp: HyperParams, # type: ignore
specs: Optional[Specs] = None,
split: Optional[dict[str, Any]] = None) -> None:
# split could be Union[None, 'npt.NDArray[np.float64]', 'pd.DataFrame', 'pd.Series', 'list[Any]', Tf.RaggedTensor etc.
Expand All @@ -149,13 +174,15 @@ def __init__(self, hp: SubclassedHParams, # type: ignore
# self.df_handler = df_handler
def save(self, path: Optional[str] = None, **kwargs: Any) -> None:
_ = kwargs
base = (P(path) if path is not None else self.hp.save_dir).joinpath(self.subpath).create()
base = (P(path) if path is not None else get_hp_save_dir(self.hp)).joinpath(self.subpath).create()
try: data: dict[str, Any] = self.__getstate__()
except AttributeError: data = self.__dict__
Save.pickle(path=base / "data_reader.DataReader.dat.pkl", obj=data)
Save.pickle(path=base / "data_reader.DataReader.pkl", obj=self)
@classmethod
def from_saved_data(cls, path: Union[str, P], hp: SubclassedHParams, # type: ignore
def from_saved_data(cls, path: Union[str, P],
# hp: SubclassedHParams, # type: ignore
hp: HyperParams,
**kwargs: Any):
path = P(path) / cls.subpath / "data_reader.DataReader.dat.pkl"
data: dict[str, Any] = Read.pickle(path)
Expand Down Expand Up @@ -400,7 +427,7 @@ def fit(self, viz: bool = True, weight_name: Optional[str] = None,
self.history.append(copy.deepcopy(hist.history)) # it is paramount to copy, cause source can change.
if viz:
artist = self.plot_loss()
artist.fig.savefig(str(self.hp.save_dir.joinpath(f"metadata/training/loss_curve.png").append(index=True).create(parents_only=True)))
artist.fig.savefig(str(get_hp_save_dir(self.hp).joinpath(f"metadata/training/loss_curve.png").append(index=True).create(parents_only=True)))
return self

def switch_to_sgd(self, epochs: int = 10):
Expand Down Expand Up @@ -473,7 +500,7 @@ def load_weights(self, directory: PLike) -> None:
self.model.load_weights(path) # .expect_partial()
def summary(self):
from contextlib import redirect_stdout
path = self.hp.save_dir.joinpath("metadata/model/model_summary.txt").create(parents_only=True)
path = get_hp_save_dir(self.hp).joinpath("metadata/model/model_summary.txt").create(parents_only=True)
with open(str(path), 'w', encoding='utf-8') as f:
with redirect_stdout(f): self.model.summary()
return self.model.summary()
Expand Down Expand Up @@ -568,10 +595,10 @@ def save_class(self, weights_only: bool = True, version: str = 'v0', strict: boo
"""
self.hp.save() # goes into the meta path.
self.data.save() # goes into the meta path.
Save.pickle(obj=self.history, path=self.hp.save_dir / 'metadata/training/history.pkl', verbose=True, desc="Training History") # goes into the meta path.
try: Experimental.generate_readme(self.hp.save_dir, obj=self.__class__, desc=desc)
Save.pickle(obj=self.history, path=get_hp_save_dir(self.hp) / 'metadata/training/history.pkl', verbose=True, desc="Training History") # goes into the meta path.
try: Experimental.generate_readme(get_hp_save_dir(self.hp), obj=self.__class__, desc=desc)
except Exception as ex: print(ex) # often fails because model is defined in main during experiments.
save_dir = self.hp.save_dir.joinpath(f'{"weights" if weights_only else "model"}_save_{version}')
save_dir = get_hp_save_dir(self.hp).joinpath(f'{"weights" if weights_only else "model"}_save_{version}')
if weights_only: self.save_weights(save_dir.create())
else:
self.save_model(save_dir)
Expand All @@ -598,12 +625,14 @@ def save_class(self, weights_only: bool = True, version: str = 'v0', strict: boo
'module_path_rh': module_path_rh,
'cwd_rh': P.cwd().collapseuser().as_posix(),
}
Save.json(obj=specs, path=self.hp.save_dir.joinpath('metadata/code_specs.json').str, indent=4)
print(f'SAVED Model Class @ {self.hp.save_dir.as_uri()}')
return self.hp.save_dir
Save.json(obj=specs, path=get_hp_save_dir(self.hp).joinpath('metadata/code_specs.json').str, indent=4)
print(f'SAVED Model Class @ {get_hp_save_dir(self.hp).as_uri()}')
return get_hp_save_dir(self.hp)

@classmethod
def from_class_weights(cls, path: PLike, hparam_class: Optional[Type[SubclassedHParams]] = None, data_class: Optional[Type[SubclassedDataReader]] = None,
def from_class_weights(cls, path: PLike,
hparam_class: Optional[Type[SubclassedHParams]] = None,
data_class: Optional[Type[SubclassedDataReader]] = None,
device_name: Optional[Device] = None, verbose: bool = True):
path = P(path)
if hparam_class is not None:
Expand Down Expand Up @@ -676,7 +705,7 @@ def from_path(path_model: PLike, **kwargs: Any) -> 'SubclassedBaseModel': # typ

def plot_model(self, dpi: int = 150, strict: bool = False, **kwargs: Any): # alternative viz via tf2onnx then Netron.
import keras
path = self.hp.save_dir.joinpath("metadata/model/model_plot.png")
path = get_hp_save_dir(self.hp).joinpath("metadata/model/model_plot.png")
try:
keras.utils.plot_model(self.model, to_file=str(path), show_shapes=True, show_layer_names=True, show_layer_activations=True, show_dtype=True, expand_nested=True, dpi=dpi, **kwargs)
print(f"Successfully plotted the model @ {path.as_uri()}")
Expand Down
2 changes: 1 addition & 1 deletion myresources/crocodile/deeplearning_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class HParams(dl.HParams):
subpath: str = 'metadata/hyperparameters' # location within model directory where this will be saved.
name: str = field(default_factory=lambda: "model-" + randstr(noun=True))
root: P = P.tmp(folder="tmp_models")
pkg_name = 'tensorflow'
pkg_name: dl.PACKAGE = 'tensorflow'
# device_name: Device=Device.gpu0
# ===================== Data ==============================
seed: int = 234
Expand Down

0 comments on commit ac809f7

Please sign in to comment.