Skip to content

Commit

Permalink
add hypergraph vertex weight (v_weight and W_v)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanfeng97 committed Dec 27, 2022
1 parent 7786e2d commit e17fc95
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 7 deletions.
45 changes: 38 additions & 7 deletions dhg/structure/hypergraphs/hypergraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Hypergraph(BaseHypergraph):
``num_v`` (``int``): The number of vertices in the hypergraph.
``e_list`` (``Union[List[int], List[List[int]]]``, optional): A list of hyperedges describes how the vertices point to the hyperedges. Defaults to ``None``.
``e_weight`` (``Union[float, List[float]]``, optional): A list of weights for hyperedges. If set to ``None``, the value ``1`` is used for all hyperedges. Defaults to ``None``.
``v_weight`` (``Union[List[float]]``, optional): A list of weights for vertices. If set to ``None``, the value ``1`` is used for all vertices. Defaults to ``None``.
``merge_op`` (``str``): The operation to merge those conflicting hyperedges in the same hyperedge group, which can be ``'mean'``, ``'sum'`` or ``'max'``. Defaults to ``'mean'``.
``device`` (``torch.device``, optional): The deivce to store the hypergraph. Defaults to ``torch.device('cpu')``.
"""
Expand All @@ -31,10 +32,18 @@ def __init__(
num_v: int,
e_list: Optional[Union[List[int], List[List[int]]]] = None,
e_weight: Optional[Union[float, List[float]]] = None,
v_weight: Optional[List[float]] = None,
merge_op: str = "mean",
device: torch.device = torch.device("cpu"),
):
super().__init__(num_v, device=device)
# init vertex weight
if v_weight is None:
self._v_weight = [1.0] * self.num_v
else:
assert len(v_weight) == self.num_v, "The length of vertex weight is not equal to the number of vertices."
self._v_weight = v_weight
# init hyperedges
if e_list is not None:
self.add_hyperedges(e_list, e_weight, merge_op=merge_op)

Expand Down Expand Up @@ -495,6 +504,12 @@ def v(self) -> List[int]:
r"""Return the list of vertices.
"""
return super().v

@property
def v_weight(self) -> List[float]:
r"""Return the list of vertex weights.
"""
return self._v_weight

@property
def e(self) -> Tuple[List[List[int]], List[float]]:
Expand Down Expand Up @@ -634,7 +649,7 @@ def vars_for_DL(self) -> List[str]:
Sparse Diagnal Matrices:
.. math::
\mathbf{W}_e, \mathbf{D}_v, \mathbf{D}_v^{-1}, \mathbf{D}_v^{-\frac{1}{2}}, \mathbf{D}_e, \mathbf{D}_e^{-1},
\mathbf{W}_v, \mathbf{W}_e, \mathbf{D}_v, \mathbf{D}_v^{-1}, \mathbf{D}_v^{-\frac{1}{2}}, \mathbf{D}_e, \mathbf{D}_e^{-1},
Vectors:
Expand All @@ -649,6 +664,7 @@ def vars_for_DL(self) -> List[str]:
"L_sym",
"L_rw",
"L_HGNN",
"W_v",
"W_e",
"D_v",
"D_v_neg_1",
Expand Down Expand Up @@ -754,14 +770,14 @@ def e2v_weight_of_group(self, group_name: str) -> torch.Tensor:

@property
def H(self) -> torch.Tensor:
r"""Return the hypergraph incidence matrix :math:`\mathbf{H}` with ``torch.Tensor`` format.
r"""Return the hypergraph incidence matrix :math:`\mathbf{H}` with ``torch.sparse_coo_tensor`` format.
"""
if self.cache.get("H") is None:
self.cache["H"] = self.H_v2e
return self.cache["H"]

def H_of_group(self, group_name: str) -> torch.Tensor:
r"""Return the hypergraph incidence matrix :math:`\mathbf{H}` of the specified hyperedge group with ``torch.Tensor`` format.
r"""Return the hypergraph incidence matrix :math:`\mathbf{H}` of the specified hyperedge group with ``torch.sparse_coo_tensor`` format.
Args:
``group_name`` (``str``): The name of the specified hyperedge group.
Expand All @@ -773,14 +789,14 @@ def H_of_group(self, group_name: str) -> torch.Tensor:

