Skip to content

Commit

Permalink
Support batched mask
Browse files Browse the repository at this point in the history
  • Loading branch information
hjjq committed Apr 13, 2023
1 parent 95d3618 commit 0739a1b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 27 deletions.
20 changes: 10 additions & 10 deletions python/hidet/graph/ops/definitions/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from hidet.ir import primitives as prim
from hidet.ir.primitives import active_mask, shfl_down_sync
from hidet.graph.ops.definitions.utils import tune
from hidet.lang import f16, f32, i32, u32, boolean, spatial, repeat, tensor
from hidet.lang import f16, f32, i32, u32, spatial, repeat, tensor
from hidet.lang import attr, grid, tensor_pointer, view, col_spatial
from hidet.lang.cuda import blockIdx, threadIdx, syncthreads, dynamic_shared_memory, register_tensor
from hidet.graph.ops.definitions.utils import Task, Operator, Tensor, TensorNode, compute, input_like
from hidet.graph.ops.definitions.utils import broadcast_shape, broadcast_shapes, broadcast_indices
from hidet.graph.ops.definitions.utils import can_broadcast
from hidet.utils.py import cdiv, prod
from .attention_mask import AttnMaskAddOp

Expand Down Expand Up @@ -349,10 +350,6 @@ def copy_k_g2s(k: f16[k_head + [d_size, n_size]], smem_k: smem_k_type, offset_j:
src_size = 0 if (i >= d_size or offset_j + j >= n_size) else min(n_size - j, 8)
if threadIdx.x < k_g2s_layout.num_workers and i < smem_k_type.shape[0]:
cp_async(~smem_k[i, j], ~gmem_k[i, j], cp_size=16, src_size=src_size * 2, cache_level='global')
# cp_async_wait_all()
# syncthreads()
# for i, j in k_norm_layout.on(threadIdx.x):
# smem_k[i, j] = smem_k[i, j] / prim.sqrt(f16(d_size))

@hidet.script
def copy_v_g2s(v: f16[v_head + [n_size, d_size]], smem_v: smem_v_type, offset_j: i32):
Expand Down Expand Up @@ -676,12 +673,15 @@ def attention(q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None) ->
if mask is None:
return AttnOp(q, k, v).get_output(0)

q_shape = q.shape
k_shape = k.shape
mask_shape = mask.shape
seq_len = q.shape[-2]
if mask_shape[-1] != seq_len or not all(s == 1 for s in mask_shape[:-1]):

q_head, k_head = (q_shape[:-2], k_shape[:-2])
qk_head = broadcast_shape(q_head, k_head)
qk_shape = qk_head + [seq_len, seq_len]
if not can_broadcast(mask_shape, qk_shape):
raise ValueError("Invalid mask dimension: {}".format(mask_shape))

if mask.dtype == boolean:
return AttnMaskWhereOp(q, k, v, mask).get_output(0)
else:
return AttnMaskAddOp(q, k, v, mask).get_output(0)
return AttnMaskAddOp(q, k, v, mask).get_output(0)
30 changes: 13 additions & 17 deletions python/hidet/graph/ops/definitions/attention/attention_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,7 @@ def __init__(self, name: str, q: TensorNode, k: TensorNode, v: TensorNode, mask:
n_size = q_shape[-2]
d_size = q_shape[-1]
o_shape = broadcast_shapes([q_shape[:-2], k_shape[:-2], v_shape[:-2]]) + [n_size, d_size]
o_head, q_head, k_head, v_head, mask_head = (
o_shape[:-2],
q_shape[:-2],
k_shape[:-2],
v_shape[:-2],
mask_shape[:-2],
)
o_head, q_head, k_head, v_head = (o_shape[:-2], q_shape[:-2], k_shape[:-2], v_shape[:-2])
qk_head = broadcast_shape(q_head, k_head)
mask_shape = mask.const_shape()

Expand All @@ -45,13 +39,14 @@ def __init__(self, name: str, q: TensorNode, k: TensorNode, v: TensorNode, mask:
),
)

qk_shape = qk.const_shape()

qk_masked = compute(
name='qk_masked',
shape=qk_head + [n_size, n_size],
fcompute=lambda *indices: mask[mask_head + [0, indices[-1]]] + qk[indices],
shape=qk_shape,
fcompute=lambda *indices: mask[broadcast_indices(indices, mask_shape, qk_shape)] + qk[indices],
)

qk_shape = qk.const_shape()
axis = len(qk_shape) - 1
axis_extent = qk_shape[axis]
reduced_shape = qk_shape[:axis] + qk_shape[axis + 1 :]
Expand Down Expand Up @@ -378,10 +373,6 @@ def copy_k_g2s(k: f16[k_head + [d_size, n_size]], smem_k: smem_k_type, offset_j:
src_size = 0 if (i >= d_size or offset_j + j >= n_size) else min(n_size - j, 8)
if threadIdx.x < k_g2s_layout.num_workers and i < smem_k_type.shape[0]:
cp_async(~smem_k[i, j], ~gmem_k[i, j], cp_size=16, src_size=src_size * 2, cache_level='global')
# cp_async_wait_all()
# syncthreads()
# for i, j in k_norm_layout.on(threadIdx.x):
# smem_k[i, j] = smem_k[i, j] / prim.sqrt(f16(d_size))

@hidet.script
def copy_v_g2s(v: f16[v_head + [n_size, d_size]], smem_v: smem_v_type, offset_j: i32):
Expand Down Expand Up @@ -635,13 +626,18 @@ def attn_kernel(
for mma_i, mma_j in grid(mmas_per_warp_m, mmas_per_warp_n):
warp_mma(~regs_q[mma_i, 0], ~regs_k[mma_k, mma_j, 0], ~regs_acc[mma_i, mma_j, 0])
# Apply Masking
qk_head_index = list(spatial(*qk_head).map(blockIdx.y))
for mma_i, mma_j in grid(mmas_per_warp_m, mmas_per_warp_n):
wi, wj, wk = spatial(warp_count_m, warp_count_n, warp_count_k).on(warp_id)[0]
p = 0
for _, tj in mma_config.c_store_map.on(lane_id):
# delta_m = wi * warp_elems_m + mma_i * mma_m + i
for ti, tj in mma_config.c_store_map.on(lane_id):
delta_m = wi * warp_elems_m + mma_i * mma_m + ti
delta_n = wj * warp_elems_n + mma_j * mma_n + tj
regs_acc[mma_i, mma_j, p] += mask[0, 0, 0, delta_n]
regs_acc[mma_i, mma_j, p] += mask[
broadcast_indices(
qk_head_index + [delta_m, delta_n], mask_shape, qk_head + [n_size, n_size]
)
]
p += 1
qk_softmax_reduce(smem_qk, smem_mij, smem_lij, regs_acc)
# Load Oi into Smem
Expand Down

0 comments on commit 0739a1b

Please sign in to comment.