Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow any string name for metadata #363

Merged
merged 19 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 59 additions & 11 deletions pynapple/core/interval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .metadata_class import _MetadataMixin
from .time_index import TsIndex
from .utils import (
_convert_iter_to_str,
_get_terminal_size,
_IntervalSetSliceHelper,
check_filename,
Expand Down Expand Up @@ -214,10 +215,12 @@
self.index = np.arange(data.shape[0], dtype="int")
self.columns = np.array(["start", "end"])
self.nap_class = self.__class__.__name__
if drop_meta:
_MetadataMixin.__init__(self)
else:
_MetadataMixin.__init__(self, metadata)
# initialize metadata to get all attributes before setting metadata
_MetadataMixin.__init__(self)
self._class_attributes = self.__dir__() # get list of all attributes
self._class_attributes.append("_class_attributes") # add this property
if drop_meta is False:
self.set_info(metadata)
self._initialized = True

def __repr__(self):
Expand All @@ -229,7 +232,14 @@
# By default, the first three columns should always show.

# Adding an extra column between actual values and metadata
col_names = self._metadata.columns
try:
metadata = self._metadata
col_names = metadata.columns
except Exception:

Check warning on line 238 in pynapple/core/interval_set.py

View check run for this annotation

Codecov / codecov/patch

pynapple/core/interval_set.py#L238

Added line #L238 was not covered by tests
# Necessary for backward compatibility when saving IntervalSet as pickle
metadata = pd.DataFrame(index=self.index)
col_names = []

Check warning on line 241 in pynapple/core/interval_set.py

View check run for this annotation

Codecov / codecov/patch

pynapple/core/interval_set.py#L240-L241

Added lines #L240 - L241 were not covered by tests

headers = ["index", "start", "end"]
if len(col_names):
headers += [""] + [c for c in col_names]
Expand All @@ -249,7 +259,7 @@
self.index[0:n_rows, None],
self.values[0:n_rows],
separator,
self._metadata.values[0:n_rows],
_convert_iter_to_str(metadata.values[0:n_rows]),
),
dtype=object,
),
Expand All @@ -259,7 +269,7 @@
self.index[-n_rows:, None],
self.values[0:n_rows],
separator,
self._metadata.values[-n_rows:],
_convert_iter_to_str(metadata.values[-n_rows:]),
),
dtype=object,
),
Expand All @@ -271,7 +281,12 @@
else:
separator = np.empty((len(self), 0))
data = np.hstack(
(self.index[:, None], self.values, separator, self._metadata.values),
(
self.index[:, None],
self.values,
separator,
_convert_iter_to_str(metadata.values),
),
dtype=object,
)

Expand All @@ -286,19 +301,53 @@
def __setattr__(self, name, value):
# necessary setter to allow metadata to be set as an attribute
if self._initialized:
_MetadataMixin.__setattr__(self, name, value)
if name in self._class_attributes:
raise AttributeError(
f"Cannot set attribute '{name}'; IntervalSet is immutable. Use 'set_info()' to set '{name}' as metadata."
)
else:
_MetadataMixin.__setattr__(self, name, value)
else:
object.__setattr__(self, name, value)

def __getattr__(self, name):
# Necessary for backward compatibility with pickle

# avoid infinite recursion when pickling due to
# self._metadata.column having attributes '__reduce__', '__reduce_ex__'
if name in ("__getstate__", "__setstate__", "__reduce__", "__reduce_ex__"):
raise AttributeError(name)

try:
metadata = self._metadata
except Exception:
metadata = pd.DataFrame(index=self.index)

Check warning on line 324 in pynapple/core/interval_set.py

View check run for this annotation

Codecov / codecov/patch

pynapple/core/interval_set.py#L323-L324

Added lines #L323 - L324 were not covered by tests

if name == "_metadata":
return metadata

Check warning on line 327 in pynapple/core/interval_set.py

View check run for this annotation

Codecov / codecov/patch

pynapple/core/interval_set.py#L327

Added line #L327 was not covered by tests
elif name in metadata.columns:
return _MetadataMixin.__getattr__(self, name)
else:
return super().__getattr__(name)

def __setitem__(self, key, value):
if (isinstance(key, str)) and (key not in self.columns):
if key in self.columns:
raise RuntimeError(
"IntervalSet is immutable. Starts and ends have been already sorted."
)
elif isinstance(key, str):
_MetadataMixin.__setitem__(self, key, value)
else:
raise RuntimeError(
"IntervalSet is immutable. Starts and ends have been already sorted."
)

def __getitem__(self, key):
try:
metadata = _MetadataMixin.__getitem__(self, key)
except Exception:
metadata = pd.DataFrame(index=self.index)

if isinstance(key, str):
# self[str]
if key == "start":
Expand All @@ -323,7 +372,6 @@
elif isinstance(key, Number):
# self[Number]
output = self.values.__getitem__(key)
metadata = _MetadataMixin.__getitem__(self, key)
return IntervalSet(start=output[0], end=output[1], metadata=metadata)
elif isinstance(key, (slice, list, np.ndarray, pd.Series)):
# self[array_like]
Expand Down
67 changes: 39 additions & 28 deletions pynapple/core/metadata_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

"""

def __init__(self, metadata=None, **kwargs):
def __init__(self, metadata=None):
"""
Metadata initializer

Expand All @@ -21,7 +21,6 @@
List of pandas.DataFrame
**kwargs : dict
Dictionary containing metadata information

"""
if self.__class__.__name__ == "TsdFrame":
# metadata index is the same as the columns for TsdFrame
Expand All @@ -31,12 +30,8 @@
self.metadata_index = self.index

self._metadata = pd.DataFrame(index=self.metadata_index)
if len(kwargs):
warnings.warn(
"initializing metadata with variable keyword arguments may be unsupported in a future version of Pynapple. Instead, initialize using the metadata argument.",
FutureWarning,
)
self.set_info(metadata, **kwargs)

self.set_info(metadata)

def __dir__(self):
"""
Expand Down Expand Up @@ -115,32 +110,48 @@
raise TypeError(
f"Invalid metadata type {type(name)}. Metadata column names must be strings!"
)
if hasattr(self, name) and (name not in self.metadata_columns):
# existing non-metadata attribute
raise ValueError(
f"Invalid metadata name '{name}'. Metadata name must differ from "
f"{type(self).__dict__.keys()} attribute names!"
)
if hasattr(self, "columns") and name in self.columns:
# existing column (since TsdFrame columns are not attributes)
raise ValueError(
f"Invalid metadata name '{name}'. Metadata name must differ from "
f"{self.columns} column names!"
)
if name[0].isalpha() is False:
# starts with a number
raise ValueError(
f"Invalid metadata name '{name}'. Metadata name cannot start with a number"
)
# warnings for metadata names that cannot be accessed as attributes or keys
if name in self._class_attributes:
if (self.nap_class == "TsGroup") and (name == "rate"):
# special exception for TsGroup rate attribute
raise ValueError(
f"Invalid metadata name '{name}'. Cannot overwrite TsGroup 'rate'!"
)
else:
# existing non-metadata attribute
warnings.warn(
f"Metadata name '{name}' overlaps with an existing attribute, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata."
)
elif hasattr(self, "columns") and name in self.columns:
if self.nap_class == "TsdFrame":
# special exception for TsdFrame columns
raise ValueError(
f"Invalid metadata name '{name}'. Metadata name must differ from {list(self.columns)} column names!"
)
else:
# existing non-metadata column
warnings.warn(

Check warning on line 133 in pynapple/core/metadata_class.py

View check run for this annotation

Codecov / codecov/patch

pynapple/core/metadata_class.py#L133

Added line #L133 was not covered by tests
f"Metadata name '{name}' overlaps with an existing property, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata."
)
# elif name in self.metadata_columns:
# # warnings for metadata that already exists
# warnings.warn(f"Overwriting existing metadata column '{name}'.")

