Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 33 additions & 51 deletions tilelang/intrinsics/mma_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,60 +92,42 @@ def shared_32x16_to_mma_32x16_smoothlayout(i, j):
return (i * 2 + j // 16, j % 16)


def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]):
def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str], swizzle_bytes=None):
ana = arith.Analyzer()
BANK_SIZE_BYTES = 128
if isinstance(dtype, str):
dtype = DataType(dtype)
col_idx_outer, col_idx_inner = col_idx // (BANK_SIZE_BYTES // dtype.bits), col_idx % (
BANK_SIZE_BYTES // dtype.bits)
# use transaction bits to support diverse dtype.
# for fp16, 64 elems * 16 bits = 1024 bits, 32 elems * 32 bits = 512 bits
# for int8, 128 elems * 8 bits = 1024 bits, 64 elems * 8 bits = 512 bits
coalescent_bits = dtype.bits * row_size
# permutation on 4 banks, each bank has 32 bits
bank_elems = BANK_SIZE_BYTES // dtype.bits
new_col_idx_outer = None

if coalescent_bits % 1024 == 0:
# Use 8 * 8 permuted layout
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
# Every row below corresponds to 32 banks
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
# 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5
# 0 1 2 3 4 5 6 7 ==> 3 2 1 0 7 6 5 4
# 0 1 2 3 4 5 6 7 ==> 4 5 6 7 0 1 2 3
# 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2
# 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1
# 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0
row_idx_sub = row_idx % bank_elems
new_col_idx_outer = col_idx_outer ^ row_idx_sub
else:
assert coalescent_bits % 512 == 0
# Use 8 * 4 permuted layout
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
# Every row below corresponds to 16 banks
# 0 1 2 3 ==> 0 1 2 3
# 0 1 2 3 ==> 0 1 2 3
# 0 1 2 3 ==> 1 0 3 2
# 0 1 2 3 ==> 1 0 3 2
# 0 1 2 3 ==> 2 3 0 1
# 0 1 2 3 ==> 2 3 0 1
# 0 1 2 3 ==> 3 2 1 0
# 0 1 2 3 ==> 3 2 1 0
# View with 8 elements per row:
# 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3
# 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2
# 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1
# 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0
row_idx_sub = row_idx % bank_elems
# Interleave elems per byte
interleave_elems = 32 // dtype.bits
new_col_idx_outer = col_idx_outer ^ (row_idx_sub // interleave_elems)

assert (new_col_idx_outer is not None), f"Unsupported dtype {dtype} with {coalescent_bits} bits"
return row_idx, ana.simplify(new_col_idx_outer * bank_elems + col_idx_inner)
row_bytes = dtype.bits * row_size // 8
assert row_bytes % 32 == 0, "Row size must be multiple of 32B."
if swizzle_bytes is None:
swizzle_bytes = min(128, row_bytes)
# 128B swizzle
# Use 8 * 8 permuted layout
# Every number below corresponds to 16B
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
# 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5
# 0 1 2 3 4 5 6 7 ==> 3 2 1 0 7 6 5 4
# 0 1 2 3 4 5 6 7 ==> 4 5 6 7 0 1 2 3
# 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2
# 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1
# 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0
# 64B swizzle
# Use 8 * 4 permuted layout
# Every number below corresponds to 16B
# 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3
# 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2
# 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1
# 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0
# 32B swizzle
# Use 8 * 2 permuted layout
# Every number below corresponds to 16B
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
elem_per_16B = 128 // dtype.bits
col_idx_16B = col_idx // elem_per_16B
col_idx_in_16B = col_idx % elem_per_16B
new_col_idx_16B = col_idx_16B ^ (row_idx % (swizzle_bytes // 16))
return row_idx, ana.simplify(new_col_idx_16B * elem_per_16B + col_idx_in_16B)


def make_mma_swizzle_layout(shared_buf, is_smooth: bool = False):
Expand Down
Loading