@property
def H_T(self) -> torch.Tensor:
r"""Return the transpose of the hypergraph incidence matrix :math:`\mathbf{H}^\top` with ``torch.Tensor`` format.
r"""Return the transpose of the hypergraph incidence matrix :math:`\mathbf{H}^\top` with ``torch.sparse_coo_tensor`` format.
"""
if self.cache.get("H_T") is None:
self.cache["H_T"] = self.H.t()
return self.cache["H_T"]

def H_T_of_group(self, group_name: str) -> torch.Tensor:
r"""Return the transpose of the hypergraph incidence matrix :math:`\mathbf{H}^\top` of the specified hyperedge group with ``torch.Tensor`` format.
r"""Return the transpose of the hypergraph incidence matrix :math:`\mathbf{H}^\top` of the specified hyperedge group with ``torch.sparse_coo_tensor`` format.
Args:
``group_name`` (``str``): The name of the specified hyperedge group.
Expand All @@ -789,10 +805,25 @@ def H_T_of_group(self, group_name: str) -> torch.Tensor:
if self.group_cache[group_name].get("H_T") is None:
self.group_cache[group_name]["H_T"] = self.H_of_group(group_name).t()
return self.group_cache[group_name]["H_T"]

@property
def W_v(self) -> torch.Tensor:
r"""Return the weight matrix :math:`\mathbf{W}_v` of vertices with ``torch.sparse_coo_tensor`` format.
"""
if self.cache.get("W_v") is None:
_tmp = torch.Tensor(self.v_weight)
_num_v = _tmp.size(0)
self.cache["W_v"] = torch.sparse_coo_tensor(
torch.arange(0, _num_v).view(1, -1).repeat(2, 1),
_tmp,
torch.Size([_num_v, _num_v]),
device=self.device,
).coalesce()
return self.cache["W_v"]

@property
def W_e(self) -> torch.Tensor:
r"""Return the weight matrix :math:`\mathbf{W}_e` of hyperedges with ``torch.Tensor`` format.
r"""Return the weight matrix :math:`\mathbf{W}_e` of hyperedges with ``torch.sparse_coo_tensor`` format.
"""
if self.cache.get("W_e") is None:
_tmp = [self.W_e_of_group(name)._values().clone() for name in self.group_names]
Expand All @@ -807,7 +838,7 @@ def W_e(self) -> torch.Tensor:
return self.cache["W_e"]

def W_e_of_group(self, group_name: str) -> torch.Tensor:
r"""Return the weight matrix :math:`\mathbf{W}_e` of hyperedges of the specified hyperedge group with ``torch.Tensor`` format.
r"""Return the weight matrix :math:`\mathbf{W}_e` of hyperedges of the specified hyperedge group with ``torch.sparse_coo_tensor`` format.
Args:
``group_name`` (``str``): The name of the specified hyperedge group.
Expand Down
6 changes: 6 additions & 0 deletions tests/structure/test_hypergraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,12 @@ def test_H_T_group(g1):
assert (g1.H_T_of_group("knn").to_dense().cpu() == torch.tensor([[1, 0, 0, 0, 1, 1]])).all()


def test_W_v(g2):
assert (g2.W_v.cpu()._values() == torch.tensor([1, 1, 1, 1, 1])).all()
hg = Hypergraph(5, [[1, 2], [0, 2, 3, 4]], v_weight=[0.1, 1, 2, 1, 1])
assert (hg.W_v.cpu()._values() == torch.tensor([0.1, 1, 2, 1, 1])).all()


def test_W_e(g2):
assert (g2.W_e.cpu()._values() == torch.tensor([0.5, 1, 0.5, 1, 0.5])).all()

Expand Down

0 comments on commit e17fc95

Please sign in to comment.