Skip to content

Commit

Permalink
Add sampled_op benchmark (#160)
Browse files Browse the repository at this point in the history
```
Function: add
=========================
Vanilla forward:  0.3757s
Vanilla backward: 0.4279s
=========================
pyg_lib forward:  0.1756s
pyg_lib backward: 0.3006s

Function: sub
=========================
Vanilla forward:  0.3762s
Vanilla backward: 0.4979s
=========================
pyg_lib forward:  0.1764s
pyg_lib backward: 0.3141s

Function: mul
=========================
Vanilla forward:  0.3755s
Vanilla backward: 0.7078s
=========================
pyg_lib forward:  0.1775s
pyg_lib backward: 0.6206s

Function: div
=========================
Vanilla forward:  0.3758s
Vanilla backward: 1.1376s
=========================
pyg_lib forward:  0.1868s
pyg_lib backward: 1.0707s
```
  • Loading branch information
rusty1s authored Dec 6, 2022
1 parent 8f5b0b9 commit 511e8f6
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.2.0] - 2023-MM-DD
### Added
- Added `sampled_op` impementation ([#156](https://github.com/pyg-team/pyg-lib/pull/156), [#159](https://github.com/pyg-team/pyg-lib/pull/159))
- Added `sampled_op` impementation ([#156](https://github.com/pyg-team/pyg-lib/pull/156), [#159](https://github.com/pyg-team/pyg-lib/pull/159), [#160](https://github.com/pyg-team/pyg-lib/pull/160))
### Changed
- Improved `[segment|grouped]_matmul` CPU implementation via `at::matmul_out` and MKL BLAS `gemm_batch` ([#146](https://github.com/pyg-team/pyg-lib/pull/146))
### Removed
Expand Down
89 changes: 89 additions & 0 deletions benchmark/ops/sampled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import argparse
import time

import torch

import pyg_lib

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--backward', action='store_true')
args = parser.parse_args()

num_nodes, num_edges, num_feats = 10000, 50000, 64

num_warmups, num_steps = 500, 1000
if args.device == 'cpu':
num_warmups, num_steps = num_warmups // 10, num_steps // 10

a_index = torch.randint(0, num_nodes, (num_edges, ), device=args.device)
b_index = torch.randint(0, num_nodes, (num_edges, ), device=args.device)
out_grad = torch.randn(num_edges, num_feats, device=args.device)

for fn in ['add', 'sub', 'mul', 'div']:
print(f'Function: {fn}')
print('=========================')

op = getattr(torch, fn)
t_forward = t_backward = 0
for i in range(num_warmups + num_steps):
a = torch.randn(num_nodes, num_feats, device=args.device)
b = torch.randn(num_nodes, num_feats, device=args.device)
if args.backward:
a.requires_grad_(True)
b.requires_grad_(True)

torch.cuda.synchronize()
t_start = time.perf_counter()

out = op(a[a_index], b[b_index])

torch.cuda.synchronize()
if i >= num_warmups:
t_forward += time.perf_counter() - t_start

if args.backward:
t_start = time.perf_counter()
out.backward(out_grad)

torch.cuda.synchronize()
if i >= num_warmups:
t_backward += time.perf_counter() - t_start

print(f'Vanilla forward: {t_forward:.4f}s')
if args.backward:
print(f'Vanilla backward: {t_backward:.4f}s')
print('=========================')

op = getattr(pyg_lib.ops, f'sampled_{fn}')
t_forward = t_backward = 0
for i in range(num_warmups + num_steps):
a = torch.randn(num_nodes, num_feats, device=args.device)
b = torch.randn(num_nodes, num_feats, device=args.device)
if args.backward:
a.requires_grad_(True)
b.requires_grad_(True)

torch.cuda.synchronize()
t_start = time.perf_counter()

out = op(a, b, a_index, b_index)

torch.cuda.synchronize()
if i >= num_warmups:
t_forward += time.perf_counter() - t_start

if args.backward:
t_start = time.perf_counter()
out.backward(out_grad)

torch.cuda.synchronize()
if i >= num_warmups:
t_backward += time.perf_counter() - t_start

print(f'pyg_lib forward: {t_forward:.4f}s')
if args.backward:
print(f'pyg_lib backward: {t_backward:.4f}s')

print()

0 comments on commit 511e8f6

Please sign in to comment.