# warnings for metadata that cannot be accessed as attributes
if name.replace("_", "").isalnum() is False:
# contains invalid characters
raise ValueError(
f"Invalid metadata name '{name}'. Metadata name cannot contain special characters except for underscores"
warnings.warn(
f"Metadata name '{name}' contains a special character, and cannot be accessed as an attribute. Use 'get_info()' or key indexing to access metadata."
)
elif (name[0].isalpha() is False) and (name[0] != "_"):
# starts with a number
warnings.warn(
f"Metadata name '{name}' starts with a number, and cannot be accessed as an attribute. Use 'get_info()' or key indexing to access metadata."
)

def _check_metadata_column_names(self, metadata=None, **kwargs):
"""
Check that metadata column names don't conflict with existing attributes, don't start with a number, and don't contain invalid characters.
Throw warnings when metadata names cannot be accessed as attributes or keys.
"""

if metadata is not None:
Expand Down
47 changes: 37 additions & 10 deletions pynapple/core/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .time_index import TsIndex
from .utils import (
_concatenate_tsd,
_convert_iter_to_str,
_get_terminal_size,
_split_tsd,
_TsdFrameSliceHelper,
Expand Down Expand Up @@ -949,11 +950,22 @@

self.columns = pd.Index(c)
self.nap_class = self.__class__.__name__
_MetadataMixin.__init__(self, metadata)
# initialize metadata for class attributes
_MetadataMixin.__init__(self)
# get current list of attributes
self._class_attributes = self.__dir__()
self._class_attributes.append("_class_attributes")
# set metadata
self.set_info(metadata)
self._initialized = True

@property
def loc(self):
# add deprecation warning
warnings.warn(
"'loc' will be deprecated in a future version. Use bracket indexing instead.",
DeprecationWarning,
)
return _TsdFrameSliceHelper(self)

def __repr__(self):
Expand Down Expand Up @@ -1030,7 +1042,9 @@
np.hstack(
(
col_names[:, None],
self._metadata.values[0:max_cols].T,
_convert_iter_to_str(
self._metadata.values[0:max_cols].T
),
ends,
),
dtype=object,
Expand All @@ -1045,7 +1059,12 @@
def __setattr__(self, name, value):
# necessary setter to allow metadata to be set as an attribute
if self._initialized:
_MetadataMixin.__setattr__(self, name, value)
if name in self._class_attributes:
raise AttributeError(
f"Cannot set attribute: '{name}' is a reserved attribute. Use 'set_info()' to set '{name}' as metadata."
)
else:
_MetadataMixin.__setattr__(self, name, value)
else:
super().__setattr__(name, value)

Expand All @@ -1056,7 +1075,15 @@
# self._metadata.column having attributes '__reduce__', '__reduce_ex__'
if name in ("__getstate__", "__setstate__", "__reduce__", "__reduce_ex__"):
raise AttributeError(name)
if name in self._metadata.columns:

try:
metadata = self._metadata
except (AttributeError, RecursionError):
metadata = pd.DataFrame(index=self.columns)

Check warning on line 1082 in pynapple/core/time_series.py

View check run for this annotation

Codecov / codecov/patch

pynapple/core/time_series.py#L1081-L1082

Added lines #L1081 - L1082 were not covered by tests

if name == "_metadata":
return metadata

Check warning on line 1085 in pynapple/core/time_series.py

View check run for this annotation

Codecov / codecov/patch

pynapple/core/time_series.py#L1085

Added line #L1085 was not covered by tests
elif name in metadata.columns:
return _MetadataMixin.__getattr__(self, name)
else:
return super().__getattr__(name)
Expand Down Expand Up @@ -1096,18 +1123,18 @@
"When indexing with a Tsd, it must contain boolean values"
)
key = key.d
elif isinstance(key, str) and (key in self.metadata_columns):
return _MetadataMixin.__getitem__(self, key)
elif (
isinstance(key, str)
or hasattr(key, "__iter__")
and all([isinstance(k, str) for k in key])
):
if all(k in self.metadata_columns for k in key):
return _MetadataMixin.__getitem__(self, key)
if all(k in self.columns for k in key):
with warnings.catch_warnings():
# ignore deprecated warning for loc
warnings.simplefilter("ignore")
return self.loc[key]
else:
return self.loc[key]

return _MetadataMixin.__getitem__(self, key)
else:
if isinstance(key, pd.Series) and key.index.equals(self.columns):
# if indexing with a pd.Series from metadata, transform it to tuple with slice(None) in first position
Expand Down
Loading
Loading