-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add unbatch_edge_index
#4903
Add unbatch_edge_index
#4903
Conversation
Codecov Report
@@ Coverage Diff @@
## master #4903 +/- ##
=======================================
Coverage 82.71% 82.72%
=======================================
Files 330 330
Lines 17891 17898 +7
=======================================
+ Hits 14799 14806 +7
Misses 3092 3092
Continue to review full report at Codecov.
|
Hi, @rusty1s. I rewrote the code per your suggestions, and it was much faster. Here are the benchmark results and code. Could you please review the current implementation when it is convenient for you?
import timeit
from typing import List
import torch
from torch import Tensor
from torch_geometric.data import Batch, Data
from torch_geometric.utils import degree, sort_edge_index
def unbatch_edge_index_v1(edge_index: Tensor, batch: Tensor) -> List[Tensor]:
boundary = torch.cumsum(degree(batch), dim=0)
inc = torch.cat([boundary.new_tensor([0]), boundary[:-1]], dim=0)
edge_assignments = torch.bucketize(edge_index, boundary, right=True)
out = [
(edge_index[edge_assignments == batch_idx].view(2, -1) - inc[batch_idx]).to(
torch.int64
)
for batch_idx in range(batch.max().item() + 1)
]
return out
def unbatch_edge_index_v2(edge_index: Tensor, batch: Tensor) -> List[Tensor]:
boundary = torch.cumsum(degree(batch), dim=0)
inc = torch.cat([boundary.new_tensor([0]), boundary[:-1]], dim=0)
edge_batch = batch[edge_index[0]]
sizes = degree(edge_batch).long().cpu().tolist()
out = [
(edge - inc[batch_idx]).to(torch.int64)
for batch_idx, edge in enumerate(edge_index.split(sizes, dim=1))
]
return out
def unbatch_edge_index_v3(edge_index: Tensor, batch: Tensor) -> List[Tensor]:
boundary = torch.cumsum(degree(batch), dim=0)
inc = torch.cat([boundary.new_tensor([0]), boundary[:-1]], dim=0)
edge_index = sort_edge_index(edge_index)
edge_batch = batch[edge_index[0]]
sizes = degree(edge_batch).long().cpu().tolist()
out = [
(edge - inc[batch_idx]).to(torch.int64)
for batch_idx, edge in enumerate(edge_index.split(sizes, dim=1))
]
return out
if __name__ == "__main__":
for batch_size in [1, 10, 100, 1000]:
for num_edges in [1, 10, 100, 1000]:
data = [
Data(
edge_index=torch.randint(
0, num_edges, [2, num_edges], dtype=torch.int64
),
num_nodes=num_edges,
)
for _ in range(batch_size)
]
c = Batch.from_data_list(data)
num_tries = 100
result_v1 = (
timeit.timeit(
"unbatch_edge_index_v1(c.edge_index, c.batch)",
setup="from __main__ import unbatch_edge_index_v1, c",
number=num_tries,
)
/ num_tries
)
result_v2 = (
timeit.timeit(
"unbatch_edge_index_v2(c.edge_index, c.batch)",
setup="from __main__ import unbatch_edge_index_v2, c",
number=num_tries,
)
/ num_tries
)
result_v3 = (
timeit.timeit(
"unbatch_edge_index_v3(c.edge_index, c.batch)",
setup="from __main__ import unbatch_edge_index_v3, c",
number=num_tries,
)
/ num_tries
)
print(
f"Batch size: {batch_size}, Number of edges: {num_edges}, v1={result_v1}, v2={result_v2}, v3={result_v3}"
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
Add a utility function to split
edge_index
according to thebatch
vector.Related: #4717