Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unbatch_edge_index #4903

Merged
merged 10 commits into from
Jul 11, 2022
Merged

Add unbatch_edge_index #4903

merged 10 commits into from
Jul 11, 2022

Conversation

chenzhekl
Copy link
Contributor

Add a utility function to split edge_index according to the batch vector.

Related: #4717

@rusty1s rusty1s changed the title Add unbatch_edge_index Add unbatch_edge_index Jul 1, 2022
test/utils/test_unbatch_edge_index.py Outdated Show resolved Hide resolved
torch_geometric/utils/unbatch_edge_index.py Outdated Show resolved Hide resolved
torch_geometric/utils/unbatch_edge_index.py Outdated Show resolved Hide resolved
torch_geometric/utils/unbatch_edge_index.py Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Jul 2, 2022

Codecov Report

Merging #4903 (9f74a62) into master (64d44fe) will increase coverage by 0.00%.
The diff coverage is 100.00%.

@@           Coverage Diff           @@
##           master    #4903   +/-   ##
=======================================
  Coverage   82.71%   82.72%           
=======================================
  Files         330      330           
  Lines       17891    17898    +7     
=======================================
+ Hits        14799    14806    +7     
  Misses       3092     3092           
Impacted Files Coverage Δ
torch_geometric/utils/__init__.py 100.00% <100.00%> (ø)
torch_geometric/utils/unbatch.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 64d44fe...9f74a62. Read the comment docs.

@chenzhekl
Copy link
Contributor Author

chenzhekl commented Jul 6, 2022

Hi, @rusty1s. I rewrote the code per your suggestions, and it was much faster. Here are the benchmark results and code. Could you please review the current implementation when it is convenient for you?

Batch size: 1, Number of edges: 1, v1=0.00015095502138137817, v2=0.00013493932783603668, v3=0.00022899847477674483
Batch size: 1, Number of edges: 10, v1=0.00012184974737465381, v2=0.00012676811777055264, v3=0.00017614605836570264
Batch size: 1, Number of edges: 100, v1=0.0001273123547434807, v2=0.0001291760616004467, v3=0.00018301083706319333
Batch size: 1, Number of edges: 1000, v1=0.00018263477832078934, v2=0.000146225243806839, v3=0.000257051857188344
Batch size: 10, Number of edges: 1, v1=0.00041488190181553363, v2=0.00022862684912979602, v3=0.0002795036416500807
Batch size: 10, Number of edges: 10, v1=0.0004219593573361635, v2=0.00023178191855549813, v3=0.0002914824988692999
Batch size: 10, Number of edges: 100, v1=0.0005695968773216009, v2=0.00027313867583870886, v3=0.00036626373417675495
Batch size: 10, Number of edges: 1000, v1=0.0013866874296218157, v2=0.0004070508200675249, v3=0.001070710765197873
Batch size: 100, Number of edges: 1, v1=0.0034643872361630203, v2=0.0012199555058032274, v3=0.0013013050612062215
Batch size: 100, Number of edges: 10, v1=0.00422879139892757, v2=0.0012447251845151186, v3=0.0013499038200825452
Batch size: 100, Number of edges: 100, v1=0.010962332878261805, v2=0.0014868732821196318, v3=0.0021350229624658825
Batch size: 100, Number of edges: 1000, v1=0.021530504962429406, v2=0.0027764803543686867, v3=0.008312329994514584
Batch size: 1000, Number of edges: 1, v1=0.04029915944673121, v2=0.011021380107849836, v3=0.011044777976348997
Batch size: 1000, Number of edges: 10, v1=0.10351377192884684, v2=0.011458314461633564, v3=0.012150439545512199
Batch size: 1000, Number of edges: 100, v1=0.17790262822061778, v2=0.013426397806033491, v3=0.017951723579317333
Batch size: 1000, Number of edges: 1000, v1=0.45610245333984495, v2=0.026584064653143287, v3=0.08407091032713651
import timeit
from typing import List

import torch
from torch import Tensor
from torch_geometric.data import Batch, Data
from torch_geometric.utils import degree, sort_edge_index


def unbatch_edge_index_v1(edge_index: Tensor, batch: Tensor) -> List[Tensor]:
    boundary = torch.cumsum(degree(batch), dim=0)
    inc = torch.cat([boundary.new_tensor([0]), boundary[:-1]], dim=0)

    edge_assignments = torch.bucketize(edge_index, boundary, right=True)

    out = [
        (edge_index[edge_assignments == batch_idx].view(2, -1) - inc[batch_idx]).to(
            torch.int64
        )
        for batch_idx in range(batch.max().item() + 1)
    ]

    return out


def unbatch_edge_index_v2(edge_index: Tensor, batch: Tensor) -> List[Tensor]:
    boundary = torch.cumsum(degree(batch), dim=0)
    inc = torch.cat([boundary.new_tensor([0]), boundary[:-1]], dim=0)

    edge_batch = batch[edge_index[0]]
    sizes = degree(edge_batch).long().cpu().tolist()

    out = [
        (edge - inc[batch_idx]).to(torch.int64)
        for batch_idx, edge in enumerate(edge_index.split(sizes, dim=1))
    ]

    return out


def unbatch_edge_index_v3(edge_index: Tensor, batch: Tensor) -> List[Tensor]:
    boundary = torch.cumsum(degree(batch), dim=0)
    inc = torch.cat([boundary.new_tensor([0]), boundary[:-1]], dim=0)

    edge_index = sort_edge_index(edge_index)

    edge_batch = batch[edge_index[0]]
    sizes = degree(edge_batch).long().cpu().tolist()

    out = [
        (edge - inc[batch_idx]).to(torch.int64)
        for batch_idx, edge in enumerate(edge_index.split(sizes, dim=1))
    ]

    return out


if __name__ == "__main__":
    for batch_size in [1, 10, 100, 1000]:
        for num_edges in [1, 10, 100, 1000]:
            data = [
                Data(
                    edge_index=torch.randint(
                        0, num_edges, [2, num_edges], dtype=torch.int64
                    ),
                    num_nodes=num_edges,
                )
                for _ in range(batch_size)
            ]
            c = Batch.from_data_list(data)

            num_tries = 100

            result_v1 = (
                timeit.timeit(
                    "unbatch_edge_index_v1(c.edge_index, c.batch)",
                    setup="from __main__ import unbatch_edge_index_v1, c",
                    number=num_tries,
                )
                / num_tries
            )

            result_v2 = (
                timeit.timeit(
                    "unbatch_edge_index_v2(c.edge_index, c.batch)",
                    setup="from __main__ import unbatch_edge_index_v2, c",
                    number=num_tries,
                )
                / num_tries
            )

            result_v3 = (
                timeit.timeit(
                    "unbatch_edge_index_v3(c.edge_index, c.batch)",
                    setup="from __main__ import unbatch_edge_index_v3, c",
                    number=num_tries,
                )
                / num_tries
            )

            print(
                f"Batch size: {batch_size}, Number of edges: {num_edges}, v1={result_v1}, v2={result_v2}, v3={result_v3}"
            )

@rusty1s rusty1s added the data label Jul 7, 2022
Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@rusty1s rusty1s merged commit 423b923 into pyg-team:master Jul 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants