Skip to content

Commit

Permalink
extend swizzle rule to int8.
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Feb 7, 2024
1 parent 2b39d03 commit 3891f2a
Showing 1 changed file with 29 additions and 18 deletions.
47 changes: 29 additions & 18 deletions src/tir/transforms/inject_permuted_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,20 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer {
using IRMutatorWithAnalyzer::VisitExpr_;
using IRMutatorWithAnalyzer::VisitStmt_;

Array<PrimExpr> PermuteIndices(PrimExpr row_idx, PrimExpr col_idx, int row_size) {
Array<PrimExpr> PermuteIndices(PrimExpr row_idx, PrimExpr col_idx, int row_size,
DataType dtype = DataType::Int(32)) {
ICHECK(permute_);
// Index after vectorizing by 8
PrimExpr col_idx_outer = floordiv(col_idx, VECTORIZE_FACTOR),
col_idx_inner = floormod(col_idx, VECTORIZE_FACTOR);
PrimExpr col_idx_outer = floordiv(col_idx, BANK_SIZE_BYTES / dtype.bits()),
col_idx_inner = floormod(col_idx, BANK_SIZE_BYTES / dtype.bits());
PrimExpr new_col_idx_outer;
if (row_size % 64 == 0) {
// use transaction bits to support diverse dtype.
// for fp16, 64 elems * 16 bits = 1024 bits, 32 elems * 32 bits = 512 bits
// for int8, 128 elems * 8 bits = 1024 bits, 64 elems * 8 bits = 512 bits
int coalescent_bits = dtype.bits() * row_size;
// permutation on 4 banks, each bank has 32 bits
int bank_elems = BANK_SIZE_BYTES / dtype.bits();
if (coalescent_bits % 1024 == 0) {
// Use 8 * 8 permuted layout
// Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
// Every row below corresponds to 32 banks
Expand All @@ -76,10 +83,10 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer {
// 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2
// 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1
// 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0
auto row_idx_sub = floormod(row_idx, 8);
auto row_idx_sub = floormod(row_idx, bank_elems);
new_col_idx_outer = col_idx_outer ^ row_idx_sub;
} else {
ICHECK(row_size % 32 == 0);
ICHECK(coalescent_bits % 512 == 0);
// Use 8 * 4 permuted layout
// Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
// Every row below corresponds to 16 banks
Expand All @@ -96,10 +103,12 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer {
// 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2
// 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1
// 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0
auto row_idx_sub = floormod(row_idx, 8);
new_col_idx_outer = col_idx_outer ^ floordiv(row_idx_sub, 2);
auto row_idx_sub = floormod(row_idx, bank_elems);
// Interleave elems per byte
int interleave_elems = 32 / dtype.bits();
new_col_idx_outer = col_idx_outer ^ floordiv(row_idx_sub, interleave_elems);
}
return {row_idx, analyzer_->Simplify(new_col_idx_outer * 8 + col_idx_inner)};
return {row_idx, analyzer_->Simplify(new_col_idx_outer * bank_elems + col_idx_inner)};
}

static bool CheckAnnotation(ObjectRef annotation) {
Expand Down Expand Up @@ -162,14 +171,14 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer {
return buffer_row_size;
}

Array<PrimExpr> HandleBufferIndices(Buffer buffer, Array<PrimExpr> indices) {
Array<PrimExpr> HandleBufferIndices(Buffer buffer, Array<PrimExpr> indices, DataType dtype) {
auto buffer_row_size = CheckAndGetBufferRowSize(buffer);

// Mutate the last two indices
auto indices_size = indices.size();
PrimExpr row_idx = indices[indices_size - 2];
PrimExpr col_idx = indices[indices_size - 1];
auto new_indices = PermuteIndices(row_idx, col_idx, buffer_row_size);
auto new_indices = PermuteIndices(row_idx, col_idx, buffer_row_size, dtype);
indices.Set(indices_size - 2, new_indices[0]);
indices.Set(indices_size - 1, new_indices[1]);
return indices;
Expand All @@ -180,7 +189,6 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer {
// We assume the shape of the shared memory is [..., row_size, col_size],
// where row_size is divisible by 64, or divisible by 32 and col_size is divisible by 2.
auto store = Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));

if (!permute_ || store->buffer->shape.size() < 2) {
return store;
}
Expand All @@ -191,7 +199,8 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer {
}

auto store_node = store.CopyOnWrite();
store_node->indices = HandleBufferIndices(store_node->buffer, store_node->indices);
store_node->indices =
HandleBufferIndices(store_node->buffer, store_node->indices, store->buffer->dtype);
return store;
}

Expand All @@ -209,11 +218,13 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer {
}

auto load_node = load.CopyOnWrite();
load_node->indices = HandleBufferIndices(load_node->buffer, load_node->indices);
load_node->indices =
HandleBufferIndices(load_node->buffer, load_node->indices, load->buffer->dtype);
return load;
}

PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr, Optional<PrimExpr> offset = NullOpt) {
PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr, Optional<PrimExpr> offset = NullOpt,
DataType dtype = DataType::Int(32)) {
// The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and accumulate it to
// smem_offset
CHECK(access_ptr->IsInstance<CallNode>())
Expand All @@ -233,7 +244,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer {
PrimExpr row_idx = floordiv(smem_offset, buffer_row_size);
PrimExpr col_idx = floormod(smem_offset, buffer_row_size);

auto new_indices = PermuteIndices(row_idx, col_idx, buffer_row_size);
auto new_indices = PermuteIndices(row_idx, col_idx, buffer_row_size, dtype);
auto new_offset = analyzer_->Simplify(new_indices[0] * buffer_row_size + new_indices[1]);

auto new_access_ptr = access_ptr_call.CopyOnWrite();
Expand All @@ -258,7 +269,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer {
// smem_ptr: T.tvm_access_ptr(ptype, data, offset, extent, rw_mask)
auto access_ptr = call->args[5];
PrimExpr smem_offset = call->args[6];
auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr, smem_offset);
auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr, smem_offset, call->dtype);
auto new_call = call.CopyOnWrite();
new_call->args.Set(5, new_access_ptr);
new_call->args.Set(6, IntImm(smem_offset->dtype, 0));
Expand All @@ -267,7 +278,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer {
// TODO(yixin): mma_store is not fully tested yet
// because we will directly store result to Buffer instead of calling mma_store now
auto access_ptr = call->args[2];
auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr);
auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr, NullOpt, call->dtype);
auto new_call = call.CopyOnWrite();
new_call->args.Set(2, new_access_ptr);
return call;
Expand Down

0 comments on commit 3891f2a

Please sign in to comment.