Skip to content

Commit

Permalink
added mypy settings
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Al-Saffar committed Aug 31, 2023
1 parent 6d6b90f commit 62af7be
Showing 1 changed file with 34 additions and 10 deletions.
44 changes: 34 additions & 10 deletions myresources/crocodile/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def get_attributes(self, remove_base_attrs: bool = True, return_objects: bool =
return List([getattr(self, x) for x in attrs]) if return_objects else List(attrs)
def print(self, dtype: bool = False, attrs: bool = False, **kwargs: Any): return Struct(self.__dict__).update(attrs=self.get_attributes() if attrs else None).print(dtype=dtype, **kwargs)
@staticmethod
def get_state(obj: Any, repr_func: Callable[[Any], str] = lambda x: x, exclude: Optional[list[str]] = None) -> dict[str, Any]: return repr_func(obj) if not any([hasattr(obj, "__getstate__"), hasattr(obj, "__dict__")]) else (tmp if type(tmp := obj.__getstate__() if hasattr(obj, "__getstate__") else obj.__dict__) is not dict else Struct(tmp).filter(lambda k, v: k not in (exclude or [])).apply2values(lambda k, v: Base.get_state(v, exclude=exclude, repr_func=repr_func)).__dict__)
def get_state(obj: Any, repr_func: Callable[[Any], dict[str, Any]] = lambda x: x, exclude: Optional[list[str]] = None) -> dict[str, Any]:
if not any([hasattr(obj, "__getstate__"), hasattr(obj, "__dict__")]): return repr_func(obj)
return (tmp if type(tmp := obj.__getstate__() if hasattr(obj, "__getstate__") else obj.__dict__) is not dict else Struct(tmp).filter(lambda k, v: k not in (exclude or [])).apply2values(lambda k, v: Base.get_state(v, exclude=exclude, repr_func=repr_func)).__dict__)
def viz_composition_heirarchy(self, depth: int = 3, obj: Any = None, filt: Optional[Callable[[Any], None]] = None):
install_n_import("objgraph").show_refs([self] if obj is None else [obj], max_depth=depth, filename=str(filename := Path(__import__("tempfile").gettempdir()).joinpath("graph_viz_" + randstr(noun=True) + ".png")), filter=filt)
_ = __import__("os").startfile(str(filename.absolute())) if __import__("sys").platform == "win32" else None; return filename
Expand Down Expand Up @@ -142,11 +144,20 @@ def __call__(self, *args: Any, **kwargs: Any) -> 'List[Any]': return List([i.__c
# ======================== Access Methods ==========================================
def __setitem__(self, key: int, value: T) -> None: self.list[key] = value
def sample(self, size: int = 1, replace: bool = False, p: Optional[list[float]] = None) -> 'List[T]':
return self[list(__import__("numpy").random.choice(len(self), size, replace=replace, p=p))]
import numpy as np
tmp = np.random.choice(len(self), size, replace=replace, p=p)
return List([self.list[item] for item in tmp.tolist()])
def split(self, every: int = 1, to: Optional[int] = None) -> 'List[List[T]]':
every = every if to is None else __import__("math").ceil(len(self) / to)
return List([(self[ix:ix + every] if ix + every < len(self) else self[ix:len(self)]) for ix in range(0, len(self), every)])
def filter(self, func: Callable[[T], bool], which: Optional[Callable[[int, T], T2]] = lambda _idx, x: x) -> 'List[T2]': return List([which(idx, x) for idx, x in enumerate(self.list) if func(x)])
res: list[List[T]] = []
for ix in range(0, len(self), every):
if ix + every < len(self):
tmp = self.list[ix:ix + every]
else:
tmp = self.list[ix:len(self)]
res.append(List(tmp))
return List(res)
def filter(self, func: Callable[[T], bool], which: Callable[[int, T], T2] = lambda _idx, x: x) -> 'List[T2]': return List([which(idx, x) for idx, x in enumerate(self.list) if func(x)])
# ======================= Modify Methods ===============================
def reduce(self, func: Callable[[T, T], T] = lambda x, y: x + y, default: Optional[T] = None) -> list[T]:
args = (func, self.list) + ((default,) if default is not None else ())
Expand All @@ -159,7 +170,10 @@ def sort(self, key=None, reverse: bool = False) -> 'List[T]': self.list.sort(key
def sorted(self, *args: list[Any], **kwargs: Any) -> 'List[T]': return List(sorted(self.list, *args, **kwargs))
def insert(self, __index: int, __object: T): self.list.insert(__index, __object); return self
# def modify(self, expr: str, other: Optional['List[T]'] = None) -> 'List[T]': _ = [exec(expr) for idx, x in enumerate(self.list)] if other is None else [exec(expr) for idx, (x, y) in enumerate(zip(self.list, other))]; return self
def remove(self, value: Optional[T] = None, values: Optional[list[T]] = None, strict: bool = True) -> 'List[T]': _ = [self.list.remove(a_val) for a_val in ((values or []) + ([value] if value else [])) if strict or value in self.list]; return self
def remove(self, value: Optional[T] = None, values: Optional[list[T]] = None, strict: bool = True) -> 'List[T]':
for a_val in ((values or []) + ([value] if value else [])):
if strict or value in self.list: self.list.remove(a_val)
return self
def to_series(self): return __import__("pandas").Series(self.list)
def to_list(self) -> list[T]: return self.list
def to_numpy(self, **kwargs: Any) -> 'Any': import numpy as np; return np.array(self.list, **kwargs)
Expand All @@ -172,10 +186,16 @@ def __getitem__(self, key: Union[int, list[int], 'slice']) -> Union[T, 'List[T]'
elif isinstance(key, int): return self.list[key]
# assert isinstance(key, slice)
return List(self.list[key]) # for slices # type: ignore
def apply(self, func: Callable[[T], T2], *args: Any, other: Optional['List[T]'] = None, filt: Optional[Callable[[T], bool]] = lambda x: True, jobs: Optional[int] = None, prefer: Optional[str] = [None, 'processes', 'threads'][0], depth: int = 1, verbose: bool = False, desc: Optional[str] = None, **kwargs: Any) -> 'List[T2]':
def apply(self, func: Callable[[T], T2], *args: Any, other: Optional['List[T]'] = None, filt: Callable[[T], bool] = lambda x: True, jobs: Optional[int] = None, prefer: Optional[str] = [None, 'processes', 'threads'][0], depth: int = 1, verbose: bool = False, desc: Optional[str] = None, **kwargs: Any) -> 'List[T2]':
if depth > 1: self.apply(lambda x: x.apply(func, *args, other=other, jobs=jobs, depth=depth - 1, **kwargs))
iterator = (self.list if not verbose else install_n_import("tqdm").tqdm(self.list, desc=desc)) if other is None else (zip(self.list, other) if not verbose else install_n_import("tqdm").tqdm(zip(self.list, other), desc=desc))
if jobs: from joblib import Parallel, delayed; return List(Parallel(n_jobs=jobs, prefer=prefer)(delayed(func)(x, *args, **kwargs) for x in iterator)) if other is None else List(Parallel(n_jobs=jobs, prefer=prefer)(delayed(func)(x, y) for x, y in iterator))
from tqdm import tqdm
if other is None:
iterator = (self.list if not verbose else tqdm(self.list, desc=desc))
else: iterator = (zip(self.list, other) if not verbose else tqdm(zip(self.list, other), desc=desc))
if jobs:
from joblib import Parallel, delayed
if other is None: return List(Parallel(n_jobs=jobs, prefer=prefer)(delayed(func)(x, *args, **kwargs) for x in iterator)) # type: ignore
return List(Parallel(n_jobs=jobs, prefer=prefer)(delayed(func)(x, y) for x, y in iterator)) # type: ignore
return List([func(x, *args, **kwargs) for x in iterator if filt(x)]) if other is None else List([func(x, y) for x, y in iterator])
def to_dataframe(self, names: Optional[list[str]] = None, minimal: bool = False, obj_included: bool = True):
df = __import__("pandas").DataFrame(columns=(['object'] if obj_included or names else []) + list(self.list[0].__dict__.keys()))
Expand All @@ -189,7 +209,7 @@ def to_dataframe(self, names: Optional[list[str]] = None, minimal: bool = False,
class Struct(Base): # inheriting from dict gives `get` method, should give `__contains__` but not working. # Inheriting from Base gives `save` method.
"""Use this class to keep bits and sundry items. Combines the power of dot notation in classes with strings in dictionaries to provide Pandas-like experience"""
def __init__(self, dictionary: Union[dict[Any, Any], Type[object], None] = None, **kwargs: Any):
if dictionary is None or isinstance(dictionary, dict): final_dict = dict() if dictionary is None else dictionary
if dictionary is None or isinstance(dictionary, dict): final_dict = {} if dictionary is None else dictionary
else:
final_dict = (dict(dictionary) if dictionary.__class__.__name__ == "mappingproxy" else dictionary.__dict__) # type: ignore
final_dict.update(kwargs) # type ignore
Expand Down Expand Up @@ -254,7 +274,11 @@ def print(self, dtype: bool = True, return_str: bool = False, justify: int = 30,
if not bool(self): res = f"Empty Struct."
else:
if as_yaml or as_config: res = __import__("yaml").dump(self.__dict__) if as_yaml else config(self.__dict__, justify=justify, **kwargs)
else: res = self._pandas_repr(justify=justify, return_str=False, limit=limit).drop(columns=[] if dtype else ["dtype"])
else:
import pandas as pd
tmp = self._pandas_repr(justify=justify, return_str=False, limit=limit)
assert isinstance(tmp, pd.DataFrame)
res = tmp.drop(columns=[] if dtype else ["dtype"])
if not return_str:
if ("DataFrame" in res.__class__.__name__ and install_n_import("tabulate")): install_n_import("rich").print(res.to_markdown())
else: print(res)
Expand Down

0 comments on commit 62af7be

Please sign in to comment.