Skip to content

Commit

Permalink
test: add unittest for grouping operator features
Browse files Browse the repository at this point in the history
  • Loading branch information
skim0119 committed May 10, 2024
1 parent f0fc544 commit 5a68fc6
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 2 deletions.
33 changes: 31 additions & 2 deletions elastica/modules/feature_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,50 @@


class FeatureGroupFIFO(Iterable):
"""
A class to store the features and their corresponding operators in a FIFO manner.
Examples
--------
>>> feature_group = FeatureGroupFIFO()
>>> feature_group.append_id(obj_1)
>>> feature_group.append_id(obj_2)
>>> feature_group.add_operators(obj_1, [OperatorType.ADD, OperatorType.SUBTRACT])
>>> feature_group.add_operators(obj_2, [OperatorType.SUBTRACT, OperatorType.MULTIPLY])
>>> list(feature_group)
[OperatorType.ADD, OperatorType.SUBTRACT, OperatorType.SUBTRACT, OperatorType.MULTIPLY]
Attributes
----------
_operator_collection : list[list[OperatorType]]
A list of lists of operators. Each list of operators corresponds to a feature.
_operator_ids : list[int]
A list of ids of the features.
Methods
-------
append_id(feature)
Appends the id of the feature to the list of ids.
add_operators(feature, operators)
Adds the operators to the list of operators corresponding to the feature.
"""

def __init__(self):
self._operator_collection: list[list[OperatorType]] = []
self._operator_ids: list[int] = []

def __iter__(self) -> OperatorType:
if not self._operator_collection:
raise RuntimeError("Feature group is not instantiated.")
"""Returns an operator iterator to satisfy the Iterable protocol."""
operator_chain = itertools.chain.from_iterable(self._operator_collection)
for operator in operator_chain:
yield operator

def append_id(self, feature):
"""Appends the id of the feature to the list of ids."""
self._operator_ids.append(id(feature))
self._operator_collection.append([])

def add_operators(self, feature, operators: list[OperatorType]):
"""Adds the operators to the list of operators corresponding to the feature."""
idx = self._operator_ids.index(id(feature))
self._operator_collection[idx].extend(operators)
56 changes: 56 additions & 0 deletions tests/test_modules/test_feature_grouping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from elastica.modules.feature_group import FeatureGroupFIFO


def test_add_ids():
feature_group = FeatureGroupFIFO()
feature_group.append_id(1)
feature_group.append_id(2)
feature_group.append_id(3)

assert feature_group._operator_ids == [id(1), id(2), id(3)]


def test_add_operators():
feature_group = FeatureGroupFIFO()
feature_group.append_id(1)
feature_group.add_operators(1, [1, 2, 3])
feature_group.append_id(2)
feature_group.add_operators(2, [4, 5, 6])
feature_group.append_id(3)
feature_group.add_operators(3, [7, 8, 9])

assert feature_group._operator_collection == [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
assert feature_group._operator_ids == [id(1), id(2), id(3)]

feature_group.append_id(4)
feature_group.add_operators(4, [10, 11, 12])

assert feature_group._operator_collection == [
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12],
]
assert feature_group._operator_ids == [id(1), id(2), id(3), id(4)]


def test_grouping():
feature_group = FeatureGroupFIFO()
feature_group.append_id(1)
feature_group.add_operators(1, [1, 2, 3])
feature_group.append_id(2)
feature_group.add_operators(2, [4, 5, 6])
feature_group.append_id(3)
feature_group.add_operators(3, [7, 8, 9])

assert list(feature_group) == [1, 2, 3, 4, 5, 6, 7, 8, 9]

feature_group.append_id(4)
feature_group.add_operators(4, [10, 11, 12])

assert list(feature_group) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

feature_group.append_id(1)
feature_group.add_operators(1, [13, 14, 15])

assert list(feature_group) == [1, 2, 3, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12]

0 comments on commit 5a68fc6

Please sign in to comment.