Skip to content

Commit

Permalink
Fix astar bugs (#94)
Browse files Browse the repository at this point in the history
1) fix bugs when use astar with cuda
2) Resolve the error occurred when the number of nodes in graph1 inputted by astar is greater than the number of nodes in graph2
  • Loading branch information
heatingma authored Dec 30, 2023
1 parent d61001d commit fab69ab
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
26 changes: 19 additions & 7 deletions pygmtools/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,11 +926,21 @@ def forward(self, data: GraphPair):
max_nodes_num_2 = torch.max(data.g2.nodes_num) + 1
x_pred = torch.zeros(num, max_nodes_num_1, max_nodes_num_2)
for i in range(num):
cur_data = GraphPair(data.g1.x[i], data.g2.x[i], data.g1.adj[i], data.g2.adj[i],
data.g1.nodes_num[i], data.g2.nodes_num[i])
x1 = data.g1.x[i]
x2 = data.g2.x[i]
adj1 = data.g1.adj[i]
adj2 = data.g2.adj[i]
n1 = data.g1.nodes_num[i]
n2 = data.g2.nodes_num[i]
exchange = True if x1.shape[0] > x2.shape[0] else False
if not exchange:
cur_data = GraphPair(x1, x2, adj1, adj2, n1, n2)
else:
cur_data = GraphPair(x2, x1, adj2, adj1, n2, n1)
num_nodes_1 = data.g1.nodes_num[i] + 1
num_nodes_2 = data.g2.nodes_num[i] + 1
x_pred[i][:num_nodes_1, :num_nodes_2] = self._astar(cur_data)
_x_pred = self._astar(cur_data)
x_pred[i][:num_nodes_1, :num_nodes_2] = _x_pred.T if exchange else _x_pred
return x_pred[:, :-1, :-1]

def _astar(self, data: GraphPair):
Expand All @@ -945,8 +955,8 @@ def _astar(self, data: GraphPair):
edge_attr_2 = data.g2.edge_weight
node_1 = data.g1.x.squeeze()
node_2 = data.g2.x.squeeze()
batch_1 = data.g1.batch
batch_2 = data.g2.batch
batch_1 = data.g1.batch.to(device)
batch_2 = data.g2.batch.to(device)
batch_num = data.g1.num_graphs

ns_1 = torch.bincount(data.g1.batch)
Expand Down Expand Up @@ -1016,8 +1026,8 @@ def net_prediction_cache(self, data: GraphPair, partial_pmat=None, return_ged_no
"""
features_1 = data.g1.x.squeeze()
features_2 = data.g2.x.squeeze()
batch_1 = data.g1.batch
batch_2 = data.g2.batch
batch_1 = data.g1.batch.to(features_1.device)
batch_2 = data.g2.batch.to(features_2.device)
adj1 = data.g1.adj.squeeze()
adj2 = data.g2.adj.squeeze()

Expand Down Expand Up @@ -1123,6 +1133,8 @@ def astar_kernel(feat1, feat2, A1, A2, n1, n2, channel, filters_1, filters_2, fi

if network is None:
args = default_parameter()
if device != torch.device('cpu'):
args["cuda"] = True
if forward_pass:
if channel is None:
args['channel'] = feat1.shape[-1]
Expand Down
2 changes: 1 addition & 1 deletion pygmtools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,7 +1539,7 @@ def from_networkx(G: nx.Graph):
"""
is_directed = isinstance(G, nx.DiGraph)
adj_matrix = nx.to_numpy_matrix(G,nodelist=G.nodes()) if is_directed else nx.to_numpy_matrix(G)
adj_matrix = nx.to_numpy_matrix(G, nodelist=G.nodes()) if is_directed else nx.to_numpy_matrix(G)
return adj_matrix


Expand Down

0 comments on commit fab69ab

Please sign in to comment.