Skip to content

Commit

Permalink
Correct unit tests to run attention functions instead of Attention mo…
Browse files Browse the repository at this point in the history
…dule in order to avoid 'final' init on outputs
  • Loading branch information
christinaflo committed Sep 20, 2023
1 parent 710088d commit 6ebcd8b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 30 deletions.
30 changes: 16 additions & 14 deletions tests/test_deepspeed_evo_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
import pickle

from openfold.model.primitives import (
Attention,
_attention,
_deepspeed_evo_attn
)
from tests.config import consts
import tests.compare_utils as compare_utils
Expand All @@ -43,22 +44,26 @@ def test_ds_kernel_vs_attention(self):
n = 2 ** 12
n_seq = 12
no_heads = 4
dtype = torch.bfloat16

q = torch.rand(batch_size, n_seq, n, c_hidden).cuda()
kv = torch.rand(batch_size, n_seq, n, c_hidden).cuda()
q = torch.rand(batch_size, n_seq, n, no_heads, c_hidden, dtype=dtype).cuda()
k = torch.rand(batch_size, n_seq, n, no_heads, c_hidden, dtype=dtype).cuda()
v = torch.rand(batch_size, n_seq, n, no_heads, c_hidden, dtype=dtype).cuda()

bias = [torch.rand(batch_size, n_seq, 1, 1, n), torch.rand(batch_size, 1, no_heads, n, n)]
bias = [b.cuda() for b in bias]

a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
bias = [b.to(dtype=dtype).cuda() for b in bias]

with torch.no_grad():
l = a(q, kv, biases=bias, use_deepspeed_evo_attention=True)
real = a(q, kv, biases=bias)
l = _deepspeed_evo_attn(q, k, v, biases=bias).cpu()

q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
real = _attention(q, k, v, biases=bias)
real = real.transpose(-2, -3).cpu()

self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps)
err = torch.max(torch.abs(l - real))
self.assertTrue(err < consts.eps, f'Error: {err}')

def compare_evoformer(self, dtype):
"""
Expand Down Expand Up @@ -112,17 +117,14 @@ def compare_evoformer(self, dtype):
self.assertTrue(torch.allclose(torch.abs(out_repro_msa), torch.abs(out_repro_msa_ds), atol=eps))
self.assertTrue(torch.allclose(torch.abs(out_repro_pair), torch.abs(out_repro_pair_ds), atol=eps))

@unittest.skip('Temporarily disabled')
def test_compare_evoformer_bf16(self):
"""Run evoformer comparison test with BF16 precision."""
self.compare_evoformer(torch.bfloat16)

@unittest.skip('Temporarily disabled')
def test_compare_evoformer_fp32(self):
"""Run evoformer comparison test with FP32 precision."""
self.compare_evoformer(torch.float32)

@unittest.skip('Temporarily disabled')
def test_compare_model(self):
"""
Run full model with and without using DeepSpeed Evoformer attention kernel
Expand Down
35 changes: 19 additions & 16 deletions tests/test_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
import unittest

from openfold.model.primitives import (
Attention,
_lma,
_attention,
DEFAULT_LMA_Q_CHUNK_SIZE,
DEFAULT_LMA_KV_CHUNK_SIZE
)
from tests.config import consts

Expand All @@ -27,26 +30,26 @@ def test_lma_vs_attention(self):
batch_size = consts.batch_size
c_hidden = 32
n = 2**12
n_seq = 12
no_heads = 4

q = torch.rand(batch_size, n, c_hidden).cuda()
kv = torch.rand(batch_size, n, c_hidden).cuda()
q = torch.rand(batch_size, n_seq, no_heads, n, c_hidden).cuda()
k = torch.rand(batch_size, n_seq, no_heads, n, c_hidden).cuda()
v = torch.rand(batch_size, n_seq, no_heads, n, c_hidden).cuda()

bias = [torch.rand(no_heads, 1, n)]
bias = [b.cuda() for b in bias]

gating_fill = torch.rand(c_hidden * no_heads, c_hidden)
o_fill = torch.rand(c_hidden, c_hidden * no_heads)

a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
bias = [torch.rand(batch_size, n_seq, 1, 1, n), torch.rand(batch_size, 1, no_heads, n, n)]
biases = [b.cuda() for b in bias]

with torch.no_grad():
l = a(q, kv, biases=bias, use_lma=True)
real = a(q, kv, biases=bias)

self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps)
lma_biases = [
b.expand(b.shape[:-2] + (q.shape[-2],) + (k.shape[-2],))
for b in biases
]
l = _lma(q, k, v, lma_biases, DEFAULT_LMA_Q_CHUNK_SIZE, DEFAULT_LMA_KV_CHUNK_SIZE).cpu()
real = _attention(q, k, v, biases).cpu()

err = torch.max(torch.abs(l - real))
self.assertTrue(err < consts.eps, f'Error: {err}')


if __name__ == "__main__":
Expand Down

0 comments on commit 6ebcd8b

Please sign in to comment.