-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support PyG #93
Merged
Merged
Support PyG #93
Changes from 58 commits
Commits
Show all changes
59 commits
Select commit
Hold shift + click to select a range
5d82d97
change python tests/test_a_star/prepare_for_test.py to python pygmtoo…
heatingma 4c05275
add bdist_wheel
heatingma 4155230
delete test/test_a_star
heatingma b083a7b
add astar ori_files
heatingma 678a49e
Create publish_req.txt
heatingma 6cf8381
Delete a_star.tar.gz
heatingma 3678388
version is a23
heatingma ef3f41c
Update setup.py
heatingma 439c757
delete some notes
heatingma 6e0ea0e
add alternate url to fix download problem
heatingma 18d3873
add return
heatingma be38bf0
fix the md5 problem with genn_astar pretrained models
heatingma d07498e
add the missed ','
heatingma b5f3133
fix the "diff=200.0" problem
heatingma 57230e8
swap the url and the url_alter
heatingma 55379dc
change the astar_pretrain_path
heatingma bde8106
Revert "swap the url and the url_alter"
heatingma 220f782
Revert "change the astar_pretrain_path"
heatingma 3143592
change the pretrained path
heatingma 61181f8
only small files has url_alter
heatingma 0642a14
add new url for pretrained modles
heatingma 0eec2ee
add new download pretrained models' paths for jittor backend
heatingma 52143ae
add new download pretrained models' paths for jittor backend
heatingma 329a115
add new download pretrained models' paths for jittor backend
heatingma f3a78c0
add new download path for pytorch backend pretrained models
heatingma d8f821b
add new download path for paddle backend
heatingma c368697
delete some unused url
heatingma 0ae85c7
add new alternate download path for cie and pca
heatingma 35a4e02
don't test neural_solvers now
heatingma 6f42e7c
only test neural_solvers
heatingma 62f2bef
only test neural_solvers
heatingma 9c76481
only test neural_solvers
heatingma a591e7e
only test neural
heatingma ceeaf60
add the forget ","
heatingma 18c98a9
add all tests
heatingma f408a51
add new url_path and change the download func
heatingma 7e4a9de
delete dropout, trust_fact and no_pred_size for astar
heatingma 6e6f773
delete dropout for genn_astar
heatingma a3ef42e
delete some parameters for astar and genn_astar
heatingma 86ab428
Merge branch 'Thinklab-SJTU:main' into main
heatingma ba4d408
python
heatingma 6df2f3f
Merge branch 'Thinklab-SJTU:main' into main
heatingma 66a4b05
Merge branch 'Thinklab-SJTU:main' into main
heatingma 37f88af
Merge branch 'Thinklab-SJTU:main' into main
heatingma 8616fb4
change a_star to astar
heatingma de06854
change the function name "astar" from cython to "c_astar"
heatingma b0c8c0d
fix: 'module' object is not callable
heatingma 57b0c26
change astar to c_astar
heatingma 3a1373a
Merge branch 'Thinklab-SJTU:main' into main
heatingma 086e182
Merge branch 'Thinklab-SJTU:main' into main
heatingma ed731b3
Merge branch 'Thinklab-SJTU:main' into main
heatingma e6237d9
Fix: for the same k, node_cost is double counted
heatingma cfe5b5d
Merge branch 'main' of https://github.com/heatingma/pygmtools
heatingma 8bb5841
Add alter urls for datasets
heatingma 1b44be0
fix: No such file or directory: 'data/SPair-71k/Layout/small/trn.txt'
heatingma b69de38
Merge branch 'Thinklab-SJTU:main' into main
heatingma a07035f
Merge branch 'Thinklab-SJTU:main' into main
heatingma fd122a0
support pyg
heatingma e743cbb
fix "Support for PyG"
heatingma File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1347,9 +1347,9 @@ | |
################################################### | ||
|
||
|
||
def build_aff_mat_from_networkx(G1:nx.Graph, G2:nx.Graph, node_aff_fn=None, edge_aff_fn=None, backend=None): | ||
def build_aff_mat_from_networkx(G1: nx.Graph, G2: nx.Graph, node_aff_fn=None, edge_aff_fn=None, backend=None): | ||
r""" | ||
Convert networkx object to Adjacency matrix | ||
Convert networkx object to affinity matrix | ||
|
||
:param G1: networkx object, whose type must be networkx.Graph | ||
:param G2: networkx object, whose type must be networkx.Graph | ||
|
@@ -1381,7 +1381,7 @@ | |
# Obtain Affinity Matrix | ||
>>> K = pygm.utils.build_aff_mat_from_networkx(G1, G2) | ||
>>> K.shape | ||
(20,20) | ||
(20, 20) | ||
|
||
# The affinity matrices K can be further processed by GM solvers | ||
""" | ||
|
@@ -1397,7 +1397,7 @@ | |
|
||
def build_aff_mat_from_graphml(G1_path, G2_path, node_aff_fn=None, edge_aff_fn=None, backend=None): | ||
r""" | ||
Convert networkx object to Adjacency matrix | ||
Convert networkx object to affinity matrix | ||
|
||
:param G1_path: The file path of the graphml object | ||
:param G2_path: The file path of the graphml object | ||
|
@@ -1427,7 +1427,7 @@ | |
# Obtain Affinity Matrix | ||
>>> K = pygm.utils.build_aff_mat_from_graphml(G1_path, G2_path) | ||
>>> K.shape | ||
(121,121) | ||
(121, 121) | ||
|
||
# The affinity matrices K can be further processed by GM solvers | ||
""" | ||
|
@@ -1441,9 +1441,9 @@ | |
return K | ||
|
||
|
||
def from_networkx(G:nx.Graph): | ||
def from_networkx(G: nx.Graph): | ||
r""" | ||
Convert networkx object to Adjacency matrix | ||
Convert networkx object to adjacency matrix | ||
|
||
:param G: networkx object, whose type must be networkx.Graph | ||
:return: the adjacency matrix corresponding to the networkx object | ||
|
@@ -1522,7 +1522,7 @@ | |
|
||
def from_graphml(filename): | ||
r""" | ||
Convert graphml object to Adjacency matrix | ||
Convert graphml object to adjacency matrix | ||
|
||
:param filename: graphml file path | ||
:return: the adjacency matrix corresponding to the graphml object | ||
|
@@ -1545,7 +1545,7 @@ | |
|
||
>>> G1 = pygm.utils.from_graphml(G2_path) | ||
>>> G2.shape | ||
(11,11) | ||
(11, 11) | ||
""" | ||
if not filename.endswith('.graphml'): | ||
raise ValueError("File name should end with '.graphml'") | ||
|
@@ -1591,3 +1591,104 @@ | |
""" | ||
nx.write_graphml(to_networkx(adj_matrix, backend), filename) | ||
|
||
|
||
################################################### | ||
# Support PyG # | ||
################################################### | ||
|
||
|
||
def build_aff_mat_from_pyg(G1, G2, node_aff_fn=None, edge_aff_fn=None, backend=None): | ||
r""" | ||
Convert torch_geometric.data.Data object to affinity matrix | ||
|
||
:param G1: Graph object, whose type must be torch_geometric.data.Data | ||
:param G2: Graph object, whose type must be torch_geometric.data.Data | ||
:param node_aff_fn: (default: inner_prod_aff_fn) the node affinity function with the characteristic | ||
``node_aff_fn(2D Tensor, 2D Tensor) -> 2D Tensor``, which accepts two node feature tensors and | ||
outputs the node-wise affinity tensor. See :func:`~pygmtools.utils.inner_prod_aff_fn` as an | ||
example. | ||
:param edge_aff_fn: (default: inner_prod_aff_fn) the edge affinity function with the characteristic | ||
``edge_aff_fn(2D Tensor, 2D Tensor) -> 2D Tensor``, which accepts two edge feature tensors and | ||
outputs the edge-wise affinity tensor. See :func:`~pygmtools.utils.inner_prod_aff_fn` as an | ||
example. | ||
:param backend: (default: ``pygmtools.BACKEND`` variable) the backend for computation. | ||
:return: the affinity matrix corresponding to the networkx object G1 and G2 | ||
|
||
.. dropdown:: Example | ||
|
||
:: | ||
|
||
>>> import networkx as nx | ||
>>> from torch_geometric.data import Data | ||
>>> import pygmtools as pygm | ||
>>> pygm.set_backend('pytorch') | ||
|
||
# Generate Graph object | ||
>>> x1 = torch.rand((4, 2), dtype=torch.float) | ||
>>> e1 = torch.tensor([[0, 0, 1, 1, 2, 2, 3], [1, 2, 0, 2, 0, 3, 1]], dtype=torch.long) | ||
>>> G1 = Data(x=x1, edge_index=e1) | ||
>>> x2 = torch.rand((5, 2), dtype=torch.float) | ||
>>> e2 = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 4, 4], [1, 3, 2, 3, 1, 3, 4, 2, 3]], dtype=torch.long) | ||
>>> G2 = Data(x=x2, edge_index=e2) | ||
|
||
# Obtain Affinity Matrix | ||
>>> K = pygm.utils.build_aff_mat_from_pyg(G1, G2) | ||
>>> K.shape | ||
(20, 20) | ||
|
||
# The affinity matrices K can be further processed by GM solvers | ||
""" | ||
from torch_geometric.data import Data | ||
assert type(G1) == Data, f"G1 must be torch_geometric.data.Data" | ||
assert type(G2) == Data, f"G2 must be torch_geometric.data.Data" | ||
if backend is None: | ||
backend = 'pytorch' | ||
else: | ||
assert backend == 'pytorch', f"Function 'build_aff_mat_from_pyg' only supports pytorch backend." | ||
pygmtools.set_backend(backend) | ||
node1 = G1.x | ||
edge1 = G1.edge_attr | ||
conn1 = G1.edge_index | ||
node2 = G2.x | ||
edge2 = G2.edge_attr | ||
conn2 = G2.edge_index | ||
K = build_aff_mat(node1, edge1, conn1, node2, edge2, conn2, node_aff_fn=node_aff_fn, edge_aff_fn=edge_aff_fn, backend=backend) | ||
return K | ||
|
||
|
||
def to_pyg(adj_matrix, backend=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should there be "from_pyg" and "to_networkx", "to_graphml"? |
||
""" | ||
Convert adjacency matrix to torch_geometric.data.Data object | ||
|
||
:param adj_matrix: the adjacency matrix to convert, whose type must be torch.Tensor | ||
:param backend: (default: ``pygmtools.BACKEND`` variable) the backend for computation. | ||
:return: the torch_geometric.data.Data object corresponding to the adjacency matrix | ||
|
||
.. dropdown:: Example | ||
|
||
:: | ||
|
||
>>> import torch | ||
>>> import pygmtools as pygm | ||
>>> pygm.set_backend('pytorch') | ||
|
||
# Generate adjacency matrix | ||
>>> adj_matrix = torch.rand((4, 4)) | ||
|
||
# Obtain torch_geometric.data.Data object | ||
>>> pygm.utils.to_pyg(adj_matrix) | ||
Data(edge_index=[16, 2], edge_attr=[16, 1]) | ||
""" | ||
import torch | ||
from torch_geometric.data import Data | ||
if backend is None: | ||
backend = 'pytorch' | ||
else: | ||
assert backend == 'pytorch', f"Function 'build_aff_mat_from_pyg' only supports pytorch backend." | ||
pygmtools.set_backend(backend) | ||
assert type(adj_matrix) == torch.Tensor, f"the adj_matrix's type must be torch.Tensor" | ||
|
||
conn1, edge1 = dense_to_sparse(adj_matrix, backend=backend) | ||
G = Data(x=None, edge_index=conn1, edge_attr=edge1) | ||
return G | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -313,7 +313,30 @@ def _test_graphml(graph_num_nodes, backends): | |
assert accuracy == 1, f'When testing the graphml function with rrwm algorithm, there is an error in accuracy, \ | ||
and the accuracy is {accuracy}, the num_node is {num_node},.' | ||
|
||
|
||
|
||
# The testing function for networkx | ||
def _test_pyg(graph_num_nodes, backends): | ||
""" | ||
Test the RRWM algorithm on pairs of isomorphic graphs using NetworkX | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NetworkX -> PyG |
||
|
||
:param graph_num_nodes: list, the numbers of nodes in the graphs to test | ||
""" | ||
for working_backend in backends: | ||
pygm.BACKEND = working_backend | ||
for num_node in tqdm(graph_num_nodes): | ||
As_b, X_gt = pygm.utils.generate_isomorphic_graphs(num_node) | ||
X_gt = pygm.utils.to_numpy(X_gt, backend=working_backend) | ||
A1 = As_b[0] | ||
A2 = As_b[1] | ||
G1 = pygm.utils.to_pyg(A1) | ||
G2 = pygm.utils.to_pyg(A2) | ||
K = pygm.utils.build_aff_mat_from_pyg(G1, G2) | ||
X = pygm.rrwm(K, n1=num_node, n2=num_node) | ||
accuracy = (pygm.utils.to_numpy(pygm.hungarian(X, num_node, num_node)) * X_gt).sum() / X_gt.sum() | ||
assert accuracy == 1, f'When testing the networkx function with rrwm algorithm, there is an error in accuracy, \ | ||
and the accuracy is {accuracy}, the num_node is {num_node},.' | ||
|
||
|
||
def test_hungarian(get_backend): | ||
backends = get_backends(get_backend) | ||
_test_classic_solver_on_linear_assignment(list(range(10, 30, 2)), list(range(30, 10, -2)), 10, pygm.hungarian, { | ||
|
@@ -477,6 +500,11 @@ def test_graphml(): | |
_test_graphml(list(range(10, 30, 2)), backends=backends) | ||
|
||
|
||
def test_pyg(): | ||
backends = ['pytorch'] | ||
_test_pyg(list(range(10, 30, 2)), backends=backends) | ||
|
||
|
||
if __name__ == '__main__': | ||
test_hungarian('all') | ||
test_sinkhorn('all') | ||
|
@@ -486,3 +514,4 @@ def test_graphml(): | |
test_astar('') | ||
test_networkx() | ||
test_graphml() | ||
test_pyg() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please replace assert with ValueError