Skip to content

Commit

Permalink
Merge pull request #377 from pynapple-org/fix_class_instantiations
Browse files Browse the repository at this point in the history
Fix class instantiations
  • Loading branch information
gviejo authored Dec 14, 2024
2 parents 17faa1e + 96a0014 commit 8d9db8b
Show file tree
Hide file tree
Showing 7 changed files with 536 additions and 408 deletions.
91 changes: 37 additions & 54 deletions pynapple/core/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ def __init__(self, t, time_units="s", time_support=None):
self.rate = np.nan
self.time_support = IntervalSet(start=[], end=[])

@abc.abstractmethod
def _define_instance(self, time_index, time_support, values=None, **kwargs):
"""Return a new class instance.
Pass "columns", "metadata" and other attributes of self
to the new instance unless specified in kwargs.
"""
pass

@property
def t(self):
"""The time index of the time series"""
Expand Down Expand Up @@ -197,8 +206,15 @@ def value_from(self, data, ep=None):
>>> print(len(ts.restrict(ep)), len(newts))
52 52
"""
if not isinstance(data, _Base) and not hasattr(data, "values"):
raise TypeError(
"First argument should be an instance of Tsd, TsdFrame or TsdTensor"
)
if ep is None:
ep = data.time_support
if not isinstance(ep, IntervalSet):
raise TypeError("Argument ep should be of type IntervalSet or None")

time_array = self.index.values
time_target_array = data.index.values
data_target_array = data.values
Expand All @@ -211,15 +227,9 @@ def value_from(self, data, ep=None):

time_support = IntervalSet(start=starts, end=ends)

kwargs = {}
if hasattr(data, "columns"):
kwargs["columns"] = data.columns
if hasattr(data, "_metadata"):
kwargs["metadata"] = data._metadata

return t, d, time_support, kwargs
return data._define_instance(time_index=t, time_support=time_support, values=d)

def count(self, *args, dtype=None, **kwargs):
def count(self, bin_size=None, ep=None, time_units="s", dtype=None):
"""
Count occurences of events within bin_size or within a set of bins defined as an IntervalSet.
You can call this function in multiple ways :
Expand Down Expand Up @@ -276,37 +286,20 @@ def count(self, *args, dtype=None, **kwargs):
start end
0 100.0 800.0
"""
bin_size = None
if "bin_size" in kwargs:
bin_size = kwargs["bin_size"]

if bin_size is not None:
if isinstance(bin_size, int):
bin_size = float(bin_size)
if not isinstance(bin_size, float):
raise ValueError("bin_size argument should be float.")
else:
for a in args:
if isinstance(a, (float, int)):
bin_size = float(a)

time_units = "s"
if "time_units" in kwargs:
time_units = kwargs["time_units"]
if not isinstance(time_units, str):
raise TypeError("bin_size argument should be float or int.")

if not isinstance(time_units, str) or time_units not in ["s", "ms", "us"]:
raise ValueError("time_units argument should be 's', 'ms' or 'us'.")
else:
for a in args:
if isinstance(a, str) and a in ["s", "ms", "us"]:
time_units = a

ep = self.time_support
if "ep" in kwargs:
ep = kwargs["ep"]
if not isinstance(ep, IntervalSet):
raise ValueError("ep argument should be IntervalSet")
else:
for a in args:
if isinstance(a, IntervalSet):
ep = a

if ep is None:
ep = self.time_support
if not isinstance(ep, IntervalSet):
raise TypeError("ep argument should be of type IntervalSet")

if dtype is None:
dtype = np.dtype(np.int64)
Expand All @@ -326,7 +319,7 @@ def count(self, *args, dtype=None, **kwargs):

t, d = _count(time_array, starts, ends, bin_size, dtype=dtype)

return t, d, ep
return self._define_instance(t, ep, values=d)

def restrict(self, iset):
"""
Expand Down Expand Up @@ -360,33 +353,23 @@ def restrict(self, iset):
0 0.0 500.0
"""

assert isinstance(iset, IntervalSet), "Argument should be IntervalSet"
if not isinstance(iset, IntervalSet):
raise TypeError("Argument should be IntervalSet")

time_array = self.index.values
starts = iset.start
ends = iset.end

idx = _restrict(time_array, starts, ends)

kwargs = {}
if hasattr(self, "columns"):
kwargs["columns"] = self.columns

if hasattr(self, "_metadata"):
kwargs["metadata"] = self._metadata

if hasattr(self, "values"):
data_array = self.values
return self.__class__(
t=time_array[idx], d=data_array[idx], time_support=iset, **kwargs
)
else:
return self.__class__(t=time_array[idx], time_support=iset)
data = None if not hasattr(self, "values") else self.values[idx]
return self._define_instance(time_array[idx], iset, values=data)

def copy(self):
"""Copy the data, index and time support"""
return self.__class__(t=self.index.copy(), time_support=self.time_support)
data = getattr(self, "values", None)
if data is not None:
data = data.copy() if hasattr(data, "copy") else data[:].copy()
return self._define_instance(self.index.copy(), self.time_support, values=data)

def find_support(self, min_gap, time_units="s"):
"""
Expand Down
Loading

0 comments on commit 8d9db8b

Please sign in to comment.