Skip to content

Commit

Permalink
added the ability to filter with custom functions
Browse files Browse the repository at this point in the history
  • Loading branch information
nwlandry committed Mar 9, 2024
1 parent 4667614 commit 62d3a07
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
18 changes: 14 additions & 4 deletions xgi/core/diviews.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def filterby(self, stat, val, mode="eq"):
val : Any
Value of the statistic. Usually a single numeric value. When mode is
'between', must be a tuple of exactly two values.
mode : str, optional
mode : str or function, optional
How to compare each value to `val`. Can be one of the following.
* 'eq' (default): Return IDs whose value is exactly equal to `val`.
Expand All @@ -201,6 +201,7 @@ def filterby(self, stat, val, mode="eq"):
* 'geq': Return IDs whose value is greater than or equal to `val`.
* 'between': In this mode, `val` must be a tuple `(val1, val2)`. Return IDs
whose value `v` satisfies `val1 <= v <= val2`.
* function, must be able to call `mode(statistic, val)` and have it map to a bool.
See Also
--------
Expand Down Expand Up @@ -256,6 +257,8 @@ def filterby(self, stat, val, mode="eq"):
bunch = [idx for idx in self if values[idx] >= val]
elif mode == "between":
bunch = [node for node in self if val[0] <= values[node] <= val[1]]
elif callable(mode):
bunch = [idx for idx in self if mode(values[idx], val)]
else:
raise ValueError(
f"Unrecognized mode {mode}. mode must be one of 'eq', 'neq', 'lt', 'gt', 'leq', 'geq', or 'between'."
Expand All @@ -271,9 +274,10 @@ def filterby_attr(self, attr, val, mode="eq", missing=None):
The name of the attribute
val : Any
A single value or, in the case of 'between', a list of length 2
mode : str, optional
mode : str or function, optional
Comparison mode. Valid options are 'eq' (default), 'neq', 'lt', 'gt',
'leq', 'geq', or 'between'.
'leq', 'geq', or 'between'. If a function, must be able to call
`mode(attribute, val)` and have it map to a bool.
missing : Any, optional
The default value if the attribute is missing. If None (default),
ignores those IDs.
Expand Down Expand Up @@ -323,9 +327,15 @@ def filterby_attr(self, attr, val, mode="eq", missing=None):
for idx in self
if values[idx] is not None and val[0] <= values[idx] <= val[1]
]
elif callable(mode):
bunch = [
idx
for idx in self
if values[idx] is not None and mode(values[idx], val)
]
else:
raise ValueError(
f"Unrecognized mode {mode}. mode must be one of 'eq', 'neq', 'lt', 'gt', 'leq', 'geq', or 'between'."
f"Unrecognized mode {mode}. mode must be one of 'eq', 'neq', 'lt', 'gt', 'leq', 'geq', 'between', or a callable function."
)
return type(self).from_view(self, bunch)

Expand Down
18 changes: 14 additions & 4 deletions xgi/core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def filterby(self, stat, val, mode="eq"):
val : Any
Value of the statistic. Usually a single numeric value. When mode is
'between', must be a tuple of exactly two values.
mode : str, optional
mode : str or function, optional
How to compare each value to `val`. Can be one of the following.
* 'eq' (default): Return IDs whose value is exactly equal to `val`.
Expand All @@ -194,6 +194,7 @@ def filterby(self, stat, val, mode="eq"):
* 'geq': Return IDs whose value is greater than or equal to `val`.
* 'between': In this mode, `val` must be a tuple `(val1, val2)`. Return IDs
whose value `v` satisfies `val1 <= v <= val2`.
* function, must be able to call `mode(statistic, val)` and have it map to a bool.
See Also
--------
Expand Down Expand Up @@ -254,7 +255,9 @@ def filterby(self, stat, val, mode="eq"):
elif mode == "geq":
bunch = [idx for idx in self if values[idx] >= val]
elif mode == "between":
bunch = [idx for idx in self if val[0] <= values[idx] <= val[1]]
bunch = [node for node in self if val[0] <= values[node] <= val[1]]
elif callable(mode):
bunch = [idx for idx in self if mode(values[idx], val)]
else:
raise ValueError(
f"Unrecognized mode {mode}. mode must be one of "
Expand All @@ -271,9 +274,10 @@ def filterby_attr(self, attr, val, mode="eq", missing=None):
The name of the attribute
val : Any
A single value or, in the case of 'between', a list of length 2
mode : str, optional
mode : str or function, optional
Comparison mode. Valid options are 'eq' (default), 'neq', 'lt', 'gt',
'leq', 'geq', or 'between'.
'leq', 'geq', or 'between'. If a function, must be able to call
`mode(attribute, val)` and have it map to a bool.
missing : Any, optional
The default value if the attribute is missing. If None (default),
ignores those IDs.
Expand Down Expand Up @@ -323,6 +327,12 @@ def filterby_attr(self, attr, val, mode="eq", missing=None):
for idx in self
if values[idx] is not None and val[0] <= values[idx] <= val[1]
]
elif callable(mode):
bunch = [
idx
for idx in self
if values[idx] is not None and mode(values[idx], val)
]
else:
raise ValueError(
f"Unrecognized mode {mode}. mode must be one of "
Expand Down

0 comments on commit 62d3a07

Please sign in to comment.