Skip to content

Commit

Permalink
Merge pull request #413 from pyt-team/frantzen/path-attribute
Browse files Browse the repository at this point in the history
Fix attribute precedence for path complexes
  • Loading branch information
ffl096 authored Dec 12, 2024
2 parents 60ddcc7 + 494249e commit 0edefd1
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 16 deletions.
70 changes: 69 additions & 1 deletion test/classes/test_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@

import pytest

from toponetx.classes.complex import Complex
from toponetx.classes.cell import Cell
from toponetx.classes.cell_complex import CellComplex
from toponetx.classes.colored_hypergraph import ColoredHyperGraph
from toponetx.classes.complex import Atom, Complex
from toponetx.classes.hyperedge import HyperEdge
from toponetx.classes.path import Path
from toponetx.classes.path_complex import PathComplex
from toponetx.classes.simplex import Simplex
from toponetx.classes.simplicial_complex import SimplicialComplex


class TestAtom:
Expand All @@ -26,6 +32,68 @@ def test_atoms_equal(self):
class TestComplex:
"""Test the Complex abstract class."""

complex_classes = (CellComplex, ColoredHyperGraph, PathComplex, SimplicialComplex)
atom_classes = (Cell, HyperEdge, Path, Simplex)
add_atom_method = ("add_cell", "add_cell", "add_path", "add_simplex")

@pytest.mark.parametrize(
"complex_class,atom_class,add_method",
zip(complex_classes, atom_classes, add_atom_method, strict=True),
)
def test_add_atom_with_attribute(
self, complex_class: type[Complex], atom_class: type[Atom], add_method: str
) -> None:
"""Test adding an atom with an attribute.
Parameters
----------
complex_class : type[Complex]
The complex class to test.
atom_class : type[Atom]
The atom class to test.
add_method : str
The name of the method to add the atom to the complex.
"""
complex_ = complex_class()
atom1 = atom_class((1, 2, 3), weight=1)
atom2 = atom_class((2, 3, 4))
add_func = getattr(complex_, add_method)

add_func(atom1)
assert atom1 in complex_
assert complex_[atom1]["weight"] == 1

add_func(atom2, weight=2)
assert atom2 in complex_
assert complex_[atom2]["weight"] == 2

@pytest.mark.parametrize(
"complex_class,atom_class,add_method",
zip(complex_classes, atom_classes, add_atom_method, strict=True),
)
def test_add_atom_attribute_precedence(
self, complex_class: type[Complex], atom_class: type[Atom], add_method: str
) -> None:
"""Test that explicitly added attributes take precedence.
Parameters
----------
complex_class : type[Complex]
The complex class to test.
atom_class : type[Atom]
The atom class to test.
add_method : str
The name of the method to add the atom to the complex.
"""
complex_ = complex_class()
atom = atom_class((1, 2, 3), weight=1)

add_func = getattr(complex_, add_method)
add_func(atom, weight=2)

assert atom in complex_
assert complex_[atom]["weight"] == 2

def test_complex_is_abstract(self):
"""Test if the Complex abstract class is abstract."""
with pytest.raises(TypeError):
Expand Down
26 changes: 11 additions & 15 deletions toponetx/classes/path_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,7 @@ def add_path(self, path: Hashable | Sequence[Hashable] | Path, **attr) -> None:
"""
new_paths = set()
if isinstance(path, int | str):
path = [
path,
]
path = [path]
if isinstance(path, list | tuple | Path):
if not isinstance(path, Path): # path is a list or tuple
path_ = tuple(path)
Expand Down Expand Up @@ -1095,19 +1093,17 @@ def _update_attributes(self, path, **attr):
**attr : keyword arguments
The attributes to update.
"""
path_ = path.elements if isinstance(path, Path) else tuple(path)
if isinstance(path, Path): # update attributes for PathView() and _G
self._path_set.faces_dict[len(path_) - 1][path_].update(path._attributes)
if len(path_) == 1:
self._G.add_node(path_[0], **path._attributes)
elif len(path_) == 2:
self._G.add_edge(path_[0], path_[1], **path._attributes)
if isinstance(path, Path):
path_ = path.elements
attr = path._attributes | attr
else:
self._path_set.faces_dict[len(path_) - 1][path_].update(attr)
if len(path_) == 1:
self._G.add_node(path_[0], **attr)
elif len(path_) == 2:
self._G.add_edge(path_[0], path_[1], **attr)
path_ = tuple(path)

self._path_set.faces_dict[len(path_) - 1][path_].update(attr)
if len(path_) == 1:
self._G.add_node(path_[0], **attr)
elif len(path_) == 2:
self._G.add_edge(path_[0], path_[1], **attr)

def __contains__(self, atom: Any) -> bool:
"""Check if an atom is in the path complex.
Expand Down

0 comments on commit 0edefd1

Please sign in to comment.