Skip to content

Commit

Permalink
Add comprehensive tests to test the kernel across available dtypes.
Browse files Browse the repository at this point in the history
Added softmax and gemm kernel to test across the available float and int dtypes.
  • Loading branch information
Prashant Kumar committed Feb 27, 2024
1 parent 971231c commit f09dfab
Showing 1 changed file with 145 additions and 0 deletions.
145 changes: 145 additions & 0 deletions core/tests/kernel/coverage_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import torch
import shark_turbine.kernel as tk
import shark_turbine.kernel.lang as tkl
import pytest


TKL_TO_TORCH_DTYPE = {
tkl.f16: torch.half,
tkl.f32: torch.float,
tkl.f64: torch.double,
tkl.bool: torch.bool,
tkl.i8: torch.int8,
tkl.i16: torch.int16,
tkl.i32: torch.int32,
tkl.i64: torch.int64,
}

FLOAT_DTYPES = [tkl.f16, tkl.f32, tkl.f64]
INT_DTYPES = [
tkl.bool,
tkl.i4,
tkl.i8,
tkl.i16,
tkl.i32,
tkl.i64,
tkl.index,
]


def iota_krnl(dtype, input):
M = tkl.sym.M

@tk.gen.thread(M)
def iota_kernel(out: tkl.OutputBuffer[M, dtype]):
a = (
tkl.constant((17, 37, 19), dtype, 5)
if dtype in INT_DTYPES
else tkl.constant((17, 37, 19), dtype, 5.0)
)
b = (
tkl.constant((17, 37, 19), dtype, 10)
if dtype in INT_DTYPES
else tkl.constant((17, 37, 19), dtype, 10.0)
)
c = (
tkl.constant((17, 37, 19), dtype, 2)
if dtype in INT_DTYPES
else tkl.constant((17, 37, 19), dtype, 2.0)
)
if dtype in INT_DTYPES:
c = (a * b) // c
else:
c = (a * b) / c
c = c + a - b

with tk.gen.TestLaunchContext():
iota_kernel(input)


def softmax_krnl(dtype, input, output):
M = tkl.sym.M
K = tkl.sym.K

@tk.gen.thread(M)
def softmax_kernel(
input: tk.lang.InputBuffer[M, K, dtype],
output: tk.lang.OutputBuffer[M, K, dtype],
):
row_index = tk.lang.program_id(0)
input_row = input[row_index, :]
numerator = tkl.exp2(input_row - tkl.max(input_row))
if dtype in INT_DTYPES:
output_row = numerator // tkl.sum(numerator)
else:
output_row = numerator / tkl.sum(numerator)
output[row_index, :] = output_row

with tk.gen.TestLaunchContext():
softmax_kernel(input, output)


def gemm_fx_kernel(dtype, A, B, output):
N = tkl.sym.N
M = tkl.sym.M
K = tkl.sym.K
BLOCK_SIZE = tkl.sym.BLOCK_SIZE

@tk.gen.thread(N // BLOCK_SIZE, M // BLOCK_SIZE)
def gemm_kernel(
A: tkl.InputBuffer[N, K, dtype],
B: tkl.InputBuffer[K, M, dtype],
output: tkl.OutputBuffer[N, M, dtype],
):
grid_n = tkl.program_id(0)
grid_m = tkl.program_id(1)

acc = None
# TODO: Only considering the float and integer cases.
if dtype in INT_DTYPES:
acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), dtype, 0)
else:
acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), dtype, 0.0)

@tkl.for_loop(0, K // BLOCK_SIZE, init_args=[acc])
def body(i, c):
a = tkl.load(A, (grid_n, i * BLOCK_SIZE), (BLOCK_SIZE, BLOCK_SIZE))
b = tkl.load(B, (i * BLOCK_SIZE, grid_m), (BLOCK_SIZE, BLOCK_SIZE))
return (tkl.dot(a, b, c),)

tkl.store(output, (grid_n, grid_m), body[0])

with tk.gen.TestLaunchContext({BLOCK_SIZE: 32}):
gemm_kernel(A, B, output)


@pytest.mark.parametrize(
("dtype",),
[(x,) for x in FLOAT_DTYPES + INT_DTYPES],
)
def test_iota_krnl(dtype):
input = torch.zeros(17)
iota_krnl(dtype, input)


@pytest.mark.parametrize(
("dtype",),
[(x,) for x in FLOAT_DTYPES],
)
def test_softmax_krnl(dtype):
if dtype in TKL_TO_TORCH_DTYPE:
input = torch.randn(128, 64).to(TKL_TO_TORCH_DTYPE[dtype])
output = torch.randn(128, 64).to(TKL_TO_TORCH_DTYPE[dtype])
softmax_krnl(dtype, input, output)


@pytest.mark.parametrize(
("dtype",),
[(x,) for x in FLOAT_DTYPES + INT_DTYPES],
)
def test_gemm_krnl(dtype):
if dtype in TKL_TO_TORCH_DTYPE:
A = torch.randn(512, 1024).to(TKL_TO_TORCH_DTYPE[dtype])
B = torch.randn(1024, 2048).to(TKL_TO_TORCH_DTYPE[dtype])
output = torch.zeros(512, 2048).to(TKL_TO_TORCH_DTYPE[dtype])
gemm_fx_kernel(dtype, A, B, output)

0 comments on commit f09dfab

Please sign in to comment.