Skip to content

Commit

Permalink
[NPU]testcase for fa with gqa (#1122)
Browse files Browse the repository at this point in the history
  • Loading branch information
iansheng authored Apr 10, 2024
1 parent 75e6171 commit 4091df3
Showing 1 changed file with 104 additions and 2 deletions.
106 changes: 104 additions & 2 deletions backends/npu/tests/unittests/test_flashattention_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@


def attention_naive(q, k, v, mask=None):
if q.shape[1] != k.shape[1]:
# GQA
n_rep = q.shape[1] // k.shape[1]
k = repeat_kv(k, n_rep)
v = repeat_kv(v, n_rep)
scale = 1.0 / np.sqrt(q.shape[-1])
s = paddle.matmul(q, paddle.transpose(k, [0, 1, 3, 2]))
s = paddle.scale(s, scale)
Expand All @@ -44,6 +49,17 @@ def attention_naive(q, k, v, mask=None):
return paddle.transpose(o, [0, 2, 1, 3])


def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states

hidden_states = hidden_states[:, :, None, :, :].expand(
[batch, num_key_value_heads, n_rep, slen, head_dim]
)
return hidden_states.reshape([batch, num_key_value_heads * n_rep, slen, head_dim])


class TestNPUFAFP16(unittest.TestCase):
def setUp(self):
self.npu_place = paddle.CustomPlace("npu", 0)
Expand All @@ -56,6 +72,7 @@ def setUp(self):
self.return_softmax = False
self.is_test = False
self.is_triangle_upper_mask = True
self.pass_line = 0.9999
self.init_dtype()

def init_dtype(self):
Expand All @@ -77,8 +94,16 @@ def check_result(self, golden_res, fused_res):
)
golden_y, golden_dx = golden_res
fused_y, fused_dx = fused_res
np.testing.assert_allclose(golden_y, fused_y, rtol=rtol, atol=atol)
np.testing.assert_allclose(golden_dx, fused_dx, rtol=rtol, atol=atol)
y_pass_ratio = (
np.sum(np.isclose(golden_y, fused_y, rtol=rtol, atol=atol)) / golden_y.size
)
dx_pass_ratio = (
np.sum(np.isclose(golden_dx, fused_dx, rtol=rtol, atol=atol))
/ golden_dx.size
)
self.assertTrue(
y_pass_ratio > self.pass_line and dx_pass_ratio > self.pass_line
)

def golden_fa(self, query_, key_, value_, mask=None):
query = query_.cast("float32")
Expand Down Expand Up @@ -216,6 +241,83 @@ def init_dtype(self):
self.dtype = "bfloat16"


class TestNPUFABF16_GQA(TestNPUFAFP16):
def setUp(self):
super().setUp()
# (B,N,S,D)
self.shape = (1, 8, 4096, 128)
self.num_keys = 1

def init_dtype(self):
self.dtype = "bfloat16"

def gen_input(self):
np_query = np.random.randn(
self.shape[0], self.shape[1], self.shape[2], self.shape[3]
)
np_key = np.random.randn(
self.shape[0], self.num_keys, self.shape[2], self.shape[3]
)
np_value = np.random.randn(
self.shape[0], self.num_keys, self.shape[2], self.shape[3]
)
mask = paddle.full(
(self.shape[2], self.shape[2]), paddle.finfo(paddle.float16).min
)
mask = paddle.triu(mask, diagonal=1)
mask = mask.astype(paddle.bool)
np_uint16_query = convert_float_to_uint16(np_query)
np_uint16_key = convert_float_to_uint16(np_key)
np_uint16_value = convert_float_to_uint16(np_value)
return np_uint16_query, np_uint16_key, np_uint16_value, mask


class TestNPUFAFP16_GQA(TestNPUFABF16_GQA):
def init_dtype(self):
self.dtype = "float16"

def gen_input(self):
np_query = np.random.randn(
self.shape[0], self.shape[1], self.shape[2], self.shape[3]
)
np_key = np.random.randn(
self.shape[0], self.num_keys, self.shape[2], self.shape[3]
)
np_value = np.random.randn(
self.shape[0], self.num_keys, self.shape[2], self.shape[3]
)
mask = paddle.full(
(self.shape[2], self.shape[2]), paddle.finfo(paddle.float16).min
)
mask = paddle.triu(mask, diagonal=1)
mask = mask.astype(paddle.bool)
return np_query, np_key, np_value, mask


class TestNPUFABF16_GQA_NK2(TestNPUFABF16_GQA):
def setUp(self):
super().setUp()
self.num_keys = 2


class TestNPUFAFP16_GQA_NK2(TestNPUFAFP16_GQA):
def setUp(self):
super().setUp()
self.num_keys = 2


class TestNPUFABF16_GQA_NK4(TestNPUFABF16_GQA):
def setUp(self):
super().setUp()
self.num_keys = 4


class TestNPUFAFP16_GQA_NK4(TestNPUFAFP16_GQA):
def setUp(self):
super().setUp()
self.num_keys = 4


if __name__ == "__main__":
np.random.seed(2024)
unittest.main()

0 comments on commit 4091df3

Please sign in to comment.