diff --git a/test/classes/test_complex.py b/test/classes/test_complex.py index c521962a..7851a76e 100644 --- a/test/classes/test_complex.py +++ b/test/classes/test_complex.py @@ -2,9 +2,15 @@ import pytest -from toponetx.classes.complex import Complex +from toponetx.classes.cell import Cell +from toponetx.classes.cell_complex import CellComplex +from toponetx.classes.colored_hypergraph import ColoredHyperGraph +from toponetx.classes.complex import Atom, Complex from toponetx.classes.hyperedge import HyperEdge +from toponetx.classes.path import Path +from toponetx.classes.path_complex import PathComplex from toponetx.classes.simplex import Simplex +from toponetx.classes.simplicial_complex import SimplicialComplex class TestAtom: @@ -26,6 +32,68 @@ def test_atoms_equal(self): class TestComplex: """Test the Complex abstract class.""" + complex_classes = (CellComplex, ColoredHyperGraph, PathComplex, SimplicialComplex) + atom_classes = (Cell, HyperEdge, Path, Simplex) + add_atom_method = ("add_cell", "add_cell", "add_path", "add_simplex") + + @pytest.mark.parametrize( + "complex_class,atom_class,add_method", + zip(complex_classes, atom_classes, add_atom_method, strict=True), + ) + def test_add_atom_with_attribute( + self, complex_class: type[Complex], atom_class: type[Atom], add_method: str + ) -> None: + """Test adding an atom with an attribute. + + Parameters + ---------- + complex_class : type[Complex] + The complex class to test. + atom_class : type[Atom] + The atom class to test. + add_method : str + The name of the method to add the atom to the complex. + """ + complex_ = complex_class() + atom1 = atom_class((1, 2, 3), weight=1) + atom2 = atom_class((2, 3, 4)) + add_func = getattr(complex_, add_method) + + add_func(atom1) + assert atom1 in complex_ + assert complex_[atom1]["weight"] == 1 + + add_func(atom2, weight=2) + assert atom2 in complex_ + assert complex_[atom2]["weight"] == 2 + + @pytest.mark.parametrize( + "complex_class,atom_class,add_method", + zip(complex_classes, atom_classes, add_atom_method, strict=True), + ) + def test_add_atom_attribute_precedence( + self, complex_class: type[Complex], atom_class: type[Atom], add_method: str + ) -> None: + """Test that explicitly added attributes take precedence. + + Parameters + ---------- + complex_class : type[Complex] + The complex class to test. + atom_class : type[Atom] + The atom class to test. + add_method : str + The name of the method to add the atom to the complex. + """ + complex_ = complex_class() + atom = atom_class((1, 2, 3), weight=1) + + add_func = getattr(complex_, add_method) + add_func(atom, weight=2) + + assert atom in complex_ + assert complex_[atom]["weight"] == 2 + def test_complex_is_abstract(self): """Test if the Complex abstract class is abstract.""" with pytest.raises(TypeError): diff --git a/toponetx/classes/path_complex.py b/toponetx/classes/path_complex.py index 85cf7e44..00b816c9 100644 --- a/toponetx/classes/path_complex.py +++ b/toponetx/classes/path_complex.py @@ -213,9 +213,7 @@ def add_path(self, path: Hashable | Sequence[Hashable] | Path, **attr) -> None: """ new_paths = set() if isinstance(path, int | str): - path = [ - path, - ] + path = [path] if isinstance(path, list | tuple | Path): if not isinstance(path, Path): # path is a list or tuple path_ = tuple(path) @@ -1095,19 +1093,17 @@ def _update_attributes(self, path, **attr): **attr : keyword arguments The attributes to update. """ - path_ = path.elements if isinstance(path, Path) else tuple(path) - if isinstance(path, Path): # update attributes for PathView() and _G - self._path_set.faces_dict[len(path_) - 1][path_].update(path._attributes) - if len(path_) == 1: - self._G.add_node(path_[0], **path._attributes) - elif len(path_) == 2: - self._G.add_edge(path_[0], path_[1], **path._attributes) + if isinstance(path, Path): + path_ = path.elements + attr = path._attributes | attr else: - self._path_set.faces_dict[len(path_) - 1][path_].update(attr) - if len(path_) == 1: - self._G.add_node(path_[0], **attr) - elif len(path_) == 2: - self._G.add_edge(path_[0], path_[1], **attr) + path_ = tuple(path) + + self._path_set.faces_dict[len(path_) - 1][path_].update(attr) + if len(path_) == 1: + self._G.add_node(path_[0], **attr) + elif len(path_) == 2: + self._G.add_edge(path_[0], path_[1], **attr) def __contains__(self, atom: Any) -> bool: """Check if an atom is in the path complex.