-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add assortativity * test * doc-string * doc-string * Update torch_geometric/utils/assortativity.py Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de> * Update torch_geometric/utils/assortativity.py Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de> * Update torch_geometric/utils/assortativity.py Co-authored-by: Padarn Wilson <padarn.wilson@grabtaxi.com> * update test * doc-string * fix test * changelog Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de> Co-authored-by: Padarn Wilson <padarn.wilson@grabtaxi.com>
- Loading branch information
1 parent
6ca2332
commit 7029bad
Showing
4 changed files
with
88 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import pytest | ||
import torch | ||
|
||
from torch_geometric.utils import assortativity | ||
|
||
|
||
def test_assortativity(): | ||
# completely assortative graph | ||
edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], | ||
[1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) | ||
out = assortativity(edge_index) | ||
assert pytest.approx(out, abs=1e-5) == 1.0 | ||
|
||
# completely disassortative graph | ||
edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 5, 5, 5, 5], | ||
[5, 5, 5, 5, 5, 0, 1, 2, 3, 4]]) | ||
out = assortativity(edge_index) | ||
assert pytest.approx(out, abs=1e-5) == -1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import torch | ||
from torch_sparse import SparseTensor | ||
|
||
from torch_geometric.typing import Adj | ||
from torch_geometric.utils import coalesce, degree | ||
|
||
from .to_dense_adj import to_dense_adj | ||
|
||
|
||
def assortativity(edge_index: Adj) -> float: | ||
r"""The degree assortativity coefficient from the | ||
`"Mixing patterns in networks" | ||
<https://arxiv.org/abs/cond-mat/0209450>`_ paper. | ||
Assortativity in a network refers to the tendency of nodes to | ||
connect with other similar nodes over dissimilar nodes. | ||
It is computed from Pearson correlation coefficient of the node degrees. | ||
Args: | ||
edge_index (Tensor or SparseTensor): The graph connectivity. | ||
Returns: | ||
The value of the degree assortativity coefficient for the input | ||
graph :math:`\in [-1, 1]` | ||
Example: | ||
>>> edge_index = torch.tensor([[0, 1, 2, 3, 2], | ||
... [1, 2, 0, 1, 3]]) | ||
>>> assortativity(edge_index) | ||
-0.666667640209198 | ||
""" | ||
if isinstance(edge_index, SparseTensor): | ||
row, col, _ = edge_index.coo() | ||
else: | ||
row, col = edge_index | ||
|
||
device = row.device | ||
out_deg = degree(row, dtype=torch.long) | ||
in_deg = degree(col, dtype=torch.long) | ||
degrees = torch.unique(torch.cat([out_deg, in_deg])) | ||
mapping = row.new_zeros(degrees.max().item() + 1) | ||
mapping[degrees] = torch.arange(degrees.size(0), device=device) | ||
|
||
# Compute degree mixing matrix (joint probability distribution) `M` | ||
num_degrees = degrees.size(0) | ||
src_deg = mapping[out_deg[row]] | ||
dst_deg = mapping[in_deg[col]] | ||
|
||
pairs = torch.stack([src_deg, dst_deg], dim=0) | ||
occurrence = torch.ones(pairs.size(1), device=device) | ||
pairs, occurrence = coalesce(pairs, occurrence) | ||
M = to_dense_adj(pairs, edge_attr=occurrence, max_num_nodes=num_degrees)[0] | ||
# normalization | ||
M /= M.sum() | ||
|
||
# numeric assortativity coefficient, computed by | ||
# Pearson correlation coefficient of the node degrees | ||
x = y = degrees.float() | ||
a, b = M.sum(0), M.sum(1) | ||
|
||
vara = (a * x**2).sum() - ((a * x).sum())**2 | ||
varb = (b * x**2).sum() - ((b * x).sum())**2 | ||
xy = torch.outer(x, y) | ||
ab = torch.outer(a, b) | ||
out = (xy * (M - ab)).sum() / (vara * varb).sqrt() | ||
return out.item() |