Skip to content

Commit 98c96e7

Browse files
KimmiShiloadams
andauthored
Update flops profiler to handle attn and __matmul__ (#4724)
Fixes #4723 - handle `F.scaled_dot_product_attention` in transformer models. - handle expreesions like `a@b` --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
1 parent b00533e commit 98c96e7

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

deepspeed/profiling/flops_profiler/profiler.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,15 @@ def _elementwise_flops_compute(input, other):
827827
return flops, 0
828828

829829

830+
def _attn_flops_compute(q, k, v, *args, **kwargs):
831+
"""
832+
Count flops for the scaled_dot_product_attention operation.
833+
"""
834+
macs = _prod(q.shape) * k.shape[-2]
835+
macs += _prod(q.shape[:-1]) * k.shape[-2] * v.shape[-1]
836+
return 2 * macs, macs
837+
838+
830839
def wrapFunc(func, funcFlopCompute):
831840
oldFunc = func
832841
name = func.__str__
@@ -899,10 +908,14 @@ def _patch_functionals():
899908
# embedding
900909
F.embedding = wrapFunc(F.embedding, _embedding_flops_compute)
901910

911+
# attn
912+
F.scaled_dot_product_attention = wrapFunc(F.scaled_dot_product_attention, _attn_flops_compute)
913+
902914

903915
def _patch_tensor_methods():
904916
torch.matmul = wrapFunc(torch.matmul, _matmul_flops_compute)
905917
torch.Tensor.matmul = wrapFunc(torch.Tensor.matmul, _matmul_flops_compute)
918+
torch.Tensor.__matmul__ = wrapFunc(torch.Tensor.__matmul__, _matmul_flops_compute)
906919
torch.mm = wrapFunc(torch.mm, _matmul_flops_compute)
907920
torch.Tensor.mm = wrapFunc(torch.Tensor.mm, _matmul_flops_compute)
908921
torch.bmm = wrapFunc(torch.bmm, _matmul_flops_compute)

0 commit comments

Comments
 (0)