Skip to content
13 changes: 13 additions & 0 deletions deepspeed/profiling/flops_profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,15 @@ def _elementwise_flops_compute(input, other):
return flops, 0


def _attn_flops_compute(q, k, v, *args, **kwargs):
"""
Count flops for the scaled_dot_product_attention operation.
"""
macs = _prod(q.shape) * k.shape[-2]
macs += _prod(q.shape[:-1]) * k.shape[-2] * v.shape[-1]
return 2 * macs, macs


def wrapFunc(func, funcFlopCompute):
oldFunc = func
name = func.__str__
Expand Down Expand Up @@ -899,10 +908,14 @@ def _patch_functionals():
# embedding
F.embedding = wrapFunc(F.embedding, _embedding_flops_compute)

# attn
F.scaled_dot_product_attention = wrapFunc(F.scaled_dot_product_attention, _attn_flops_compute)


def _patch_tensor_methods():
torch.matmul = wrapFunc(torch.matmul, _matmul_flops_compute)
torch.Tensor.matmul = wrapFunc(torch.Tensor.matmul, _matmul_flops_compute)
torch.Tensor.__matmul__ = wrapFunc(torch.Tensor.__matmul__, _matmul_flops_compute)
torch.mm = wrapFunc(torch.mm, _matmul_flops_compute)
torch.Tensor.mm = wrapFunc(torch.Tensor.mm, _matmul_flops_compute)
torch.bmm = wrapFunc(torch.bmm, _matmul_flops_compute)
Expand Down