-
Notifications
You must be signed in to change notification settings - Fork 333
Closed
Description
with T.Kernel(T.ceildiv(seqlen_q, block_M), heads * batch, num_split, threads=128 * 2) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
T.copy(Q[bid, 0, hid, :], Q_shared[0, :])
# block level shared memory tensor, and register level tensor
if bx == 0 and by == 0 and bz == 0 and threadIdx.x == 0:
print(Q_shared)
print(acc_s)
For debug, is it possible to print the intermediate value like that?
Metadata
Metadata
Assignees
Labels
No labels