From 036b037b815098eed4e5ccbcea1e8d006548d855 Mon Sep 17 00:00:00 2001 From: John Lawson Date: Wed, 17 Feb 2021 15:10:55 +0000 Subject: [PATCH 01/53] Add support for AMX tile instructions --- src/CMakeLists.txt | 2 + src/CodeGen_X86.cpp | 48 ++++- src/Expr.h | 4 + src/ExtractTileOperations.cpp | 293 ++++++++++++++++++++++++++++++ src/ExtractTileOperations.h | 20 ++ src/FuseGPUThreadLoops.cpp | 2 + src/IRPrinter.cpp | 3 + src/Lower.cpp | 10 + src/runtime/x86_avx512.ll | 38 ++++ test/performance/CMakeLists.txt | 1 + test/performance/tiled_matmul.cpp | 130 +++++++++++++ 11 files changed, 550 insertions(+), 1 deletion(-) create mode 100644 src/ExtractTileOperations.cpp create mode 100644 src/ExtractTileOperations.h create mode 100644 test/performance/tiled_matmul.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8bc68eeeb26c..84f71065f4ef 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -59,6 +59,7 @@ set(HEADER_FILES ExprUsesVar.h Extern.h ExternFuncArgument.h + ExtractTileOperations.h FastIntegerDivide.h FindCalls.h FindIntrinsics.h @@ -218,6 +219,7 @@ set(SOURCE_FILES EmulateFloat16Math.cpp Error.cpp Expr.cpp + ExtractTileOperations.cpp FastIntegerDivide.cpp FindCalls.cpp FindIntrinsics.cpp diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 2038dcce75c8..0aa64e089845 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -1,4 +1,5 @@ #include "CodeGen_Posix.h" + #include "ConciseCasts.h" #include "Debug.h" #include "IRMatch.h" @@ -79,15 +80,21 @@ class CodeGen_X86 : public CodeGen_Posix { void visit(const EQ *) override; void visit(const NE *) override; void visit(const Select *) override; + void visit(const Allocate *) override; + void visit(const Load *) override; + void visit(const Store *) override; void codegen_vector_reduce(const VectorReduce *, const Expr &init) override; // @} + +private: + Scope mem_type; }; CodeGen_X86::CodeGen_X86(Target t) : CodeGen_Posix(complete_x86_target(t)) { } -const int max_intrinsic_args = 4; +const int max_intrinsic_args = 6; struct x86Intrinsic { const char *intrin_name; @@ -184,6 +191,13 @@ const x86Intrinsic intrinsic_defs[] = { {"dpwssdx16", Int(32, 16), "dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_SapphireRapids}, {"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids}, {"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids}, + + {"tileloadd64_i8", Int(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids}, + {"tdpbssd", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), Int(8, 1024), Int(8, 1024)}, Target::AVX512_SapphireRapids}, + {"tilezero_i32", Int(32, 256), "tile_zero", {Int(16), Int(16)}, Target::AVX512_SapphireRapids}, + // CodeGen_LLVM cannot cope with returning Type() ie void*, and return type needs to be vector to trigger call_overloaded_intrin + {"tilestored64", Bool(2), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Int(32, 256)}, Target::AVX512_SapphireRapids}, + }; // clang-format on @@ -576,6 +590,38 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init CodeGen_Posix::codegen_vector_reduce(op, init); } +void CodeGen_X86::visit(const Allocate *op) { + ScopedBinding bind(mem_type, op->name, op->memory_type); + CodeGen_Posix::visit(op); +} + +void CodeGen_X86::visit(const Load *op) { + if (mem_type.contains(op->name) && mem_type.get(op->name) == MemoryType::AMXTile) { + const Ramp *ramp = op->index.as(); + internal_assert(ramp) << "Expected AMXTile to have index ramp\n"; + Value *ptr = codegen_buffer_pointer(op->name, op->type, ramp->base); + LoadInst *load = builder->CreateAlignedLoad(ptr, llvm::Align(op->type.bytes())); + add_tbaa_metadata(load, op->name, op->index); + value = load; + return; + } + CodeGen_Posix::visit(op); +} + +void CodeGen_X86::visit(const Store *op) { + if (mem_type.contains(op->name) && mem_type.get(op->name) == MemoryType::AMXTile) { + Value *val = codegen(op->value); + Halide::Type value_type = op->value.type(); + const Ramp *ramp = op->index.as(); + internal_assert(ramp) << "Expected AMXTile to have index ramp\n"; + Value *ptr = codegen_buffer_pointer(op->name, value_type, ramp->base); + StoreInst *store = builder->CreateAlignedStore(val, ptr, llvm::Align(value_type.bytes())); + add_tbaa_metadata(store, op->name, op->index); + return; + } + CodeGen_Posix::visit(op); +} + string CodeGen_X86::mcpu() const { if (target.has_feature(Target::AVX512_SapphireRapids)) { #if LLVM_VERSION >= 120 diff --git a/src/Expr.h b/src/Expr.h index c5472e766fa4..b70d608d290b 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -379,6 +379,10 @@ enum class MemoryType { * intermediate buffers. Necessary for vgather-vscatter instructions * on Hexagon */ VTCM, + + /** AMX Tile register for X86. Any data that would be used in an AMX matrix + * multiplication must first be loaded into an AMX tile register. */ + AMXTile, }; namespace Internal { diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp new file mode 100644 index 000000000000..e8edc3532775 --- /dev/null +++ b/src/ExtractTileOperations.cpp @@ -0,0 +1,293 @@ +#include "ExtractTileOperations.h" + +#include "IRMatch.h" // expr_match +#include "IRMutator.h" +#include "IROperator.h" // Expr + Expr +#include "Util.h" // ScopedValue + +namespace Halide { +namespace Internal { + +namespace { + +template +struct Tile { + bool result; + Expr base; + Expr stride[Dim]; + int extent[Dim]; +}; + +const auto wild_i32 = Variable::make(Int(32), "*"); +const auto wild_i32x = Variable::make(Int(32, 0), "*"); + +Tile<2> is_2d_tile_index(const Expr &e) { + // ramp(ramp(base, 1, 4), x4(stride), 4) + std::vector matches; + if (const auto *r1 = e.as()) { + if (const auto *r2 = r1->base.as()) { + auto ramp_2d_pattern = Ramp::make(Ramp::make(wild_i32, wild_i32, r2->lanes), Broadcast::make(wild_i32, r2->lanes), r1->lanes); + if (expr_match(ramp_2d_pattern, e, matches)) { + return {true, std::move(matches[0]), {std::move(matches[2]), std::move(matches[1])}, {r1->lanes, r2->lanes}}; + } + } + } + return {}; +} + +Tile<3> is_3d_tile_index(const Expr &e) { + std::vector matches; + auto add_sub_pattern = (wild_i32x + wild_i32x) - wild_i32x; + if (!expr_match(add_sub_pattern, e, matches)) { return {}; } + // ramp(x16(base), x16(stride), 4) + x16(ramp(idx, 1, 4)) y: 4, x: 4, r: 4 + // ramp(x10(base), x10(stride), 3) + x6(ramp(idx, 1, 5)) y: 2, x: 3, r: 5 + Expr first = std::move(matches[0]); + Expr second = std::move(matches[1]); + Expr adj = std::move(matches[2]); + const auto *r1 = first.as(); + const auto *b2 = second.as(); + if (!r1 && !b2) { + // Try switching the order + r1 = second.as(); + b2 = first.as(); + } + if (!r1 || !b2) { return {}; } + + const auto *b1 = r1->base.as(); + const auto *r2 = b2->value.as(); + + if (!b1 || !r2) { return {}; } + + int x_tile = r1->lanes; + int r_tile = r2->lanes; + int y_tile = b1->lanes / r_tile; + if (y_tile != b2->lanes / x_tile) { return {}; } + + auto pattern1 = Ramp::make(Broadcast::make(wild_i32, b1->lanes), Broadcast::make(wild_i32, b1->lanes), r1->lanes); + if (!expr_match(pattern1, first, matches)) { return {}; } + Expr base = std::move(matches[0]); + Expr x_stride = std::move(matches[1]); + + auto pattern2 = Broadcast::make(Ramp::make(wild_i32, wild_i32, r2->lanes), b2->lanes); + if (!expr_match(pattern2, second, matches)) { return {}; } + base += std::move(matches[0]); + Expr r_stride = std::move(matches[1]); + + auto pattern3 = Broadcast::make(wild_i32, b1->lanes * r1->lanes); + if (!expr_match(pattern3, adj, matches)) { return {}; } + base -= std::move(matches[0]); + + return {true, base, {x_stride, 0, r_stride}, {x_tile, y_tile, r_tile}}; +} + +struct NewMatmul { + bool result = false; + Stmt stmt; + int tile_x; + int tile_y; + int tile_r; +}; + +NewMatmul +convert_to_matmul(const Store *op, const std::string &new_name) { + // m[ramp(0, 1, S)] = VectorAdd(lhs[{XYR tile}] * xX(rhs[{YR tile}])) + m[ramp(0, 1, S)] + const auto wild_i8x = Variable::make(Int(8, 0), "*"); + const auto wild_i16x = Variable::make(Int(16, 0), "*"); + std::vector matches; + const auto pattern1 = wild_i32x + wild_i32x; + if (!expr_match(pattern1, op->value, matches)) { return {}; } + const auto *reduce = matches[0].as(); + const auto *load = matches[1].as(); + if (!reduce || reduce->op != VectorReduce::Add) { return {}; } + if (!load || load->name != op->name || !equal(load->index, op->index)) { return {}; } + + // FIXME: Add support for uint8 and bf16 for LLVM 13+ + auto pattern2 = cast(Int(32, 0), cast(Int(16, 0), wild_i8x) * wild_i16x); + if (!expr_match(pattern2, reduce->value, matches)) { return {}; } + const auto *lhs_load = matches[0].as(); + // FIXME: When tile_r is not 4 the broadcast is inside the index, not of the value + const auto *rhs_broadcast = matches[1].as(); + if (!lhs_load || !rhs_broadcast) { return {}; } + const auto *rhs_cast = rhs_broadcast->value.as(); + if (!rhs_cast || rhs_cast->value.type().element_of() != Int(8)) { return {}; } + const auto *rhs_load = rhs_cast->value.as(); + if (!rhs_load) { return {}; } + + const auto lhs_tile = is_3d_tile_index(lhs_load->index); + const auto rhs_tile = is_2d_tile_index(rhs_load->index); + // FIXME: When tile_r is not 4 the RHS load will be 4D (x, r/4, y, r%4) + if (!lhs_tile.result || !rhs_tile.result) { return {}; } + + const int tile_x = lhs_tile.extent[0]; + const int tile_y = lhs_tile.extent[1]; + const int tile_r = lhs_tile.extent[2]; + const int factor = reduce->value.type().lanes() / reduce->type.lanes(); + if (op->index.type().lanes() != tile_x * tile_y || + factor != tile_r || + tile_y != rhs_tile.extent[0] || + tile_r != rhs_tile.extent[1]) { + return {}; + } + + // {rows, colbytes, var, index} + auto lhs_var = Variable::make(Handle(), lhs_load->name); + auto lhs = Call::make(Int(8, 1024), "tile_load", {tile_x, tile_r, lhs_var, lhs_tile.base, lhs_tile.stride[0]}, Call::Intrinsic); + auto rhs_var = Variable::make(Handle(), rhs_load->name); + auto rhs = Call::make(Int(8, 1024), "tile_load", {1, tile_y * tile_r, rhs_var, rhs_tile.base, rhs_tile.stride[0]}, Call::Intrinsic); + + // {rows, colbytes, acc, out, lhs, rhs} + auto out = Load::make(Int(32, 256), new_name, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); + auto colbytes = tile_y * 32 / rhs_load->type.bits(); + auto matmul = Call::make(Int(32, 256), "tile_matmul", {tile_x, colbytes, tile_r, out, lhs, rhs}, Call::Intrinsic); + auto store = Store::make(new_name, matmul, Ramp::make(0, 1, 256), Parameter(), const_true(256), ModulusRemainder()); + return {true, std::move(store), tile_x, tile_y, tile_r}; +} + +Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const std::string &new_name) { + if (const auto *ramp = op->index.as()) { + if (const auto *bcast = op->value.as()) { + if (is_const_one(ramp->stride) && + is_const_zero(bcast->value) && + (bcast->lanes == tile_x * tile_y)) { + auto rows = Cast::make(Int(16), tile_x); + auto bytes = op->value.type().bytes(); + auto colbytes = Cast::make(Int(16), tile_y * bytes); + auto val = Call::make(Int(32, 256), "tile_zero", {rows, colbytes}, Call::Intrinsic); + auto store = Store::make(new_name, val, Ramp::make(0, 1, 256), Parameter(), const_true(256), ModulusRemainder()); + return store; + } + } + } + return {}; +} + +Stmt convert_to_tile_store(const Store *op, const std::string &amx_alloc, int tile_x, int tile_y) { + auto tile = is_2d_tile_index(op->index); + if (tile.result && tile.extent[0] == tile_x && tile.extent[1] == tile_y) { + auto out = Variable::make(Handle(), op->name); + auto tile_val = Load::make(Int(32, 256), amx_alloc, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); + auto bytes = op->value.type().bytes(); + internal_assert(bytes == 4) << "AMX store only supported for int32 and float32, not for " << op->value.type() << "\n"; + // {tile_x, tile_y, var, base, stride} + auto store = Call::make(Bool(2), "tile_store", {tile_x, tile_y * bytes, out, tile.base * bytes, tile.stride[0] * bytes, tile_val}, Call::Intrinsic); + return Evaluate::make(store); + } + return {}; +} + +class ExtractTileOperations : public IRMutator { + using IRMutator::visit; + + std::string tile_name; + std::string amx_alloc; + std::vector pending_stores; + bool is_valid = true; + bool in_allocate = false; + int found_tile_x = -1; + int found_tile_y = -1; + int found_tile_r = -1; + + Stmt visit(const Allocate *op) override { + if (op->type.is_int() && op->type.bits() == 32) { + if (in_allocate) { + // Found two possible tile allocations + // FIXME: Handle this better + is_valid = false; + return op; + } + amx_alloc = op->name + ".amx"; + tile_name = op->name; + ScopedValue old_in_alloc(in_allocate, true); + Stmt body = op->body; + + pending_stores.clear(); + body = mutate(body); + if (!is_valid) { + return op; + } + if (found_tile_x < 0 || found_tile_y < 0 || found_tile_r < 0) { + return op; + } + if (!pending_stores.empty()) { + // Really only need to go over the pending stores + body = mutate(body); + } + if (!is_valid) { + return op; + } + + return Allocate::make(amx_alloc, Int(32, 256), MemoryType::AMXTile, {1}, const_true(), body); + } + return IRMutator::visit(op); + } + + Stmt visit(const Free *op) override { + if (op->name != tile_name) { + return op; + } + return Free::make(amx_alloc); + } + + Stmt visit(const ProducerConsumer *op) override { + if (op->name != tile_name) { + return IRMutator::visit(op); + } + + auto body = mutate(op->body); + return ProducerConsumer::make(amx_alloc, op->is_producer, body); + } + + Stmt visit(const Store *op) override { + if (op->name != tile_name) { + const auto *load = op->value.as(); + if (!load || load->name != tile_name) { + return op; + } + auto store = convert_to_tile_store(op, amx_alloc, found_tile_x, found_tile_y); + if (store.defined()) { + return store; + } else { + // Found store of tile_name that is not a tile store. + is_valid = false; + return op; + } + } + + auto matmul = convert_to_matmul(op, amx_alloc); + if (matmul.result) { + if ((found_tile_x > 0 && matmul.tile_x != found_tile_x) || + (found_tile_r > 0 && matmul.tile_r != found_tile_r) || + (found_tile_y > 0 && matmul.tile_y != found_tile_y)) { + is_valid = false; + return op; + } + found_tile_x = matmul.tile_x; + found_tile_y = matmul.tile_y; + found_tile_r = matmul.tile_r; + return matmul.stmt; + } + + if (found_tile_x < 0 || found_tile_y < 0) { + pending_stores.emplace_back(op); + return op; + } + + auto zero = convert_to_zero(op, found_tile_x, found_tile_y, amx_alloc); + if (zero.defined()) { + return zero; + } + + // Otherwise there is some other operation using the allocation, so we cannot use the AMX instructions + is_valid = false; + return op; + } +}; + +} // namespace + +Stmt extract_tile_operations(const Stmt &s) { + return ExtractTileOperations().mutate(s); +} + +} // namespace Internal +} // namespace Halide diff --git a/src/ExtractTileOperations.h b/src/ExtractTileOperations.h new file mode 100644 index 000000000000..d246bddc5a04 --- /dev/null +++ b/src/ExtractTileOperations.h @@ -0,0 +1,20 @@ +#ifndef HALIDE_EXTRACT_TILE_OPERATIONS_H +#define HALIDE_EXTRACT_TILE_OPERATIONS_H + +/** \file + * Defines the lowering pass that injects calls to tile intrinsics that support + * AMX instructions. + */ + +#include "Expr.h" + +namespace Halide { +namespace Internal { + +/** TODO */ +Stmt extract_tile_operations(const Stmt &s); + +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/FuseGPUThreadLoops.cpp b/src/FuseGPUThreadLoops.cpp index 7fa67ac2192f..0c58e318a86c 100644 --- a/src/FuseGPUThreadLoops.cpp +++ b/src/FuseGPUThreadLoops.cpp @@ -1279,6 +1279,7 @@ class InjectThreadBarriers : public IRMutator { case MemoryType::Register: case MemoryType::LockedCache: case MemoryType::VTCM: + case MemoryType::AMXTile: break; } @@ -1303,6 +1304,7 @@ class InjectThreadBarriers : public IRMutator { case MemoryType::Register: case MemoryType::LockedCache: case MemoryType::VTCM: + case MemoryType::AMXTile: break; } diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index 62fdb705997b..38e2eeefd511 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -135,6 +135,9 @@ std::ostream &operator<<(std::ostream &out, const MemoryType &t) { case MemoryType::VTCM: out << "VTCM"; break; + case MemoryType::AMXTile: + out << "AMXTile"; + break; } return out; } diff --git a/src/Lower.cpp b/src/Lower.cpp index a7227b6ec76d..4a8ec2df34b5 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -22,6 +22,7 @@ #include "DebugToFile.h" #include "Deinterleave.h" #include "EarlyFree.h" +#include "ExtractTileOperations.h" #include "FindCalls.h" #include "FindIntrinsics.h" #include "FlattenNestedRamps.h" @@ -385,6 +386,15 @@ Module lower(const vector &output_funcs, s = lower_unsafe_promises(s, t); log("Lowering after lowering unsafe promises:", s); +#if LLVM_VERSION >= 12 + if (t.has_feature(Target::AVX512_SapphireRapids)) { + debug(1) << "Extracting tile operations...\n"; + s = extract_tile_operations(s); + debug(2) << "Lowering after extracting tile operations:\n" + << s << "\n\n"; + } +#endif + debug(1) << "Flattening nested ramps...\n"; s = flatten_nested_ramps(s); log("Lowering after flattening nested ramps:", s); diff --git a/src/runtime/x86_avx512.ll b/src/runtime/x86_avx512.ll index 904fabe9368e..5e9ace735bfd 100644 --- a/src/runtime/x86_avx512.ll +++ b/src/runtime/x86_avx512.ll @@ -90,3 +90,41 @@ define weak_odr <4 x i32> @dpwssdx4(<4 x i32> %init, <8 x i16> %a, <8 x i16> %b ret <4 x i32> %3 } declare <4 x i32> @llvm.x86.avx512.vpdpwssd.128(<4 x i32>, <4 x i32>, <4 x i32>) + +define weak_odr <1024 x i8> @tileloadd64_i8(i16 %rows, i16 %colbytes, i8* %ptr, i64 %off, i64 %stride) nounwind alwaysinline { + %1 = getelementptr i8, i8* %ptr, i64 %off + %2 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %rows, i16 %colbytes, i8* %1, i64 %stride) + %3 = bitcast x86_amx %2 to <1024 x i8> + ret <1024 x i8> %3 +} +declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) + +define weak_odr <256 x i32> @tdpbssd(i16 %rows, i16 %colbytes, i16 %acc, <256 x i32> %out, <1024 x i8> %lhs, <1024 x i8> %rhs) nounwind readnone alwaysinline { + %1 = bitcast <1024 x i8> %lhs to x86_amx + %2 = bitcast <1024 x i8> %rhs to x86_amx + %3 = bitcast <256 x i32> %out to x86_amx + %4 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %rows, i16 %colbytes, i16 %acc, x86_amx %3, x86_amx %1, x86_amx %2) nounwind readnone + %5 = bitcast x86_amx %4 to <256 x i32> + ret <256 x i32> %5 +} +declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) + +define weak_odr <2 x i1> @tilestored64(i16 %rows, i16 %cols, i8* %ptr, i64 %off, i64 %stride, <256 x i32> %val) nounwind alwaysinline { + %1 = getelementptr i8, i8* %ptr, i64 %off + %2 = bitcast <256 x i32> %val to x86_amx + tail call void @llvm.x86.tilestored64.internal(i16 %rows, i16 %cols, i8* %1, i64 %stride, x86_amx %2) + ret <2 x i1> zeroinitializer +} +declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) + +; NB: Even though this should be readnone, that will cause LLVM to try to +; generate a single zero tile, and copy it each time it is used. However the AMX +; registers cannot be copied, so this causes compilation failures: +; LLVM ERROR: Cannot emit physreg copy instruction +; renamable $tmm1 = COPY renamable $tmm0 +define weak_odr <256 x i32> @tilezero_i32(i16 %rows, i16 %colbytes) nounwind alwaysinline { + %1 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %rows, i16 %colbytes) + %2 = bitcast x86_amx %1 to <256 x i32> + ret <256 x i32> %2 +} +declare x86_amx @llvm.x86.tilezero.internal(i16, i16) diff --git a/test/performance/CMakeLists.txt b/test/performance/CMakeLists.txt index 65aa41da00f3..80f58a16afae 100644 --- a/test/performance/CMakeLists.txt +++ b/test/performance/CMakeLists.txt @@ -1,5 +1,6 @@ tests(GROUPS performance SOURCES + tiled_matmul.cpp async_gpu.cpp block_transpose.cpp boundary_conditions.cpp diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp new file mode 100644 index 000000000000..b3be78721f10 --- /dev/null +++ b/test/performance/tiled_matmul.cpp @@ -0,0 +1,130 @@ +#include "Halide.h" +#include "halide_benchmark.h" +#include "halide_test_dirs.h" +#include +#include + +using namespace Halide; + +#define FUSE 0 + +int main(int argc, char **argv) { + const int row = 16; + const int col = 16; + const int acc = 16; + + Var x("x"), y("y"); + ImageParam A(Int(8), 2, "lhs"); + ImageParam B(Int(8), 3, "rhs"); + + RDom r(0, acc); + + Func mm("matmul"); + mm(y, x) = cast(0); + mm(y, x) += cast(A(r.x, x)) * B(r.x % 4, y, r.x / 4); + + // Ensure all (x, y) tile sizes are the same so that loops are fused. + int tile_y = 8; + int tile_x = 6; + int tile_r = 4; + + // Schedule the reduction + Var rxi("rxi"), ryi("ryi"), rz("rz"); + RVar rri("rri"), rro("rro"); + mm.compute_at(mm.in(), y) + .update() + // Split into (x,y) tile + .tile(y, x, ryi, rxi, tile_y, tile_x, TailStrategy::GuardWithIf) + // Split reduction dim by tile_r + .split(r.x, rro, rri, tile_r) + // Reorder so that the (x,y) tile is inside the inner ro loop + .reorder({rri, ryi, rxi, rro, y, x}) + .atomic() + .vectorize(rri) + .vectorize(ryi) + .vectorize(rxi); + + // Schedule the initialization + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), y) + .tile(y, x, iyi, ixi, tile_y, tile_x) + .vectorize(iyi) + .vectorize(ixi); + + // Schedule the consumer + Var mmxi("mmxi"), mmyi("mmyi"), mmz("mmz"); + mm.in() + .tile(y, x, mmyi, mmxi, tile_y, tile_x) + .vectorize(mmyi) + .vectorize(mmxi); + + int count = 1; + Buffer a_buf(acc, row); + for (int iy = 0; iy < row; iy++) { + for (int ix = 0; ix < acc; ix++) { + a_buf(ix, iy) = count++; //rand() % 256 - 128; + } + } + A.set(a_buf); + + Buffer b_buf(4, col, acc / 4); + count = 1; + for (int iy = 0; iy < acc / 4; iy++) { + for (int ix = 0; ix < col; ix++) { + for (int ik = 0; ik < 4; ++ik) { + b_buf(ik, ix, iy) = count++; //rand() % 256 - 128; + } + } + } + B.set(b_buf); + + Buffer out(col, row); + + Func result = mm.in(); + + // Uncomment to check the asm + Target target = get_jit_target_from_environment(); + result.compile_to_llvm_assembly("matmul.ll", {A, B}, target); + //result.compile_to_assembly("matmul.s", {A, B}, target); + + auto time = Tools::benchmark(20, 20, [&]() { + result.realize(out); + }); + std::cout << "Exec time: " << time << "\n"; + + for (int i = 0; i < row; ++i) { + for (int j = 0; j < acc; ++j) { + std::cout << std::setw(4) << (int)a_buf(j, i) << " "; + } + std::cout << "\n"; + } + std::cout << "\n\n*\n\n"; + for (int i = 0; i < acc; ++i) { + for (int j = 0; j < col; ++j) { + std::cout << std::setw(4) << (int)b_buf(i % 4, j, i / 4) << " "; + } + std::cout << "\n"; + } + std::cout << "\n\n=\n\n"; + for (int i = 0; i < row; ++i) { + for (int j = 0; j < col; ++j) { + std::cout << std::setw(6) << out(j, i) << " "; + } + std::cout << "\n"; + } + + for (int j = 0; j < row; ++j) { + for (int i = 0; i < col; ++i) { + int32_t val = 0; + for (int k = 0; k < acc; ++k) { + val += a_buf(k, j) * b_buf(k % 4, i, k / 4); + } + if (val != out(i, j)) { + std::cerr << "Invalid result at " << i << ", " << j << "\n" + << out(i, j) << " != " << val << "\n"; + return 1; + } + } + } + return 0; +} From 2bd34539cb0f3166bee42b57dcf86ad3822589ec Mon Sep 17 00:00:00 2001 From: John Lawson Date: Tue, 9 Mar 2021 15:42:28 +0000 Subject: [PATCH 02/53] Make AMX transform opt-in with memory type --- src/ExtractTileOperations.cpp | 4 +++- test/performance/tiled_matmul.cpp | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index e8edc3532775..5d2799b82655 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -188,7 +188,9 @@ class ExtractTileOperations : public IRMutator { int found_tile_r = -1; Stmt visit(const Allocate *op) override { - if (op->type.is_int() && op->type.bits() == 32) { + if (op->memory_type == MemoryType::AMXTile && + op->type.is_int() && + op->type.bits() == 32) { if (in_allocate) { // Found two possible tile allocations // FIXME: Handle this better diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index b3be78721f10..fe250857417e 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -32,6 +32,7 @@ int main(int argc, char **argv) { Var rxi("rxi"), ryi("ryi"), rz("rz"); RVar rri("rri"), rro("rro"); mm.compute_at(mm.in(), y) + .store_in(MemoryType::AMXTile) .update() // Split into (x,y) tile .tile(y, x, ryi, rxi, tile_y, tile_x, TailStrategy::GuardWithIf) From 34e1a4c8519621a4c187b0da7243140b1ede503d Mon Sep 17 00:00:00 2001 From: John Lawson Date: Wed, 10 Mar 2021 14:50:59 +0000 Subject: [PATCH 03/53] Clean up tiled_matmul test --- test/performance/tiled_matmul.cpp | 45 ++++++++----------------------- 1 file changed, 11 insertions(+), 34 deletions(-) diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index fe250857417e..92d545ff486a 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -1,13 +1,9 @@ #include "Halide.h" #include "halide_benchmark.h" -#include "halide_test_dirs.h" #include -#include using namespace Halide; -#define FUSE 0 - int main(int argc, char **argv) { const int row = 16; const int col = 16; @@ -15,6 +11,10 @@ int main(int argc, char **argv) { Var x("x"), y("y"); ImageParam A(Int(8), 2, "lhs"); + // NB the RHS matrix in AMX instructions should be tiled in "VNNI format", + // where instead of being (cols, rows) where rows are adjacent in memory it + // should be (4, cols, rows / 4) for int8, or (2, cols, rows / 2) for bf16. + // This means that the rows must always be divisible by 4 (or 2 for bf16). ImageParam B(Int(8), 3, "rhs"); RDom r(0, acc); @@ -29,7 +29,7 @@ int main(int argc, char **argv) { int tile_r = 4; // Schedule the reduction - Var rxi("rxi"), ryi("ryi"), rz("rz"); + Var rxi("rxi"), ryi("ryi"); RVar rri("rri"), rro("rro"); mm.compute_at(mm.in(), y) .store_in(MemoryType::AMXTile) @@ -53,27 +53,25 @@ int main(int argc, char **argv) { .vectorize(ixi); // Schedule the consumer - Var mmxi("mmxi"), mmyi("mmyi"), mmz("mmz"); + Var mmxi("mmxi"), mmyi("mmyi"); mm.in() .tile(y, x, mmyi, mmxi, tile_y, tile_x) .vectorize(mmyi) .vectorize(mmxi); - int count = 1; Buffer a_buf(acc, row); for (int iy = 0; iy < row; iy++) { for (int ix = 0; ix < acc; ix++) { - a_buf(ix, iy) = count++; //rand() % 256 - 128; + a_buf(ix, iy) = rand() % 256 - 128; } } A.set(a_buf); Buffer b_buf(4, col, acc / 4); - count = 1; for (int iy = 0; iy < acc / 4; iy++) { for (int ix = 0; ix < col; ix++) { for (int ik = 0; ik < 4; ++ik) { - b_buf(ik, ix, iy) = count++; //rand() % 256 - 128; + b_buf(ik, ix, iy) = rand() % 256 - 128; } } } @@ -84,36 +82,15 @@ int main(int argc, char **argv) { Func result = mm.in(); // Uncomment to check the asm - Target target = get_jit_target_from_environment(); - result.compile_to_llvm_assembly("matmul.ll", {A, B}, target); - //result.compile_to_assembly("matmul.s", {A, B}, target); + //Target target = get_jit_target_from_environment(); + //result.compile_to_llvm_assembly("tiled_matmul.ll", {A, B}, target); + //result.compile_to_assembly("tiled_matmul.s", {A, B}, target); auto time = Tools::benchmark(20, 20, [&]() { result.realize(out); }); std::cout << "Exec time: " << time << "\n"; - for (int i = 0; i < row; ++i) { - for (int j = 0; j < acc; ++j) { - std::cout << std::setw(4) << (int)a_buf(j, i) << " "; - } - std::cout << "\n"; - } - std::cout << "\n\n*\n\n"; - for (int i = 0; i < acc; ++i) { - for (int j = 0; j < col; ++j) { - std::cout << std::setw(4) << (int)b_buf(i % 4, j, i / 4) << " "; - } - std::cout << "\n"; - } - std::cout << "\n\n=\n\n"; - for (int i = 0; i < row; ++i) { - for (int j = 0; j < col; ++j) { - std::cout << std::setw(6) << out(j, i) << " "; - } - std::cout << "\n"; - } - for (int j = 0; j < row; ++j) { for (int i = 0; i < col; ++i) { int32_t val = 0; From dfeac5576edf1e9bc7911fa85b25cb51edddb033 Mon Sep 17 00:00:00 2001 From: John Lawson Date: Tue, 9 Mar 2021 16:27:01 +0000 Subject: [PATCH 04/53] Handle AMX intrinsic attributes better --- src/CodeGen_X86.cpp | 12 +++++++++--- src/runtime/x86_avx512.ll | 12 ++++++------ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 0aa64e089845..e0e34d163f04 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -102,6 +102,10 @@ struct x86Intrinsic { const char *name; halide_type_t arg_types[max_intrinsic_args]; Target::Feature feature = Target::FeatureEnd; + uint32_t flags = 0; + enum Options { + AccessesMemory = 1 << 0, + }; }; // clang-format off @@ -192,11 +196,11 @@ const x86Intrinsic intrinsic_defs[] = { {"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids}, {"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids}, - {"tileloadd64_i8", Int(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids}, + {"tileloadd64_i8", Int(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, {"tdpbssd", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), Int(8, 1024), Int(8, 1024)}, Target::AVX512_SapphireRapids}, {"tilezero_i32", Int(32, 256), "tile_zero", {Int(16), Int(16)}, Target::AVX512_SapphireRapids}, // CodeGen_LLVM cannot cope with returning Type() ie void*, and return type needs to be vector to trigger call_overloaded_intrin - {"tilestored64", Bool(2), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Int(32, 256)}, Target::AVX512_SapphireRapids}, + {"tilestored64", Bool(2), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Int(32, 256)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, }; // clang-format on @@ -220,7 +224,9 @@ void CodeGen_X86::init_module() { } auto *fn = declare_intrin_overload(i.name, ret_type, i.intrin_name, std::move(arg_types)); - fn->addFnAttr(llvm::Attribute::ReadNone); + if((i.flags & x86Intrinsic::AccessesMemory) == 0) { + fn->addFnAttr(llvm::Attribute::ReadNone); + } fn->addFnAttr(llvm::Attribute::NoUnwind); } } diff --git a/src/runtime/x86_avx512.ll b/src/runtime/x86_avx512.ll index 5e9ace735bfd..7217caeb9f01 100644 --- a/src/runtime/x86_avx512.ll +++ b/src/runtime/x86_avx512.ll @@ -91,15 +91,15 @@ define weak_odr <4 x i32> @dpwssdx4(<4 x i32> %init, <8 x i16> %a, <8 x i16> %b } declare <4 x i32> @llvm.x86.avx512.vpdpwssd.128(<4 x i32>, <4 x i32>, <4 x i32>) -define weak_odr <1024 x i8> @tileloadd64_i8(i16 %rows, i16 %colbytes, i8* %ptr, i64 %off, i64 %stride) nounwind alwaysinline { +define weak_odr <1024 x i8> @tileloadd64_i8(i16 %rows, i16 %colbytes, i8* %ptr, i64 %off, i64 %stride) nounwind alwaysinline readonly { %1 = getelementptr i8, i8* %ptr, i64 %off - %2 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %rows, i16 %colbytes, i8* %1, i64 %stride) + %2 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %rows, i16 %colbytes, i8* %1, i64 %stride) nounwind readonly %3 = bitcast x86_amx %2 to <1024 x i8> ret <1024 x i8> %3 } declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) -define weak_odr <256 x i32> @tdpbssd(i16 %rows, i16 %colbytes, i16 %acc, <256 x i32> %out, <1024 x i8> %lhs, <1024 x i8> %rhs) nounwind readnone alwaysinline { +define weak_odr <256 x i32> @tdpbssd(i16 %rows, i16 %colbytes, i16 %acc, <256 x i32> %out, <1024 x i8> %lhs, <1024 x i8> %rhs) nounwind alwaysinline readnone { %1 = bitcast <1024 x i8> %lhs to x86_amx %2 = bitcast <1024 x i8> %rhs to x86_amx %3 = bitcast <256 x i32> %out to x86_amx @@ -109,10 +109,10 @@ define weak_odr <256 x i32> @tdpbssd(i16 %rows, i16 %colbytes, i16 %acc, <256 x } declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) -define weak_odr <2 x i1> @tilestored64(i16 %rows, i16 %cols, i8* %ptr, i64 %off, i64 %stride, <256 x i32> %val) nounwind alwaysinline { +define weak_odr <2 x i1> @tilestored64(i16 %rows, i16 %cols, i8* %ptr, i64 %off, i64 %stride, <256 x i32> %val) nounwind alwaysinline writeonly { %1 = getelementptr i8, i8* %ptr, i64 %off %2 = bitcast <256 x i32> %val to x86_amx - tail call void @llvm.x86.tilestored64.internal(i16 %rows, i16 %cols, i8* %1, i64 %stride, x86_amx %2) + tail call void @llvm.x86.tilestored64.internal(i16 %rows, i16 %cols, i8* %1, i64 %stride, x86_amx %2) nounwind writeonly ret <2 x i1> zeroinitializer } declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) @@ -123,7 +123,7 @@ declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) ; LLVM ERROR: Cannot emit physreg copy instruction ; renamable $tmm1 = COPY renamable $tmm0 define weak_odr <256 x i32> @tilezero_i32(i16 %rows, i16 %colbytes) nounwind alwaysinline { - %1 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %rows, i16 %colbytes) + %1 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %rows, i16 %colbytes) nounwind %2 = bitcast x86_amx %1 to <256 x i32> ret <256 x i32> %2 } From a9f84ded53b33fe7bc390a875abe0c30df4eec6d Mon Sep 17 00:00:00 2001 From: John Lawson Date: Wed, 10 Mar 2021 15:16:24 +0000 Subject: [PATCH 05/53] Format --- src/CodeGen_X86.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index e0e34d163f04..a8c7cdb12f05 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -1,5 +1,4 @@ #include "CodeGen_Posix.h" - #include "ConciseCasts.h" #include "Debug.h" #include "IRMatch.h" @@ -224,8 +223,8 @@ void CodeGen_X86::init_module() { } auto *fn = declare_intrin_overload(i.name, ret_type, i.intrin_name, std::move(arg_types)); - if((i.flags & x86Intrinsic::AccessesMemory) == 0) { - fn->addFnAttr(llvm::Attribute::ReadNone); + if ((i.flags & x86Intrinsic::AccessesMemory) == 0) { + fn->addFnAttr(llvm::Attribute::ReadNone); } fn->addFnAttr(llvm::Attribute::NoUnwind); } From da04b0aaf147ef1ac42519403513345fa8cdd4b0 Mon Sep 17 00:00:00 2001 From: John Lawson Date: Thu, 11 Mar 2021 12:30:17 +0000 Subject: [PATCH 06/53] Fix test to behave like other tests --- test/performance/tiled_matmul.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index 92d545ff486a..7670ecb7d642 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -1,6 +1,9 @@ #include "Halide.h" #include "halide_benchmark.h" +#include "halide_test_dirs.h" + #include +#include using namespace Halide; @@ -83,8 +86,8 @@ int main(int argc, char **argv) { // Uncomment to check the asm //Target target = get_jit_target_from_environment(); - //result.compile_to_llvm_assembly("tiled_matmul.ll", {A, B}, target); - //result.compile_to_assembly("tiled_matmul.s", {A, B}, target); + //result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.ll", {A, B}, target); + //result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); auto time = Tools::benchmark(20, 20, [&]() { result.realize(out); @@ -104,5 +107,6 @@ int main(int argc, char **argv) { } } } + std::cout << "Success!\n"; return 0; } From 992304052d6ee2b52f452e0d015b0a5eb3d840b5 Mon Sep 17 00:00:00 2001 From: John Lawson Date: Thu, 11 Mar 2021 13:21:27 +0000 Subject: [PATCH 07/53] Add doc and missing load check --- src/ExtractTileOperations.cpp | 9 +++++++++ src/ExtractTileOperations.h | 3 ++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 5d2799b82655..641cfa3623ca 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -239,6 +239,15 @@ class ExtractTileOperations : public IRMutator { return ProducerConsumer::make(amx_alloc, op->is_producer, body); } + Expr visit(const Load* op) override { + if (op->name == tile_name) { + // Any tile load will be matched elsewhere, so a load here means that + // the AMX tile is used outside of a tile instruction. + is_valid = false; + } + return IRMutator::visit(op); + } + Stmt visit(const Store *op) override { if (op->name != tile_name) { const auto *load = op->value.as(); diff --git a/src/ExtractTileOperations.h b/src/ExtractTileOperations.h index d246bddc5a04..918e3b1b9940 100644 --- a/src/ExtractTileOperations.h +++ b/src/ExtractTileOperations.h @@ -11,7 +11,8 @@ namespace Halide { namespace Internal { -/** TODO */ +/** Rewrite any AMX tile operations that have been stored in the AMXTile memory + * type as intrinsic calls, to be used in the X86 backend. */ Stmt extract_tile_operations(const Stmt &s); } // namespace Internal From 1a1d10a0647194ab03e479b7613fa04b6d923bce Mon Sep 17 00:00:00 2001 From: John Lawson Date: Thu, 11 Mar 2021 14:50:58 +0000 Subject: [PATCH 08/53] Format --- src/ExtractTileOperations.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 641cfa3623ca..731f028c5835 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -239,11 +239,11 @@ class ExtractTileOperations : public IRMutator { return ProducerConsumer::make(amx_alloc, op->is_producer, body); } - Expr visit(const Load* op) override { + Expr visit(const Load *op) override { if (op->name == tile_name) { - // Any tile load will be matched elsewhere, so a load here means that - // the AMX tile is used outside of a tile instruction. - is_valid = false; + // Any tile load will be matched elsewhere, so a load here means that + // the AMX tile is used outside of a tile instruction. + is_valid = false; } return IRMutator::visit(op); } From d54fe249917a33e602f5d3b2fc92c0082b7262b5 Mon Sep 17 00:00:00 2001 From: John Lawson Date: Fri, 12 Mar 2021 12:41:34 +0000 Subject: [PATCH 09/53] Throw error if user requests AMX for invalid operation --- src/ExtractTileOperations.cpp | 84 +++++++++++++------------------ test/performance/tiled_matmul.cpp | 6 ++- 2 files changed, 39 insertions(+), 51 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 731f028c5835..62d877aac354 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -8,6 +8,9 @@ namespace Halide { namespace Internal { +using std::string; +using std::vector; + namespace { template @@ -23,7 +26,7 @@ const auto wild_i32x = Variable::make(Int(32, 0), "*"); Tile<2> is_2d_tile_index(const Expr &e) { // ramp(ramp(base, 1, 4), x4(stride), 4) - std::vector matches; + vector matches; if (const auto *r1 = e.as()) { if (const auto *r2 = r1->base.as()) { auto ramp_2d_pattern = Ramp::make(Ramp::make(wild_i32, wild_i32, r2->lanes), Broadcast::make(wild_i32, r2->lanes), r1->lanes); @@ -36,7 +39,7 @@ Tile<2> is_2d_tile_index(const Expr &e) { } Tile<3> is_3d_tile_index(const Expr &e) { - std::vector matches; + vector matches; auto add_sub_pattern = (wild_i32x + wild_i32x) - wild_i32x; if (!expr_match(add_sub_pattern, e, matches)) { return {}; } // ramp(x16(base), x16(stride), 4) + x16(ramp(idx, 1, 4)) y: 4, x: 4, r: 4 @@ -89,11 +92,11 @@ struct NewMatmul { }; NewMatmul -convert_to_matmul(const Store *op, const std::string &new_name) { +convert_to_matmul(const Store *op, const string &new_name) { // m[ramp(0, 1, S)] = VectorAdd(lhs[{XYR tile}] * xX(rhs[{YR tile}])) + m[ramp(0, 1, S)] const auto wild_i8x = Variable::make(Int(8, 0), "*"); const auto wild_i16x = Variable::make(Int(16, 0), "*"); - std::vector matches; + vector matches; const auto pattern1 = wild_i32x + wild_i32x; if (!expr_match(pattern1, op->value, matches)) { return {}; } const auto *reduce = matches[0].as(); @@ -143,7 +146,7 @@ convert_to_matmul(const Store *op, const std::string &new_name) { return {true, std::move(store), tile_x, tile_y, tile_r}; } -Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const std::string &new_name) { +Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const string &new_name) { if (const auto *ramp = op->index.as()) { if (const auto *bcast = op->value.as()) { if (is_const_one(ramp->stride) && @@ -161,11 +164,11 @@ Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const std::string return {}; } -Stmt convert_to_tile_store(const Store *op, const std::string &amx_alloc, int tile_x, int tile_y) { +Stmt convert_to_tile_store(const Store *op, const string &amx_name, int tile_x, int tile_y) { auto tile = is_2d_tile_index(op->index); if (tile.result && tile.extent[0] == tile_x && tile.extent[1] == tile_y) { auto out = Variable::make(Handle(), op->name); - auto tile_val = Load::make(Int(32, 256), amx_alloc, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); + auto tile_val = Load::make(Int(32, 256), amx_name, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); auto bytes = op->value.type().bytes(); internal_assert(bytes == 4) << "AMX store only supported for int32 and float32, not for " << op->value.type() << "\n"; // {tile_x, tile_y, var, base, stride} @@ -178,10 +181,9 @@ Stmt convert_to_tile_store(const Store *op, const std::string &amx_alloc, int ti class ExtractTileOperations : public IRMutator { using IRMutator::visit; - std::string tile_name; - std::string amx_alloc; - std::vector pending_stores; - bool is_valid = true; + string tile_name; + string amx_name; + vector pending_stores; bool in_allocate = false; int found_tile_x = -1; int found_tile_y = -1; @@ -191,22 +193,15 @@ class ExtractTileOperations : public IRMutator { if (op->memory_type == MemoryType::AMXTile && op->type.is_int() && op->type.bits() == 32) { - if (in_allocate) { - // Found two possible tile allocations - // FIXME: Handle this better - is_valid = false; - return op; - } - amx_alloc = op->name + ".amx"; - tile_name = op->name; + // FIXME: Handle nested allocations better + user_assert(!in_allocate) << "Found two possible tile allocations for AMX allocation"; + ScopedValue old_amx_name(amx_name, op->name + ".amx"); + ScopedValue old_tile_name(tile_name, op->name); ScopedValue old_in_alloc(in_allocate, true); Stmt body = op->body; pending_stores.clear(); body = mutate(body); - if (!is_valid) { - return op; - } if (found_tile_x < 0 || found_tile_y < 0 || found_tile_r < 0) { return op; } @@ -214,11 +209,8 @@ class ExtractTileOperations : public IRMutator { // Really only need to go over the pending stores body = mutate(body); } - if (!is_valid) { - return op; - } - return Allocate::make(amx_alloc, Int(32, 256), MemoryType::AMXTile, {1}, const_true(), body); + return Allocate::make(amx_name, Int(32, 256), MemoryType::AMXTile, {1}, const_true(), body); } return IRMutator::visit(op); } @@ -227,7 +219,7 @@ class ExtractTileOperations : public IRMutator { if (op->name != tile_name) { return op; } - return Free::make(amx_alloc); + return Free::make(amx_name); } Stmt visit(const ProducerConsumer *op) override { @@ -236,15 +228,13 @@ class ExtractTileOperations : public IRMutator { } auto body = mutate(op->body); - return ProducerConsumer::make(amx_alloc, op->is_producer, body); + return ProducerConsumer::make(amx_name, op->is_producer, body); } Expr visit(const Load *op) override { - if (op->name == tile_name) { - // Any tile load will be matched elsewhere, so a load here means that - // the AMX tile is used outside of a tile instruction. - is_valid = false; - } + // Any tile load will be matched elsewhere, so a load here means that + // the AMX tile is used outside of a tile instruction. + user_assert(op->name != tile_name) << "AMX tile allocation used outside a tile instruction"; return IRMutator::visit(op); } @@ -254,24 +244,18 @@ class ExtractTileOperations : public IRMutator { if (!load || load->name != tile_name) { return op; } - auto store = convert_to_tile_store(op, amx_alloc, found_tile_x, found_tile_y); - if (store.defined()) { - return store; - } else { - // Found store of tile_name that is not a tile store. - is_valid = false; - return op; - } + auto store = convert_to_tile_store(op, amx_name, found_tile_x, found_tile_y); + user_assert(store.defined()) << "Store to AMX tile allocation of a non-tile value"; + return store; } - auto matmul = convert_to_matmul(op, amx_alloc); + auto matmul = convert_to_matmul(op, amx_name); if (matmul.result) { - if ((found_tile_x > 0 && matmul.tile_x != found_tile_x) || - (found_tile_r > 0 && matmul.tile_r != found_tile_r) || - (found_tile_y > 0 && matmul.tile_y != found_tile_y)) { - is_valid = false; - return op; - } + user_assert( + (found_tile_x < 0 || matmul.tile_x == found_tile_x) && + (found_tile_x < 0 || matmul.tile_x == found_tile_x) && + (found_tile_x < 0 || matmul.tile_x == found_tile_x)) + << "Found different tile sizes for AMX tile allocation"; found_tile_x = matmul.tile_x; found_tile_y = matmul.tile_y; found_tile_r = matmul.tile_r; @@ -283,13 +267,13 @@ class ExtractTileOperations : public IRMutator { return op; } - auto zero = convert_to_zero(op, found_tile_x, found_tile_y, amx_alloc); + auto zero = convert_to_zero(op, found_tile_x, found_tile_y, amx_name); if (zero.defined()) { return zero; } // Otherwise there is some other operation using the allocation, so we cannot use the AMX instructions - is_valid = false; + user_assert(false) << "Found non-tile operations for AMX tile allocation"; return op; } }; diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index 7670ecb7d642..fbcd4b292942 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -8,6 +8,11 @@ using namespace Halide; int main(int argc, char **argv) { + Target target = get_jit_target_from_environment(); + if (!target.has_feature(Target::AVX512_SapphireRapids)) { + std::cout << "[SKIP] The tiled matmul test is only designed to test AMX support.\n"; + return 0; + } const int row = 16; const int col = 16; const int acc = 16; @@ -85,7 +90,6 @@ int main(int argc, char **argv) { Func result = mm.in(); // Uncomment to check the asm - //Target target = get_jit_target_from_environment(); //result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.ll", {A, B}, target); //result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); From 228beda1df5095361be4a9aca11de256653edcec Mon Sep 17 00:00:00 2001 From: John Lawson Date: Fri, 12 Mar 2021 14:43:59 +0000 Subject: [PATCH 10/53] Add Tile lowering pass to makefile --- Makefile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Makefile b/Makefile index 1e80eba0a4a5..30fc4f531273 100644 --- a/Makefile +++ b/Makefile @@ -454,6 +454,7 @@ SOURCE_FILES = \ EmulateFloat16Math.cpp \ Error.cpp \ Expr.cpp \ + ExtractTileOperations.cpp \ FastIntegerDivide.cpp \ FindCalls.cpp \ FindIntrinsics.cpp \ @@ -626,6 +627,7 @@ HEADER_FILES = \ ExprUsesVar.h \ Extern.h \ ExternFuncArgument.h \ + ExtractTileOperations.h \ FastIntegerDivide.h \ FindCalls.h \ FindIntrinsics.h \ From 673480a8ac72d8480f181f13cb986e7131b6d28a Mon Sep 17 00:00:00 2001 From: John Lawson Date: Fri, 12 Mar 2021 15:15:35 +0000 Subject: [PATCH 11/53] Use spaces in Makefile --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 30fc4f531273..ca4802069fa4 100644 --- a/Makefile +++ b/Makefile @@ -454,7 +454,7 @@ SOURCE_FILES = \ EmulateFloat16Math.cpp \ Error.cpp \ Expr.cpp \ - ExtractTileOperations.cpp \ + ExtractTileOperations.cpp \ FastIntegerDivide.cpp \ FindCalls.cpp \ FindIntrinsics.cpp \ @@ -627,7 +627,7 @@ HEADER_FILES = \ ExprUsesVar.h \ Extern.h \ ExternFuncArgument.h \ - ExtractTileOperations.h \ + ExtractTileOperations.h \ FastIntegerDivide.h \ FindCalls.h \ FindIntrinsics.h \ From f91d79f8c864aa4355219ad2a550fc0fff8c8499 Mon Sep 17 00:00:00 2001 From: Thales Sabino Date: Mon, 15 Mar 2021 17:55:30 +0000 Subject: [PATCH 12/53] Place AMX instrinsics into a separate module (x86_amx.ll) This will only be included if LLVM >= 12 is used to build Halide --- src/runtime/CMakeLists.txt | 5 +++++ src/runtime/x86_amx.ll | 37 +++++++++++++++++++++++++++++++++++++ src/runtime/x86_avx512.ll | 38 -------------------------------------- 3 files changed, 42 insertions(+), 38 deletions(-) create mode 100644 src/runtime/x86_amx.ll diff --git a/src/runtime/CMakeLists.txt b/src/runtime/CMakeLists.txt index 19614d1dc2d6..77b80c3cbb1b 100644 --- a/src/runtime/CMakeLists.txt +++ b/src/runtime/CMakeLists.txt @@ -115,6 +115,11 @@ set(RUNTIME_LL x86_sse41 ) +if (LLVM_PACKAGE_VERSION VERSION_GREATER_EQUAL 12.0) + # AMX instructions require LLVM 12 or newer + list(APPEND RUNTIME_LL x86_amx) +endif () + set(RUNTIME_BC compute_20 compute_30 diff --git a/src/runtime/x86_amx.ll b/src/runtime/x86_amx.ll new file mode 100644 index 000000000000..265d7ef2e381 --- /dev/null +++ b/src/runtime/x86_amx.ll @@ -0,0 +1,37 @@ +define weak_odr <1024 x i8> @tileloadd64_i8(i16 %rows, i16 %colbytes, i8* %ptr, i64 %off, i64 %stride) nounwind alwaysinline readonly { + %1 = getelementptr i8, i8* %ptr, i64 %off + %2 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %rows, i16 %colbytes, i8* %1, i64 %stride) nounwind readonly + %3 = bitcast x86_amx %2 to <1024 x i8> + ret <1024 x i8> %3 +} +declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) + +define weak_odr <256 x i32> @tdpbssd(i16 %rows, i16 %colbytes, i16 %acc, <256 x i32> %out, <1024 x i8> %lhs, <1024 x i8> %rhs) nounwind alwaysinline readnone { + %1 = bitcast <1024 x i8> %lhs to x86_amx + %2 = bitcast <1024 x i8> %rhs to x86_amx + %3 = bitcast <256 x i32> %out to x86_amx + %4 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %rows, i16 %colbytes, i16 %acc, x86_amx %3, x86_amx %1, x86_amx %2) nounwind readnone + %5 = bitcast x86_amx %4 to <256 x i32> + ret <256 x i32> %5 +} +declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) + +define weak_odr <2 x i1> @tilestored64(i16 %rows, i16 %cols, i8* %ptr, i64 %off, i64 %stride, <256 x i32> %val) nounwind alwaysinline writeonly { + %1 = getelementptr i8, i8* %ptr, i64 %off + %2 = bitcast <256 x i32> %val to x86_amx + tail call void @llvm.x86.tilestored64.internal(i16 %rows, i16 %cols, i8* %1, i64 %stride, x86_amx %2) nounwind writeonly + ret <2 x i1> zeroinitializer +} +declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) + +; NB: Even though this should be readnone, that will cause LLVM to try to +; generate a single zero tile, and copy it each time it is used. However the AMX +; registers cannot be copied, so this causes compilation failures: +; LLVM ERROR: Cannot emit physreg copy instruction +; renamable $tmm1 = COPY renamable $tmm0 +define weak_odr <256 x i32> @tilezero_i32(i16 %rows, i16 %colbytes) nounwind alwaysinline { + %1 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %rows, i16 %colbytes) nounwind + %2 = bitcast x86_amx %1 to <256 x i32> + ret <256 x i32> %2 +} +declare x86_amx @llvm.x86.tilezero.internal(i16, i16) diff --git a/src/runtime/x86_avx512.ll b/src/runtime/x86_avx512.ll index 7217caeb9f01..904fabe9368e 100644 --- a/src/runtime/x86_avx512.ll +++ b/src/runtime/x86_avx512.ll @@ -90,41 +90,3 @@ define weak_odr <4 x i32> @dpwssdx4(<4 x i32> %init, <8 x i16> %a, <8 x i16> %b ret <4 x i32> %3 } declare <4 x i32> @llvm.x86.avx512.vpdpwssd.128(<4 x i32>, <4 x i32>, <4 x i32>) - -define weak_odr <1024 x i8> @tileloadd64_i8(i16 %rows, i16 %colbytes, i8* %ptr, i64 %off, i64 %stride) nounwind alwaysinline readonly { - %1 = getelementptr i8, i8* %ptr, i64 %off - %2 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %rows, i16 %colbytes, i8* %1, i64 %stride) nounwind readonly - %3 = bitcast x86_amx %2 to <1024 x i8> - ret <1024 x i8> %3 -} -declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) - -define weak_odr <256 x i32> @tdpbssd(i16 %rows, i16 %colbytes, i16 %acc, <256 x i32> %out, <1024 x i8> %lhs, <1024 x i8> %rhs) nounwind alwaysinline readnone { - %1 = bitcast <1024 x i8> %lhs to x86_amx - %2 = bitcast <1024 x i8> %rhs to x86_amx - %3 = bitcast <256 x i32> %out to x86_amx - %4 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %rows, i16 %colbytes, i16 %acc, x86_amx %3, x86_amx %1, x86_amx %2) nounwind readnone - %5 = bitcast x86_amx %4 to <256 x i32> - ret <256 x i32> %5 -} -declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) - -define weak_odr <2 x i1> @tilestored64(i16 %rows, i16 %cols, i8* %ptr, i64 %off, i64 %stride, <256 x i32> %val) nounwind alwaysinline writeonly { - %1 = getelementptr i8, i8* %ptr, i64 %off - %2 = bitcast <256 x i32> %val to x86_amx - tail call void @llvm.x86.tilestored64.internal(i16 %rows, i16 %cols, i8* %1, i64 %stride, x86_amx %2) nounwind writeonly - ret <2 x i1> zeroinitializer -} -declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) - -; NB: Even though this should be readnone, that will cause LLVM to try to -; generate a single zero tile, and copy it each time it is used. However the AMX -; registers cannot be copied, so this causes compilation failures: -; LLVM ERROR: Cannot emit physreg copy instruction -; renamable $tmm1 = COPY renamable $tmm0 -define weak_odr <256 x i32> @tilezero_i32(i16 %rows, i16 %colbytes) nounwind alwaysinline { - %1 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %rows, i16 %colbytes) nounwind - %2 = bitcast x86_amx %1 to <256 x i32> - ret <256 x i32> %2 -} -declare x86_amx @llvm.x86.tilezero.internal(i16, i16) From 16a1c7b8536fa588ef09e13199351f3175d2a834 Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Wed, 17 Mar 2021 12:56:47 -0700 Subject: [PATCH 13/53] Fix CreateAlignedLoad() call in CodeGen_X86 Recent changes in LLVM trunk made the previous calling convention deprecated (and thus compiling with warning/error) --- src/CodeGen_X86.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index a8c7cdb12f05..4a6a5be5ea20 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -605,7 +605,7 @@ void CodeGen_X86::visit(const Load *op) { const Ramp *ramp = op->index.as(); internal_assert(ramp) << "Expected AMXTile to have index ramp\n"; Value *ptr = codegen_buffer_pointer(op->name, op->type, ramp->base); - LoadInst *load = builder->CreateAlignedLoad(ptr, llvm::Align(op->type.bytes())); + LoadInst *load = builder->CreateAlignedLoad(ptr->getType()->getPointerElementType(), ptr, llvm::Align(op->type.bytes())); add_tbaa_metadata(load, op->name, op->index); value = load; return; From d0e512361118b120eb7b2c4f89aada954a0d406f Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Mon, 19 Apr 2021 18:01:13 +0100 Subject: [PATCH 14/53] fix exporting to module --- dependencies/llvm/CMakeLists.txt | 3 +++ src/LLVM_Runtime_Linker.cpp | 7 +++++++ src/runtime/CMakeLists.txt | 2 +- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/dependencies/llvm/CMakeLists.txt b/dependencies/llvm/CMakeLists.txt index 1e7923d03c05..bb93e721b534 100644 --- a/dependencies/llvm/CMakeLists.txt +++ b/dependencies/llvm/CMakeLists.txt @@ -16,6 +16,9 @@ find_package(Clang REQUIRED CONFIG HINTS "${LLVM_DIR}/../clang") message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") +# LLVM_PACKAGE_VERSION does not propagate to higher scopes +set(Halide_LLVM_VERSION ${LLVM_PACKAGE_VERSION} CACHE INTERNAL "Provided LLVM version") + if (LLVM_PACKAGE_VERSION VERSION_LESS 10.0) message(FATAL_ERROR "LLVM version must be 10.0 or newer") endif () diff --git a/src/LLVM_Runtime_Linker.cpp b/src/LLVM_Runtime_Linker.cpp index a3ef7b27454f..4b4c1eb5cfb5 100644 --- a/src/LLVM_Runtime_Linker.cpp +++ b/src/LLVM_Runtime_Linker.cpp @@ -230,6 +230,7 @@ DECLARE_NO_INITMOD(windows_d3d12compute_arm) #endif // WITH_D3D12 #ifdef WITH_X86 +DECLARE_LL_INITMOD(x86_amx) DECLARE_LL_INITMOD(x86_avx512) DECLARE_LL_INITMOD(x86_avx2) DECLARE_LL_INITMOD(x86_avx) @@ -237,6 +238,7 @@ DECLARE_LL_INITMOD(x86) DECLARE_LL_INITMOD(x86_sse41) DECLARE_CPP_INITMOD(x86_cpu_features) #else +DECLARE_NO_INITMOD(x86_amx) DECLARE_NO_INITMOD(x86_avx512) DECLARE_NO_INITMOD(x86_avx2) DECLARE_NO_INITMOD(x86_avx) @@ -1063,6 +1065,11 @@ std::unique_ptr get_initial_module_for_target(Target t, llvm::LLVM if (t.has_feature(Target::AVX512)) { modules.push_back(get_initmod_x86_avx512_ll(c)); } +#if LLVM_VERSION >= 120 + if (t.has_feature(Target::AVX512_SapphireRapids)) { + modules.push_back(get_initmod_x86_amx_ll(c)); + } +#endif if (t.has_feature(Target::Profile)) { user_assert(t.os != Target::WebAssemblyRuntime) << "The profiler cannot be used in a threadless environment."; modules.push_back(get_initmod_profiler_inlined(c, bits_64, debug)); diff --git a/src/runtime/CMakeLists.txt b/src/runtime/CMakeLists.txt index bf6e3f2c7e8c..ccda4a765f35 100644 --- a/src/runtime/CMakeLists.txt +++ b/src/runtime/CMakeLists.txt @@ -115,7 +115,7 @@ set(RUNTIME_LL x86_sse41 ) -if (LLVM_PACKAGE_VERSION VERSION_GREATER_EQUAL 12.0) +if (Halide_LLVM_VERSION VERSION_GREATER_EQUAL 12.0) # AMX instructions require LLVM 12 or newer list(APPEND RUNTIME_LL x86_amx) endif () From 91891208f14c5a5390ce363224a21ef0366a54b1 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Thu, 8 Apr 2021 12:14:06 +0100 Subject: [PATCH 15/53] add llvm funcs for su, us, uu amx variants --- src/runtime/x86_amx.ll | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/runtime/x86_amx.ll b/src/runtime/x86_amx.ll index 265d7ef2e381..024518ddb98b 100644 --- a/src/runtime/x86_amx.ll +++ b/src/runtime/x86_amx.ll @@ -16,6 +16,36 @@ define weak_odr <256 x i32> @tdpbssd(i16 %rows, i16 %colbytes, i16 %acc, <256 x } declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) +define weak_odr <256 x i32> @tdpbsud(i16 %rows, i16 %colbytes, i16 %acc, <256 x i32> %out, <1024 x i8> %lhs, <1024 x i8> %rhs) nounwind alwaysinline readnone { + %1 = bitcast <1024 x i8> %lhs to x86_amx + %2 = bitcast <1024 x i8> %rhs to x86_amx + %3 = bitcast <256 x i32> %out to x86_amx + %4 = tail call x86_amx @llvm.x86.tdpbsud.internal(i16 %rows, i16 %colbytes, i16 %acc, x86_amx %3, x86_amx %1, x86_amx %2) nounwind readnone + %5 = bitcast x86_amx %4 to <256 x i32> + ret <256 x i32> %5 +} +declare x86_amx @llvm.x86.tdpbsud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) + +define weak_odr <256 x i32> @tdpbusd(i16 %rows, i16 %colbytes, i16 %acc, <256 x i32> %out, <1024 x i8> %lhs, <1024 x i8> %rhs) nounwind alwaysinline readnone { + %1 = bitcast <1024 x i8> %lhs to x86_amx + %2 = bitcast <1024 x i8> %rhs to x86_amx + %3 = bitcast <256 x i32> %out to x86_amx + %4 = tail call x86_amx @llvm.x86.tdpbusd.internal(i16 %rows, i16 %colbytes, i16 %acc, x86_amx %3, x86_amx %1, x86_amx %2) nounwind readnone + %5 = bitcast x86_amx %4 to <256 x i32> + ret <256 x i32> %5 +} +declare x86_amx @llvm.x86.tdpbusd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) + +define weak_odr <256 x i32> @tdpbuud(i16 %rows, i16 %colbytes, i16 %acc, <256 x i32> %out, <1024 x i8> %lhs, <1024 x i8> %rhs) nounwind alwaysinline readnone { + %1 = bitcast <1024 x i8> %lhs to x86_amx + %2 = bitcast <1024 x i8> %rhs to x86_amx + %3 = bitcast <256 x i32> %out to x86_amx + %4 = tail call x86_amx @llvm.x86.tdpbuud.internal(i16 %rows, i16 %colbytes, i16 %acc, x86_amx %3, x86_amx %1, x86_amx %2) nounwind readnone + %5 = bitcast x86_amx %4 to <256 x i32> + ret <256 x i32> %5 +} +declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) + define weak_odr <2 x i1> @tilestored64(i16 %rows, i16 %cols, i8* %ptr, i64 %off, i64 %stride, <256 x i32> %val) nounwind alwaysinline writeonly { %1 = getelementptr i8, i8* %ptr, i64 %off %2 = bitcast <256 x i32> %val to x86_amx From 75c4262d3190770e074fd18ffa435fce3ddb624b Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Fri, 16 Apr 2021 14:54:06 +0100 Subject: [PATCH 16/53] add other amx intrinsics to intrinsic_defs --- src/CodeGen_X86.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 628f96d28e95..60a83e651996 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -206,7 +206,11 @@ const x86Intrinsic intrinsic_defs[] = { {"dpwssdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids}, {"tileloadd64_i8", Int(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, + {"tileloadd64_i8", UInt(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, {"tdpbssd", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), Int(8, 1024), Int(8, 1024)}, Target::AVX512_SapphireRapids}, + {"tdpbsud", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), Int(8, 1024), UInt(8, 1024)}, Target::AVX512_SapphireRapids}, + {"tdpbusd", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), UInt(8, 1024), Int(8, 1024)}, Target::AVX512_SapphireRapids}, + {"tdpbuud", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), UInt(8, 1024), UInt(8, 1024)}, Target::AVX512_SapphireRapids}, {"tilezero_i32", Int(32, 256), "tile_zero", {Int(16), Int(16)}, Target::AVX512_SapphireRapids}, // CodeGen_LLVM cannot cope with returning Type() ie void*, and return type needs to be vector to trigger call_overloaded_intrin {"tilestored64", Bool(2), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Int(32, 256)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, From 09f7551a8994c3d4b7376481915bb8a83b83df9a Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Fri, 16 Apr 2021 14:55:26 +0100 Subject: [PATCH 17/53] match with unsigned 8 bit integers This matching happens for the left and right side, each determining whether that side is unsigned or signed. In the end the proper 1024 byte buffer is created with (un)signed. --- src/ExtractTileOperations.cpp | 39 +++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 62d877aac354..8ff0c937b643 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -95,6 +95,7 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name) { // m[ramp(0, 1, S)] = VectorAdd(lhs[{XYR tile}] * xX(rhs[{YR tile}])) + m[ramp(0, 1, S)] const auto wild_i8x = Variable::make(Int(8, 0), "*"); + const auto wild_u8x = Variable::make(UInt(8, 0), "*"); const auto wild_i16x = Variable::make(Int(16, 0), "*"); vector matches; const auto pattern1 = wild_i32x + wild_i32x; @@ -106,13 +107,31 @@ convert_to_matmul(const Store *op, const string &new_name) { // FIXME: Add support for uint8 and bf16 for LLVM 13+ auto pattern2 = cast(Int(32, 0), cast(Int(16, 0), wild_i8x) * wild_i16x); - if (!expr_match(pattern2, reduce->value, matches)) { return {}; } + auto pattern2_alt = cast(Int(32, 0), cast(Int(16, 0), wild_u8x) * wild_i16x); + + bool lhs_signed = false; + if (expr_match(pattern2, reduce->value, matches)) { + lhs_signed = true; + } else if (expr_match(pattern2_alt, reduce->value, matches)) { + lhs_signed = false; + } else { + return {}; + } + const auto *lhs_load = matches[0].as(); // FIXME: When tile_r is not 4 the broadcast is inside the index, not of the value const auto *rhs_broadcast = matches[1].as(); if (!lhs_load || !rhs_broadcast) { return {}; } const auto *rhs_cast = rhs_broadcast->value.as(); - if (!rhs_cast || rhs_cast->value.type().element_of() != Int(8)) { return {}; } + bool rhs_signed = false; + if (rhs_cast && rhs_cast->value.type().element_of() == Int(8)) { + rhs_signed = true; + } else if (rhs_cast && rhs_cast->value.type().element_of() == UInt(8)) { + rhs_signed = false; + } else { + return {}; + } + const auto *rhs_load = rhs_cast->value.as(); if (!rhs_load) { return {}; } @@ -134,9 +153,21 @@ convert_to_matmul(const Store *op, const string &new_name) { // {rows, colbytes, var, index} auto lhs_var = Variable::make(Handle(), lhs_load->name); - auto lhs = Call::make(Int(8, 1024), "tile_load", {tile_x, tile_r, lhs_var, lhs_tile.base, lhs_tile.stride[0]}, Call::Intrinsic); + auto lhs = [&]() { + if (lhs_signed) { + return Call::make(Int(8, 1024), "tile_load", {tile_x, tile_r, lhs_var, lhs_tile.base, lhs_tile.stride[0]}, Call::Intrinsic); + } else { + return Call::make(UInt(8, 1024), "tile_load", {tile_x, tile_r, lhs_var, lhs_tile.base, lhs_tile.stride[0]}, Call::Intrinsic); + } + }(); auto rhs_var = Variable::make(Handle(), rhs_load->name); - auto rhs = Call::make(Int(8, 1024), "tile_load", {1, tile_y * tile_r, rhs_var, rhs_tile.base, rhs_tile.stride[0]}, Call::Intrinsic); + auto rhs = [&]() { + if (rhs_signed) { + return Call::make(Int(8, 1024), "tile_load", {1, tile_y * tile_r, rhs_var, rhs_tile.base, rhs_tile.stride[0]}, Call::Intrinsic); + } else { + return Call::make(UInt(8, 1024), "tile_load", {1, tile_y * tile_r, rhs_var, rhs_tile.base, rhs_tile.stride[0]}, Call::Intrinsic); + } + }(); // {rows, colbytes, acc, out, lhs, rhs} auto out = Load::make(Int(32, 256), new_name, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); From 5dd1471f4d1105cf19f0eadb1c53a153c439d98e Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Tue, 27 Apr 2021 15:20:24 +0100 Subject: [PATCH 18/53] match for 32 bit integer and guard unsigned amx on llvm 13 --- src/ExtractTileOperations.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 8ff0c937b643..c449d0dcb890 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -96,7 +96,6 @@ convert_to_matmul(const Store *op, const string &new_name) { // m[ramp(0, 1, S)] = VectorAdd(lhs[{XYR tile}] * xX(rhs[{YR tile}])) + m[ramp(0, 1, S)] const auto wild_i8x = Variable::make(Int(8, 0), "*"); const auto wild_u8x = Variable::make(UInt(8, 0), "*"); - const auto wild_i16x = Variable::make(Int(16, 0), "*"); vector matches; const auto pattern1 = wild_i32x + wild_i32x; if (!expr_match(pattern1, op->value, matches)) { return {}; } @@ -105,14 +104,14 @@ convert_to_matmul(const Store *op, const string &new_name) { if (!reduce || reduce->op != VectorReduce::Add) { return {}; } if (!load || load->name != op->name || !equal(load->index, op->index)) { return {}; } - // FIXME: Add support for uint8 and bf16 for LLVM 13+ - auto pattern2 = cast(Int(32, 0), cast(Int(16, 0), wild_i8x) * wild_i16x); - auto pattern2_alt = cast(Int(32, 0), cast(Int(16, 0), wild_u8x) * wild_i16x); + // FIXME: Add support for bf16 for LLVM 13+ + auto pattern2 = cast(Int(32, 0), cast(Int(32, 0), wild_i8x) * wild_i32x); + auto pattern2_unsigned = cast(Int(32, 0), cast(Int(32, 0), wild_u8x) * wild_i32x); bool lhs_signed = false; if (expr_match(pattern2, reduce->value, matches)) { lhs_signed = true; - } else if (expr_match(pattern2_alt, reduce->value, matches)) { + } else if (expr_match(pattern2_unsigned, reduce->value, matches)) { lhs_signed = false; } else { return {}; @@ -151,6 +150,10 @@ convert_to_matmul(const Store *op, const string &new_name) { return {}; } +#if LLVM_VERSION < 130 + user_assert(lhs_signed && rhs_signed) << "LLVM 13 or above is required for unsigned AMX instructions"; +#endif + // {rows, colbytes, var, index} auto lhs_var = Variable::make(Handle(), lhs_load->name); auto lhs = [&]() { From 7e45c29ed06086851379be5d14049503b302b6bf Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Thu, 29 Apr 2021 13:27:05 +0100 Subject: [PATCH 19/53] adjust test to cover unsigned tile operations --- test/performance/tiled_matmul.cpp | 84 +++++++++++++++++++++++-------- 1 file changed, 64 insertions(+), 20 deletions(-) diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index fbcd4b292942..b0938ff29138 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -7,29 +7,70 @@ using namespace Halide; -int main(int argc, char **argv) { +struct make_uint_t { + template + auto operator()(Args &&... args) const { + return UInt(static_cast(args)...); + } +}; + +struct make_int_t { + template + auto operator()(Args &&... args) const { + return Int(static_cast(args)...); + } +}; + +template +void fill_buffer_a(Buffer &buf, int row, int acc) { + for (int iy = 0; iy < row; iy++) { + for (int ix = 0; ix < acc; ix++) { + buf(ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } +} + +template +void fill_buffer_b(Buffer &buf, int col, int acc) { + for (int iy = 0; iy < acc / 4; iy++) { + for (int ix = 0; ix < col; ix++) { + for (int ik = 0; ik < 4; ++ik) { + buf(ik, ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } + } +} + +template +bool matmul() { + auto lhs = std::conditional_t{}; + auto rhs = std::conditional_t{}; + + using LhsInt8 = std::conditional_t; + using RhsInt8 = std::conditional_t; + Target target = get_jit_target_from_environment(); if (!target.has_feature(Target::AVX512_SapphireRapids)) { std::cout << "[SKIP] The tiled matmul test is only designed to test AMX support.\n"; - return 0; + return true; } const int row = 16; const int col = 16; const int acc = 16; Var x("x"), y("y"); - ImageParam A(Int(8), 2, "lhs"); + ImageParam A(lhs(8), 2, "lhs"); // NB the RHS matrix in AMX instructions should be tiled in "VNNI format", // where instead of being (cols, rows) where rows are adjacent in memory it // should be (4, cols, rows / 4) for int8, or (2, cols, rows / 2) for bf16. // This means that the rows must always be divisible by 4 (or 2 for bf16). - ImageParam B(Int(8), 3, "rhs"); + ImageParam B(rhs(8), 3, "rhs"); RDom r(0, acc); Func mm("matmul"); mm(y, x) = cast(0); - mm(y, x) += cast(A(r.x, x)) * B(r.x % 4, y, r.x / 4); + mm(y, x) += cast(A(r.x, x)) * B(r.x % 4, y, r.x / 4); // Ensure all (x, y) tile sizes are the same so that loops are fused. int tile_y = 8; @@ -67,22 +108,12 @@ int main(int argc, char **argv) { .vectorize(mmyi) .vectorize(mmxi); - Buffer a_buf(acc, row); - for (int iy = 0; iy < row; iy++) { - for (int ix = 0; ix < acc; ix++) { - a_buf(ix, iy) = rand() % 256 - 128; - } - } + Buffer a_buf(acc, row); + fill_buffer_a(a_buf, row, acc); A.set(a_buf); - Buffer b_buf(4, col, acc / 4); - for (int iy = 0; iy < acc / 4; iy++) { - for (int ix = 0; ix < col; ix++) { - for (int ik = 0; ik < 4; ++ik) { - b_buf(ik, ix, iy) = rand() % 256 - 128; - } - } - } + Buffer b_buf(4, col, acc / 4); + fill_buffer_b(b_buf, col, acc); B.set(b_buf); Buffer out(col, row); @@ -107,10 +138,23 @@ int main(int argc, char **argv) { if (val != out(i, j)) { std::cerr << "Invalid result at " << i << ", " << j << "\n" << out(i, j) << " != " << val << "\n"; - return 1; + return false; } } } std::cout << "Success!\n"; + return true; +} + +auto matmul_ss = &matmul; +auto matmul_us = &matmul; +auto matmul_su = &matmul; +auto matmul_uu = &matmul; + +int main(int argc, char **argv) { + matmul_ss(); + matmul_us(); + matmul_su(); + matmul_uu(); return 0; } From 4ab681b2a5564cf089e29113780a1ddeab8369d8 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Thu, 29 Apr 2021 13:55:26 +0100 Subject: [PATCH 20/53] guard properly with llvm 12 --- src/Lower.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/Lower.cpp b/src/Lower.cpp index 6f2231d57d85..509fdc35015d 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -386,12 +386,11 @@ void lower_impl(const vector &output_funcs, s = lower_unsafe_promises(s, t); log("Lowering after lowering unsafe promises:", s); -#if LLVM_VERSION >= 12 +#if LLVM_VERSION >= 120 if (t.has_feature(Target::AVX512_SapphireRapids)) { debug(1) << "Extracting tile operations...\n"; s = extract_tile_operations(s); - debug(2) << "Lowering after extracting tile operations:\n" - << s << "\n\n"; + log("Lowering after extracting tile operations:", s); } #endif From 6339ae724f15a416f3851a6a77e295785e60b11a Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Thu, 29 Apr 2021 13:56:08 +0100 Subject: [PATCH 21/53] create explicit error if failed to use tile operations --- src/ExtractTileOperations.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index c449d0dcb890..e6d50fe64f59 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -224,9 +224,9 @@ class ExtractTileOperations : public IRMutator { int found_tile_r = -1; Stmt visit(const Allocate *op) override { - if (op->memory_type == MemoryType::AMXTile && - op->type.is_int() && - op->type.bits() == 32) { + if (op->memory_type == MemoryType::AMXTile) { + user_assert(op->type.is_int() && op->type.bits() == 32) << "scheduled tile operations must yield 32-bit integers"; + // FIXME: Handle nested allocations better user_assert(!in_allocate) << "Found two possible tile allocations for AMX allocation"; ScopedValue old_amx_name(amx_name, op->name + ".amx"); From 525e11e8930addb5be79552efa8436956c415b40 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Thu, 29 Apr 2021 14:40:42 +0100 Subject: [PATCH 22/53] pass types as template params rather than boolean This makes the intention clearer --- test/performance/tiled_matmul.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index b0938ff29138..d9d1f31ab586 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -41,13 +41,13 @@ void fill_buffer_b(Buffer &buf, int col, int acc) { } } -template +template bool matmul() { - auto lhs = std::conditional_t{}; - auto rhs = std::conditional_t{}; + constexpr bool lhs_signed = std::is_signed::value; + constexpr bool rhs_signed = std::is_signed::value; - using LhsInt8 = std::conditional_t; - using RhsInt8 = std::conditional_t; + auto lhs = std::conditional_t{}; + auto rhs = std::conditional_t{}; Target target = get_jit_target_from_environment(); if (!target.has_feature(Target::AVX512_SapphireRapids)) { @@ -146,10 +146,10 @@ bool matmul() { return true; } -auto matmul_ss = &matmul; -auto matmul_us = &matmul; -auto matmul_su = &matmul; -auto matmul_uu = &matmul; +auto matmul_ss = &matmul; +auto matmul_us = &matmul; +auto matmul_su = &matmul; +auto matmul_uu = &matmul; int main(int argc, char **argv) { matmul_ss(); From 5a21484036a9ffdce3d46cac9a75e3da5b738f2f Mon Sep 17 00:00:00 2001 From: Thales Sabino Date: Fri, 30 Apr 2021 13:27:54 +0100 Subject: [PATCH 23/53] clang-format patch --- test/performance/tiled_matmul.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index d9d1f31ab586..68e30b75ed1c 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -9,14 +9,14 @@ using namespace Halide; struct make_uint_t { template - auto operator()(Args &&... args) const { + auto operator()(Args &&...args) const { return UInt(static_cast(args)...); } }; struct make_int_t { template - auto operator()(Args &&... args) const { + auto operator()(Args &&...args) const { return Int(static_cast(args)...); } }; From 7950614494d8138b8e766f406353cd5bd385f170 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Fri, 30 Apr 2021 15:05:23 +0100 Subject: [PATCH 24/53] add x86_amx to makefile's runtime components --- Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/Makefile b/Makefile index 33b9dfe4eb77..efb1be5063f4 100644 --- a/Makefile +++ b/Makefile @@ -843,6 +843,7 @@ RUNTIME_LL_COMPONENTS = \ x86_avx \ x86_avx2 \ x86_avx512 \ + x86_amx \ x86_sse41 RUNTIME_EXPORTED_INCLUDES = $(INCLUDE_DIR)/HalideRuntime.h \ From 4a6c10c7755fc8e5873dfaf705a7fc2524011249 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Fri, 30 Apr 2021 16:03:25 +0100 Subject: [PATCH 25/53] make tiled_matmul compatible with c++11 --- test/performance/tiled_matmul.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index 68e30b75ed1c..c0ddca425c65 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -9,14 +9,14 @@ using namespace Halide; struct make_uint_t { template - auto operator()(Args &&...args) const { + Type operator()(Args &&...args) const { return UInt(static_cast(args)...); } }; struct make_int_t { template - auto operator()(Args &&...args) const { + Type operator()(Args &&...args) const { return Int(static_cast(args)...); } }; @@ -46,8 +46,8 @@ bool matmul() { constexpr bool lhs_signed = std::is_signed::value; constexpr bool rhs_signed = std::is_signed::value; - auto lhs = std::conditional_t{}; - auto rhs = std::conditional_t{}; + auto lhs = typename std::conditional::type{}; + auto rhs = typename std::conditional::type{}; Target target = get_jit_target_from_environment(); if (!target.has_feature(Target::AVX512_SapphireRapids)) { From f985644eed8436f3048c7e3a45da8bee120dffe0 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Mon, 3 May 2021 15:29:50 +0100 Subject: [PATCH 26/53] add mattrs required for amx --- src/CodeGen_X86.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 212318ac5885..049133ca6eec 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -703,7 +703,7 @@ string CodeGen_X86::mattrs() const { } if (target.has_feature(Target::AVX512_SapphireRapids)) { #if LLVM_VERSION >= 120 - features += ",+avx512bf16,+avx512vnni"; + features += ",+avx512bf16,+avx512vnni,+amx-int8,+amx-bf16"; #else user_error << "AVX512 SapphireRapids requires LLVM 12 or later."; #endif From e3f1ef6b30348b74baa58ec10124cb9d37e21f17 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Wed, 5 May 2021 15:41:02 +0100 Subject: [PATCH 27/53] fix formatting issues --- src/ExtractTileOperations.cpp | 56 +++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index e6d50fe64f59..976d0262f806 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -41,7 +41,9 @@ Tile<2> is_2d_tile_index(const Expr &e) { Tile<3> is_3d_tile_index(const Expr &e) { vector matches; auto add_sub_pattern = (wild_i32x + wild_i32x) - wild_i32x; - if (!expr_match(add_sub_pattern, e, matches)) { return {}; } + if (!expr_match(add_sub_pattern, e, matches)) { + return {}; + } // ramp(x16(base), x16(stride), 4) + x16(ramp(idx, 1, 4)) y: 4, x: 4, r: 4 // ramp(x10(base), x10(stride), 3) + x6(ramp(idx, 1, 5)) y: 2, x: 3, r: 5 Expr first = std::move(matches[0]); @@ -54,30 +56,42 @@ Tile<3> is_3d_tile_index(const Expr &e) { r1 = second.as(); b2 = first.as(); } - if (!r1 || !b2) { return {}; } + if (!r1 || !b2) { + return {}; + } const auto *b1 = r1->base.as(); const auto *r2 = b2->value.as(); - if (!b1 || !r2) { return {}; } + if (!b1 || !r2) { + return {}; + } int x_tile = r1->lanes; int r_tile = r2->lanes; int y_tile = b1->lanes / r_tile; - if (y_tile != b2->lanes / x_tile) { return {}; } + if (y_tile != b2->lanes / x_tile) { + return {}; + } auto pattern1 = Ramp::make(Broadcast::make(wild_i32, b1->lanes), Broadcast::make(wild_i32, b1->lanes), r1->lanes); - if (!expr_match(pattern1, first, matches)) { return {}; } + if (!expr_match(pattern1, first, matches)) { + return {}; + } Expr base = std::move(matches[0]); Expr x_stride = std::move(matches[1]); auto pattern2 = Broadcast::make(Ramp::make(wild_i32, wild_i32, r2->lanes), b2->lanes); - if (!expr_match(pattern2, second, matches)) { return {}; } + if (!expr_match(pattern2, second, matches)) { + return {}; + } base += std::move(matches[0]); Expr r_stride = std::move(matches[1]); auto pattern3 = Broadcast::make(wild_i32, b1->lanes * r1->lanes); - if (!expr_match(pattern3, adj, matches)) { return {}; } + if (!expr_match(pattern3, adj, matches)) { + return {}; + } base -= std::move(matches[0]); return {true, base, {x_stride, 0, r_stride}, {x_tile, y_tile, r_tile}}; @@ -91,18 +105,23 @@ struct NewMatmul { int tile_r; }; -NewMatmul -convert_to_matmul(const Store *op, const string &new_name) { +NewMatmul convert_to_matmul(const Store *op, const string &new_name) { // m[ramp(0, 1, S)] = VectorAdd(lhs[{XYR tile}] * xX(rhs[{YR tile}])) + m[ramp(0, 1, S)] const auto wild_i8x = Variable::make(Int(8, 0), "*"); const auto wild_u8x = Variable::make(UInt(8, 0), "*"); vector matches; const auto pattern1 = wild_i32x + wild_i32x; - if (!expr_match(pattern1, op->value, matches)) { return {}; } + if (!expr_match(pattern1, op->value, matches)) { + return {}; + } const auto *reduce = matches[0].as(); const auto *load = matches[1].as(); - if (!reduce || reduce->op != VectorReduce::Add) { return {}; } - if (!load || load->name != op->name || !equal(load->index, op->index)) { return {}; } + if (!reduce || reduce->op != VectorReduce::Add) { + return {}; + } + if (!load || load->name != op->name || !equal(load->index, op->index)) { + return {}; + } // FIXME: Add support for bf16 for LLVM 13+ auto pattern2 = cast(Int(32, 0), cast(Int(32, 0), wild_i8x) * wild_i32x); @@ -120,7 +139,9 @@ convert_to_matmul(const Store *op, const string &new_name) { const auto *lhs_load = matches[0].as(); // FIXME: When tile_r is not 4 the broadcast is inside the index, not of the value const auto *rhs_broadcast = matches[1].as(); - if (!lhs_load || !rhs_broadcast) { return {}; } + if (!lhs_load || !rhs_broadcast) { + return {}; + } const auto *rhs_cast = rhs_broadcast->value.as(); bool rhs_signed = false; if (rhs_cast && rhs_cast->value.type().element_of() == Int(8)) { @@ -132,12 +153,16 @@ convert_to_matmul(const Store *op, const string &new_name) { } const auto *rhs_load = rhs_cast->value.as(); - if (!rhs_load) { return {}; } + if (!rhs_load) { + return {}; + } const auto lhs_tile = is_3d_tile_index(lhs_load->index); const auto rhs_tile = is_2d_tile_index(rhs_load->index); // FIXME: When tile_r is not 4 the RHS load will be 4D (x, r/4, y, r%4) - if (!lhs_tile.result || !rhs_tile.result) { return {}; } + if (!lhs_tile.result || !rhs_tile.result) { + return {}; + } const int tile_x = lhs_tile.extent[0]; const int tile_y = lhs_tile.extent[1]; @@ -317,6 +342,5 @@ class ExtractTileOperations : public IRMutator { Stmt extract_tile_operations(const Stmt &s) { return ExtractTileOperations().mutate(s); } - } // namespace Internal } // namespace Halide From 57b608023502a08446e6daac5502688360c6a93f Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Mon, 10 May 2021 15:32:27 +0100 Subject: [PATCH 28/53] remove outdated FIXME comments --- src/ExtractTileOperations.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 976d0262f806..e1eea9ccd7ce 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -137,7 +137,6 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name) { } const auto *lhs_load = matches[0].as(); - // FIXME: When tile_r is not 4 the broadcast is inside the index, not of the value const auto *rhs_broadcast = matches[1].as(); if (!lhs_load || !rhs_broadcast) { return {}; @@ -159,7 +158,6 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name) { const auto lhs_tile = is_3d_tile_index(lhs_load->index); const auto rhs_tile = is_2d_tile_index(rhs_load->index); - // FIXME: When tile_r is not 4 the RHS load will be 4D (x, r/4, y, r%4) if (!lhs_tile.result || !rhs_tile.result) { return {}; } @@ -252,7 +250,6 @@ class ExtractTileOperations : public IRMutator { if (op->memory_type == MemoryType::AMXTile) { user_assert(op->type.is_int() && op->type.bits() == 32) << "scheduled tile operations must yield 32-bit integers"; - // FIXME: Handle nested allocations better user_assert(!in_allocate) << "Found two possible tile allocations for AMX allocation"; ScopedValue old_amx_name(amx_name, op->name + ".amx"); ScopedValue old_tile_name(tile_name, op->name); From 55098f3f993be92ff3ae339a896e70c744208774 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Mon, 10 May 2021 16:32:19 +0100 Subject: [PATCH 29/53] add bf16 tile operations to the runtime --- src/CodeGen_X86.cpp | 5 ++++- src/runtime/x86_amx.ll | 19 ++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 7bd644177b25..17ea4f01700a 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -212,13 +212,16 @@ const x86Intrinsic intrinsic_defs[] = { {"tileloadd64_i8", Int(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, {"tileloadd64_i8", UInt(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, + {"tileloadd64_bf16", BFloat(16, 512), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, {"tdpbssd", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), Int(8, 1024), Int(8, 1024)}, Target::AVX512_SapphireRapids}, {"tdpbsud", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), Int(8, 1024), UInt(8, 1024)}, Target::AVX512_SapphireRapids}, {"tdpbusd", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), UInt(8, 1024), Int(8, 1024)}, Target::AVX512_SapphireRapids}, {"tdpbuud", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), UInt(8, 1024), UInt(8, 1024)}, Target::AVX512_SapphireRapids}, + {"tdpbf16ps", Float(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Float(32, 256), BFloat(16, 512), BFloat(16, 512)}, Target::AVX512_SapphireRapids}, {"tilezero_i32", Int(32, 256), "tile_zero", {Int(16), Int(16)}, Target::AVX512_SapphireRapids}, // CodeGen_LLVM cannot cope with returning Type() ie void*, and return type needs to be vector to trigger call_overloaded_intrin - {"tilestored64", Bool(2), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Int(32, 256)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, + {"tilestored64_i32", Bool(2), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Int(32, 256)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, + {"tilestored64_f32", Bool(2), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Float(32, 256)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, }; // clang-format on diff --git a/src/runtime/x86_amx.ll b/src/runtime/x86_amx.ll index 024518ddb98b..85c51d20a2ee 100644 --- a/src/runtime/x86_amx.ll +++ b/src/runtime/x86_amx.ll @@ -46,7 +46,17 @@ define weak_odr <256 x i32> @tdpbuud(i16 %rows, i16 %colbytes, i16 %acc, <256 x } declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) -define weak_odr <2 x i1> @tilestored64(i16 %rows, i16 %cols, i8* %ptr, i64 %off, i64 %stride, <256 x i32> %val) nounwind alwaysinline writeonly { +define weak_odr <256 x float> @tdpbf16ps(i16 %rows, i16 %colbytes, i16 %acc, <256 x float> %out, <512 x bfloat> %lhs, <512 x bfloat> %rhs) nounwind alwaysinline readnone { + %1 = bitcast <512 x bfloat> %lhs to x86_amx + %2 = bitcast <512 x bfloat> %rhs to x86_amx + %3 = bitcast <256 x float> %out to x86_amx + %4 = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 %rows, i16 %colbytes, i16 %acc, x86_amx %3, x86_amx %1, x86_amx %2) nounwind readnone + %5 = bitcast x86_amx %4 to <256 x float> + ret <256 x float> %5 +} +declare x86_amx @llvm.x86.tdpbf16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) + +define weak_odr <2 x i1> @tilestored64_i32(i16 %rows, i16 %cols, i8* %ptr, i64 %off, i64 %stride, <256 x i32> %val) nounwind alwaysinline writeonly { %1 = getelementptr i8, i8* %ptr, i64 %off %2 = bitcast <256 x i32> %val to x86_amx tail call void @llvm.x86.tilestored64.internal(i16 %rows, i16 %cols, i8* %1, i64 %stride, x86_amx %2) nounwind writeonly @@ -54,6 +64,13 @@ define weak_odr <2 x i1> @tilestored64(i16 %rows, i16 %cols, i8* %ptr, i64 %off, } declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) +define weak_odr <2 x i1> @tilestored64_f32(i16 %rows, i16 %cols, i8* %ptr, i64 %off, i64 %stride, <256 x float> %val) nounwind alwaysinline writeonly { + %1 = getelementptr i8, i8* %ptr, i64 %off + %2 = bitcast <256 x float> %val to x86_amx + tail call void @llvm.x86.tilestored64.internal(i16 %rows, i16 %cols, i8* %1, i64 %stride, x86_amx %2) nounwind writeonly + ret <2 x i1> zeroinitializer +} + ; NB: Even though this should be readnone, that will cause LLVM to try to ; generate a single zero tile, and copy it each time it is used. However the AMX ; registers cannot be copied, so this causes compilation failures: From 9f078f0b6ff2f92e7aecc7d6a4bfe5ead89f7485 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Tue, 11 May 2021 17:25:05 +0100 Subject: [PATCH 30/53] create a schedule that should map to amx --- test/performance/tiled_matmul.cpp | 65 +++++++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 4 deletions(-) diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index c0ddca425c65..f6b3d3052741 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -151,10 +151,67 @@ auto matmul_us = &matmul; auto matmul_su = &matmul; auto matmul_uu = &matmul; +void tiled_matmul_bf16() { + // lhs: 32x16, rhs: 16x32 + const int row = 32; + const int col = 32; + const int acc = 16; + + Var x("x"), y("y"); + ImageParam A(Float(32), 2, "lhs"); + ImageParam B(Float(32), 3, "rhs"); + + RDom r(0, acc, "racc"); + + Func mm("matmul"); + mm(x, y) = cast(0); + mm(x, y) += cast(A(r.x, y)) * B(r.x % 2, x, r.x / 2); + + int tile_x = 8; + int tile_y = 8; + int tile_r = 2; + + Var rxi("rxi"), ryi("ryi"); + RVar rri("rri"), rro("rro"); + + mm.compute_at(mm.in(), x) + .update() + .tile(x, y, rxi, ryi, tile_x, tile_y, TailStrategy::GuardWithIf) + .split(r.x, rro, rri, tile_r) + .reorder({rri, rxi, ryi, rro, x, y}); + + //Var ixi("ixi"), iyi("iyi"); + //mm.compute_at(mm.in(), x); + // .tile(x, y, ixi, iyi, tile_x, tile_y) + // .vectorize(ixi) + // .vectorize(iyi); + + //Var mmxi("mmxi"), mmyi("mmyi"); + //mm.in() + // .tile(x, y, mmxi, mmyi, tile_x, tile_y) + // .vectorize(mmxi) + // .vectorize(mmyi); + + + Func result = mm.in(); + result.print_loop_nest(); + + result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul_bf16.ll", {A, B}, get_jit_target_from_environment()); + +} + int main(int argc, char **argv) { - matmul_ss(); - matmul_us(); - matmul_su(); - matmul_uu(); + Target target = get_jit_target_from_environment(); + if (!target.has_feature(Target::AVX512_SapphireRapids)) { + std::cout << "[SKIP] The tiled matmul test is only designed to test AMX support.\n"; + return 0; + } + + //matmul_ss(); + //matmul_us(); + //matmul_su(); + //matmul_uu(); + + tiled_matmul_bf16(); return 0; } From e73702d0faf78fe5bfe7bb9419d1fd10ff0f136a Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Wed, 12 May 2021 13:18:36 +0100 Subject: [PATCH 31/53] create full amx-bf16 schedule --- test/performance/tiled_matmul.cpp | 64 +++++++++++++++---------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index f6b3d3052741..9a5337ab716a 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -49,11 +49,6 @@ bool matmul() { auto lhs = typename std::conditional::type{}; auto rhs = typename std::conditional::type{}; - Target target = get_jit_target_from_environment(); - if (!target.has_feature(Target::AVX512_SapphireRapids)) { - std::cout << "[SKIP] The tiled matmul test is only designed to test AMX support.\n"; - return true; - } const int row = 16; const int col = 16; const int acc = 16; @@ -81,32 +76,32 @@ bool matmul() { Var rxi("rxi"), ryi("ryi"); RVar rri("rri"), rro("rro"); mm.compute_at(mm.in(), y) - .store_in(MemoryType::AMXTile) + //.store_in(MemoryType::AMXTile) .update() // Split into (x,y) tile .tile(y, x, ryi, rxi, tile_y, tile_x, TailStrategy::GuardWithIf) // Split reduction dim by tile_r .split(r.x, rro, rri, tile_r) // Reorder so that the (x,y) tile is inside the inner ro loop - .reorder({rri, ryi, rxi, rro, y, x}) - .atomic() - .vectorize(rri) - .vectorize(ryi) - .vectorize(rxi); + .reorder({rri, ryi, rxi, rro, y, x}); + //.atomic() + //.vectorize(rri) + //.vectorize(ryi) + //.vectorize(rxi); // Schedule the initialization Var ixi("ixi"), iyi("iyi"); mm.compute_at(mm.in(), y) - .tile(y, x, iyi, ixi, tile_y, tile_x) - .vectorize(iyi) - .vectorize(ixi); + .tile(y, x, iyi, ixi, tile_y, tile_x); + // .vectorize(iyi) + // .vectorize(ixi); // Schedule the consumer Var mmxi("mmxi"), mmyi("mmyi"); mm.in() - .tile(y, x, mmyi, mmxi, tile_y, tile_x) - .vectorize(mmyi) - .vectorize(mmxi); + .tile(y, x, mmyi, mmxi, tile_y, tile_x); + // .vectorize(mmyi) + // .vectorize(mmxi); Buffer a_buf(acc, row); fill_buffer_a(a_buf, row, acc); @@ -151,7 +146,7 @@ auto matmul_us = &matmul; auto matmul_su = &matmul; auto matmul_uu = &matmul; -void tiled_matmul_bf16() { +void matmul_bf16() { // lhs: 32x16, rhs: 16x32 const int row = 32; const int col = 32; @@ -175,22 +170,27 @@ void tiled_matmul_bf16() { RVar rri("rri"), rro("rro"); mm.compute_at(mm.in(), x) + .store_in(MemoryType::AMXTile) .update() .tile(x, y, rxi, ryi, tile_x, tile_y, TailStrategy::GuardWithIf) .split(r.x, rro, rri, tile_r) - .reorder({rri, rxi, ryi, rro, x, y}); - - //Var ixi("ixi"), iyi("iyi"); - //mm.compute_at(mm.in(), x); - // .tile(x, y, ixi, iyi, tile_x, tile_y) - // .vectorize(ixi) - // .vectorize(iyi); - - //Var mmxi("mmxi"), mmyi("mmyi"); - //mm.in() - // .tile(x, y, mmxi, mmyi, tile_x, tile_y) - // .vectorize(mmxi) - // .vectorize(mmyi); + .reorder({rri, rxi, ryi, rro, x, y}) + .atomic() + .vectorize(rri) + .vectorize(rxi) + .vectorize(ryi); + + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), x) + .tile(x, y, ixi, iyi, tile_x, tile_y) + .vectorize(ixi) + .vectorize(iyi); + + Var mmxi("mmxi"), mmyi("mmyi"); + mm.in() + .tile(x, y, mmxi, mmyi, tile_x, tile_y) + .vectorize(mmxi) + .vectorize(mmyi); Func result = mm.in(); @@ -212,6 +212,6 @@ int main(int argc, char **argv) { //matmul_su(); //matmul_uu(); - tiled_matmul_bf16(); + matmul_bf16(); return 0; } From 16217b4a22b4ac8c7cc35217fd81dc48b8f7b302 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Wed, 12 May 2021 13:18:53 +0100 Subject: [PATCH 32/53] allow amx operations to yield f32s --- src/ExtractTileOperations.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index e1eea9ccd7ce..12b2f4889769 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -248,7 +248,10 @@ class ExtractTileOperations : public IRMutator { Stmt visit(const Allocate *op) override { if (op->memory_type == MemoryType::AMXTile) { - user_assert(op->type.is_int() && op->type.bits() == 32) << "scheduled tile operations must yield 32-bit integers"; + user_assert( + (op->type.is_int() && op->type.bits() == 32) || + (op->type.is_float() && op->type.bits() == 32)) << + "scheduled tile operations must yield 32-bit integers or 32-bit floats"; user_assert(!in_allocate) << "Found two possible tile allocations for AMX allocation"; ScopedValue old_amx_name(amx_name, op->name + ".amx"); From c7226103a3075c52f445210de0812fc7d85e7103 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Thu, 13 May 2021 16:29:01 +0100 Subject: [PATCH 33/53] accept 32 bit float stores --- src/ExtractTileOperations.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 12b2f4889769..065019e4f663 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -250,8 +250,8 @@ class ExtractTileOperations : public IRMutator { if (op->memory_type == MemoryType::AMXTile) { user_assert( (op->type.is_int() && op->type.bits() == 32) || - (op->type.is_float() && op->type.bits() == 32)) << - "scheduled tile operations must yield 32-bit integers or 32-bit floats"; + (op->type.is_float() && op->type.bits() == 32)) + << "scheduled tile operations must yield 32-bit integers or 32-bit floats"; user_assert(!in_allocate) << "Found two possible tile allocations for AMX allocation"; ScopedValue old_amx_name(amx_name, op->name + ".amx"); From 66885f163cc358b27d41bc5bc128979becd81485 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Wed, 19 May 2021 15:06:37 +0100 Subject: [PATCH 34/53] add support for bf16 --- src/ExtractTileOperations.cpp | 154 +++++++++++++++++++++++++----- test/performance/tiled_matmul.cpp | 70 ++++++++++---- 2 files changed, 185 insertions(+), 39 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 065019e4f663..ebc4470ad95b 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -21,6 +21,11 @@ struct Tile { int extent[Dim]; }; +enum class AMXOpType { + Int8, + Bf16, +}; + const auto wild_i32 = Variable::make(Int(32), "*"); const auto wild_i32x = Variable::make(Int(32, 0), "*"); @@ -105,15 +110,27 @@ struct NewMatmul { int tile_r; }; -NewMatmul convert_to_matmul(const Store *op, const string &new_name) { +NewMatmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType op_type) { // m[ramp(0, 1, S)] = VectorAdd(lhs[{XYR tile}] * xX(rhs[{YR tile}])) + m[ramp(0, 1, S)] const auto wild_i8x = Variable::make(Int(8, 0), "*"); const auto wild_u8x = Variable::make(UInt(8, 0), "*"); + const auto wild_bf16x = Variable::make(BFloat(16, 0), "*"); + const auto wild_f32x = Variable::make(Float(32, 0), "*"); + vector matches; + if (op_type == AMXOpType::Int8) { const auto pattern1 = wild_i32x + wild_i32x; if (!expr_match(pattern1, op->value, matches)) { return {}; } + } else // AMXOpType::Bf16 + { + const auto pattern1 = wild_f32x + wild_f32x; + if (!expr_match(pattern1, op->value, matches)) { + return {}; + } + } + const auto *reduce = matches[0].as(); const auto *load = matches[1].as(); if (!reduce || reduce->op != VectorReduce::Add) { @@ -123,11 +140,11 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name) { return {}; } - // FIXME: Add support for bf16 for LLVM 13+ + bool lhs_signed = false; + if (op_type == AMXOpType::Int8) { auto pattern2 = cast(Int(32, 0), cast(Int(32, 0), wild_i8x) * wild_i32x); auto pattern2_unsigned = cast(Int(32, 0), cast(Int(32, 0), wild_u8x) * wild_i32x); - bool lhs_signed = false; if (expr_match(pattern2, reduce->value, matches)) { lhs_signed = true; } else if (expr_match(pattern2_unsigned, reduce->value, matches)) { @@ -135,6 +152,13 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name) { } else { return {}; } + } else { + auto pattern2 = cast(Float(32, 0), cast(Float(32, 0), wild_bf16x) * wild_f32x); + + if (!expr_match(pattern2, reduce->value, matches)) { + return {}; + } + } const auto *lhs_load = matches[0].as(); const auto *rhs_broadcast = matches[1].as(); @@ -143,10 +167,19 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name) { } const auto *rhs_cast = rhs_broadcast->value.as(); bool rhs_signed = false; - if (rhs_cast && rhs_cast->value.type().element_of() == Int(8)) { + if (rhs_cast) { + if (op_type == AMXOpType::Int8) { + if (rhs_cast->value.type().element_of() == Int(8)) { rhs_signed = true; - } else if (rhs_cast && rhs_cast->value.type().element_of() == UInt(8)) { + } else if (rhs_cast->value.type().element_of() == UInt(8)) { rhs_signed = false; + } else { + user_assert(false) << "Expected rhs cast of i8/u8"; + } + } else // AMXOpType::Bf16 + { + user_assert(rhs_cast->value.type().element_of() == BFloat(16)) << "Expected rhs cast of bf16"; + } } else { return {}; } @@ -174,36 +207,75 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name) { } #if LLVM_VERSION < 130 - user_assert(lhs_signed && rhs_signed) << "LLVM 13 or above is required for unsigned AMX instructions"; + user_assert(op_type != AMXOpType::Bf16 && lhs_signed && rhs_signed) << "LLVM 13 or above is required for unsigned or float AMX instructions"; #endif // {rows, colbytes, var, index} auto lhs_var = Variable::make(Handle(), lhs_load->name); - auto lhs = [&]() { + auto lhs_type = [&]() { + switch (op_type) { + case AMXOpType::Int8: if (lhs_signed) { - return Call::make(Int(8, 1024), "tile_load", {tile_x, tile_r, lhs_var, lhs_tile.base, lhs_tile.stride[0]}, Call::Intrinsic); + return Int(8, 1024); } else { - return Call::make(UInt(8, 1024), "tile_load", {tile_x, tile_r, lhs_var, lhs_tile.base, lhs_tile.stride[0]}, Call::Intrinsic); + return UInt(8, 1024); + } + case AMXOpType::Bf16: + return BFloat(16, 512); + default: + return Type(); } }(); + int element_width = 0; + switch (op_type) { + case AMXOpType::Int8: + element_width = 1; + break; + case AMXOpType::Bf16: + element_width = 2; + break; + } + auto lhs = Call::make(lhs_type, "tile_load", {tile_x, tile_r * element_width, lhs_var, lhs_tile.base, lhs_tile.stride[0]}, Call::Intrinsic); + auto rhs_var = Variable::make(Handle(), rhs_load->name); - auto rhs = [&]() { + auto rhs_type = [&]() -> Type { + switch (op_type) { + case AMXOpType::Int8: if (rhs_signed) { - return Call::make(Int(8, 1024), "tile_load", {1, tile_y * tile_r, rhs_var, rhs_tile.base, rhs_tile.stride[0]}, Call::Intrinsic); + return Int(8, 1024); } else { - return Call::make(UInt(8, 1024), "tile_load", {1, tile_y * tile_r, rhs_var, rhs_tile.base, rhs_tile.stride[0]}, Call::Intrinsic); + return UInt(8, 1024); + } + case AMXOpType::Bf16: + return BFloat(16, 512); + default: + return Type(); + } + }(); + auto rhs = Call::make(rhs_type, "tile_load", {1, tile_y * tile_r, rhs_var, rhs_tile.base, rhs_tile.stride[0]}, Call::Intrinsic); + + auto res_type = [&]() { + switch (op_type) { + case AMXOpType::Int8: + return Int(32, 256); + case AMXOpType::Bf16: + return Float(32, 256); + default: + return Type(); } }(); // {rows, colbytes, acc, out, lhs, rhs} - auto out = Load::make(Int(32, 256), new_name, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); + auto out = Load::make(res_type, new_name, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); + auto colbytes = tile_y * 32 / rhs_load->type.bits(); - auto matmul = Call::make(Int(32, 256), "tile_matmul", {tile_x, colbytes, tile_r, out, lhs, rhs}, Call::Intrinsic); + auto matmul = + Call::make(res_type, "tile_matmul", {tile_x, colbytes, tile_r, out, lhs, rhs}, Call::Intrinsic); auto store = Store::make(new_name, matmul, Ramp::make(0, 1, 256), Parameter(), const_true(256), ModulusRemainder()); return {true, std::move(store), tile_x, tile_y, tile_r}; } -Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const string &new_name) { +Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const string &new_name, AMXOpType op_type) { if (const auto *ramp = op->index.as()) { if (const auto *bcast = op->value.as()) { if (is_const_one(ramp->stride) && @@ -212,7 +284,17 @@ Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const string &new_ auto rows = Cast::make(Int(16), tile_x); auto bytes = op->value.type().bytes(); auto colbytes = Cast::make(Int(16), tile_y * bytes); - auto val = Call::make(Int(32, 256), "tile_zero", {rows, colbytes}, Call::Intrinsic); + auto type = [&]() { + switch (op_type) { + case AMXOpType::Int8: + return Int(32, 256); + case AMXOpType::Bf16: + return Float(32, 256); + default: + return Type(); + } + }(); + auto val = Call::make(type, "tile_zero", {rows, colbytes}, Call::Intrinsic); auto store = Store::make(new_name, val, Ramp::make(0, 1, 256), Parameter(), const_true(256), ModulusRemainder()); return store; } @@ -221,11 +303,21 @@ Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const string &new_ return {}; } -Stmt convert_to_tile_store(const Store *op, const string &amx_name, int tile_x, int tile_y) { +Stmt convert_to_tile_store(const Store *op, const string &amx_name, int tile_x, int tile_y, AMXOpType op_type) { auto tile = is_2d_tile_index(op->index); if (tile.result && tile.extent[0] == tile_x && tile.extent[1] == tile_y) { auto out = Variable::make(Handle(), op->name); - auto tile_val = Load::make(Int(32, 256), amx_name, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); + auto tile_type = [&]() { + switch (op_type) { + case AMXOpType::Int8: + return Int(32, 256); + case AMXOpType::Bf16: + return Float(32, 256); + default: + return Type(); + } + }(); + auto tile_val = Load::make(tile_type, amx_name, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); auto bytes = op->value.type().bytes(); internal_assert(bytes == 4) << "AMX store only supported for int32 and float32, not for " << op->value.type() << "\n"; // {tile_x, tile_y, var, base, stride} @@ -245,6 +337,7 @@ class ExtractTileOperations : public IRMutator { int found_tile_x = -1; int found_tile_y = -1; int found_tile_r = -1; + AMXOpType op_type; Stmt visit(const Allocate *op) override { if (op->memory_type == MemoryType::AMXTile) { @@ -253,6 +346,12 @@ class ExtractTileOperations : public IRMutator { (op->type.is_float() && op->type.bits() == 32)) << "scheduled tile operations must yield 32-bit integers or 32-bit floats"; + if (op->type.is_int() && op->type.bits() == 32) { + op_type = AMXOpType::Int8; + } else { + op_type = AMXOpType::Bf16; + } + user_assert(!in_allocate) << "Found two possible tile allocations for AMX allocation"; ScopedValue old_amx_name(amx_name, op->name + ".amx"); ScopedValue old_tile_name(tile_name, op->name); @@ -269,7 +368,18 @@ class ExtractTileOperations : public IRMutator { body = mutate(body); } - return Allocate::make(amx_name, Int(32, 256), MemoryType::AMXTile, {1}, const_true(), body); + auto alloc_type = [&]() { + switch(op_type) { + case AMXOpType::Int8: + return Int(32, 256); + case AMXOpType::Bf16: + return Float(32, 256); + default: + return Type(); + } + }(); + + return Allocate::make(amx_name, alloc_type, MemoryType::AMXTile, {1}, const_true(), body); } return IRMutator::visit(op); } @@ -303,12 +413,12 @@ class ExtractTileOperations : public IRMutator { if (!load || load->name != tile_name) { return op; } - auto store = convert_to_tile_store(op, amx_name, found_tile_x, found_tile_y); + auto store = convert_to_tile_store(op, amx_name, found_tile_x, found_tile_y, op_type); user_assert(store.defined()) << "Store to AMX tile allocation of a non-tile value"; return store; } - auto matmul = convert_to_matmul(op, amx_name); + auto matmul = convert_to_matmul(op, amx_name, op_type); if (matmul.result) { user_assert( (found_tile_x < 0 || matmul.tile_x == found_tile_x) && @@ -326,7 +436,7 @@ class ExtractTileOperations : public IRMutator { return op; } - auto zero = convert_to_zero(op, found_tile_x, found_tile_y, amx_name); + auto zero = convert_to_zero(op, found_tile_x, found_tile_y, amx_name, op_type); if (zero.defined()) { return zero; } diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index 9a5337ab716a..ba6338f54eeb 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -7,6 +7,27 @@ using namespace Halide; +void fill_buffer_a_bf16(Buffer &buf, int row, int acc) { + for (int iy = 0; iy < row; ++iy) { + for (int ix = 0; ix < acc; ++ix) { + // value between 0 and 100 + bfloat16_t val = bfloat16_t(((float)rand() / (float)(RAND_MAX)) * 100.f); + buf(ix, iy) = val; + } + } +} + +void fill_buffer_b_bf16(Buffer &buf, int col, int acc) { + for (int iy = 0; iy < acc / 2; ++iy) { + for (int ix = 0; ix < col; ++ix) { + for (int ik = 0; ik < 2; ++ik) { + bfloat16_t val = bfloat16_t(((float)rand() / (float)(RAND_MAX)) * 100.f); + buf(ik, ix, iy) = val; + } + } + } +} + struct make_uint_t { template Type operator()(Args &&...args) const { @@ -76,32 +97,32 @@ bool matmul() { Var rxi("rxi"), ryi("ryi"); RVar rri("rri"), rro("rro"); mm.compute_at(mm.in(), y) - //.store_in(MemoryType::AMXTile) + .store_in(MemoryType::AMXTile) .update() // Split into (x,y) tile .tile(y, x, ryi, rxi, tile_y, tile_x, TailStrategy::GuardWithIf) // Split reduction dim by tile_r .split(r.x, rro, rri, tile_r) // Reorder so that the (x,y) tile is inside the inner ro loop - .reorder({rri, ryi, rxi, rro, y, x}); - //.atomic() - //.vectorize(rri) - //.vectorize(ryi) - //.vectorize(rxi); + .reorder({rri, ryi, rxi, rro, y, x}) + .atomic() + .vectorize(rri) + .vectorize(ryi) + .vectorize(rxi); // Schedule the initialization Var ixi("ixi"), iyi("iyi"); mm.compute_at(mm.in(), y) - .tile(y, x, iyi, ixi, tile_y, tile_x); - // .vectorize(iyi) - // .vectorize(ixi); + .tile(y, x, iyi, ixi, tile_y, tile_x) + .vectorize(iyi) + .vectorize(ixi); // Schedule the consumer Var mmxi("mmxi"), mmyi("mmyi"); mm.in() - .tile(y, x, mmyi, mmxi, tile_y, tile_x); - // .vectorize(mmyi) - // .vectorize(mmxi); + .tile(y, x, mmyi, mmxi, tile_y, tile_x) + .vectorize(mmyi) + .vectorize(mmxi); Buffer a_buf(acc, row); fill_buffer_a(a_buf, row, acc); @@ -153,10 +174,10 @@ void matmul_bf16() { const int acc = 16; Var x("x"), y("y"); - ImageParam A(Float(32), 2, "lhs"); - ImageParam B(Float(32), 3, "rhs"); + ImageParam A(BFloat(16), 2, "lhs"); + ImageParam B(BFloat(16), 3, "rhs"); - RDom r(0, acc, "racc"); + RDom r(0, acc, "acc"); Func mm("matmul"); mm(x, y) = cast(0); @@ -186,18 +207,33 @@ void matmul_bf16() { .vectorize(ixi) .vectorize(iyi); + // schedule the consumer Var mmxi("mmxi"), mmyi("mmyi"); mm.in() .tile(x, y, mmxi, mmyi, tile_x, tile_y) .vectorize(mmxi) .vectorize(mmyi); - Func result = mm.in(); result.print_loop_nest(); - result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul_bf16.ll", {A, B}, get_jit_target_from_environment()); + Buffer a_buf(acc, row); + fill_buffer_a_bf16(a_buf, row, acc); + A.set(a_buf); + + Buffer b_buf(2, col, acc / 2); + fill_buffer_b_bf16(b_buf, col, acc); + B.set(b_buf); + + Buffer out(col, row); + // Uncomment to check the asm + //result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul_bf16.ll", {A, B}, target); + //result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); + + auto time = Tools::benchmark(20, 20, [&]() { + result.realize(out); + }); } int main(int argc, char **argv) { From 97fb022ae7914fec2abde1c92528fffb04389cb6 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Wed, 19 May 2021 15:07:23 +0100 Subject: [PATCH 35/53] add missing bf16 intrinsics --- src/CodeGen_X86.cpp | 1 + src/runtime/x86_amx.ll | 19 ++++++++++++++++--- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 17ea4f01700a..1d1b7312e3fc 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -219,6 +219,7 @@ const x86Intrinsic intrinsic_defs[] = { {"tdpbuud", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), UInt(8, 1024), UInt(8, 1024)}, Target::AVX512_SapphireRapids}, {"tdpbf16ps", Float(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Float(32, 256), BFloat(16, 512), BFloat(16, 512)}, Target::AVX512_SapphireRapids}, {"tilezero_i32", Int(32, 256), "tile_zero", {Int(16), Int(16)}, Target::AVX512_SapphireRapids}, + {"tilezero_f32", Float(32, 256), "tile_zero", {Int(16), Int(16)}, Target::AVX512_SapphireRapids}, // CodeGen_LLVM cannot cope with returning Type() ie void*, and return type needs to be vector to trigger call_overloaded_intrin {"tilestored64_i32", Bool(2), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Int(32, 256)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, {"tilestored64_f32", Bool(2), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Float(32, 256)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, diff --git a/src/runtime/x86_amx.ll b/src/runtime/x86_amx.ll index 85c51d20a2ee..e4fd4179d636 100644 --- a/src/runtime/x86_amx.ll +++ b/src/runtime/x86_amx.ll @@ -6,6 +6,13 @@ define weak_odr <1024 x i8> @tileloadd64_i8(i16 %rows, i16 %colbytes, i8* %ptr, } declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) +define weak_odr <512 x i16> @tileloadd64_bf16(i16 %rows, i16 %colbytes, i8* %ptr, i64 %off, i64 %stride) nounwind alwaysinline readonly { + %1 = getelementptr i8, i8* %ptr, i64 %off + %2 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %rows, i16 %colbytes, i8* %1, i64 %stride) nounwind readonly + %3 = bitcast x86_amx %2 to <512 x i16> + ret <512 x i16> %3 +} + define weak_odr <256 x i32> @tdpbssd(i16 %rows, i16 %colbytes, i16 %acc, <256 x i32> %out, <1024 x i8> %lhs, <1024 x i8> %rhs) nounwind alwaysinline readnone { %1 = bitcast <1024 x i8> %lhs to x86_amx %2 = bitcast <1024 x i8> %rhs to x86_amx @@ -46,9 +53,9 @@ define weak_odr <256 x i32> @tdpbuud(i16 %rows, i16 %colbytes, i16 %acc, <256 x } declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) -define weak_odr <256 x float> @tdpbf16ps(i16 %rows, i16 %colbytes, i16 %acc, <256 x float> %out, <512 x bfloat> %lhs, <512 x bfloat> %rhs) nounwind alwaysinline readnone { - %1 = bitcast <512 x bfloat> %lhs to x86_amx - %2 = bitcast <512 x bfloat> %rhs to x86_amx +define weak_odr <256 x float> @tdpbf16ps(i16 %rows, i16 %colbytes, i16 %acc, <256 x float> %out, <512 x i16> %lhs, <512 x i16> %rhs) nounwind alwaysinline readnone { + %1 = bitcast <512 x i16> %lhs to x86_amx + %2 = bitcast <512 x i16> %rhs to x86_amx %3 = bitcast <256 x float> %out to x86_amx %4 = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 %rows, i16 %colbytes, i16 %acc, x86_amx %3, x86_amx %1, x86_amx %2) nounwind readnone %5 = bitcast x86_amx %4 to <256 x float> @@ -81,4 +88,10 @@ define weak_odr <256 x i32> @tilezero_i32(i16 %rows, i16 %colbytes) nounwind alw %2 = bitcast x86_amx %1 to <256 x i32> ret <256 x i32> %2 } + +define weak_odr <256 x float> @tilezero_f32(i16 %rows, i16 %colbytes) nounwind alwaysinline { + %1 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %rows, i16 %colbytes) nounwind + %2 = bitcast x86_amx %1 to <256 x float> + ret <256 x float> %2 +} declare x86_amx @llvm.x86.tilezero.internal(i16, i16) From f6ba7391c5e6b9c103fbf3f31dda51604f4c059a Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Fri, 21 May 2021 14:40:59 +0100 Subject: [PATCH 36/53] fix striding error when loading matrix --- src/ExtractTileOperations.cpp | 142 ++++++++++------------------------ 1 file changed, 43 insertions(+), 99 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index ebc4470ad95b..78a422becc6b 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -119,10 +119,10 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType o vector matches; if (op_type == AMXOpType::Int8) { - const auto pattern1 = wild_i32x + wild_i32x; - if (!expr_match(pattern1, op->value, matches)) { - return {}; - } + const auto pattern1 = wild_i32x + wild_i32x; + if (!expr_match(pattern1, op->value, matches)) { + return {}; + } } else // AMXOpType::Bf16 { const auto pattern1 = wild_f32x + wild_f32x; @@ -140,18 +140,13 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType o return {}; } - bool lhs_signed = false; if (op_type == AMXOpType::Int8) { - auto pattern2 = cast(Int(32, 0), cast(Int(32, 0), wild_i8x) * wild_i32x); - auto pattern2_unsigned = cast(Int(32, 0), cast(Int(32, 0), wild_u8x) * wild_i32x); + auto pattern2 = cast(Int(32, 0), cast(Int(32, 0), wild_i8x) * wild_i32x); + auto pattern2_unsigned = cast(Int(32, 0), cast(Int(32, 0), wild_u8x) * wild_i32x); - if (expr_match(pattern2, reduce->value, matches)) { - lhs_signed = true; - } else if (expr_match(pattern2_unsigned, reduce->value, matches)) { - lhs_signed = false; - } else { - return {}; - } + if (!(expr_match(pattern2, reduce->value, matches) || expr_match(pattern2_unsigned, reduce->value, matches))) { + return {}; + } } else { auto pattern2 = cast(Float(32, 0), cast(Float(32, 0), wild_bf16x) * wild_f32x); @@ -166,14 +161,9 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType o return {}; } const auto *rhs_cast = rhs_broadcast->value.as(); - bool rhs_signed = false; if (rhs_cast) { if (op_type == AMXOpType::Int8) { - if (rhs_cast->value.type().element_of() == Int(8)) { - rhs_signed = true; - } else if (rhs_cast->value.type().element_of() == UInt(8)) { - rhs_signed = false; - } else { + if (!(rhs_cast->value.type().element_of() == Int(8) || rhs_cast->value.type().element_of() == UInt(8))) { user_assert(false) << "Expected rhs cast of i8/u8"; } } else // AMXOpType::Bf16 @@ -207,52 +197,22 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType o } #if LLVM_VERSION < 130 - user_assert(op_type != AMXOpType::Bf16 && lhs_signed && rhs_signed) << "LLVM 13 or above is required for unsigned or float AMX instructions"; + user_assert(op_type != AMXOpType::Bf16 && + lhs_load->type.is_int() && rhs_cast->value.type().is_int()) + << "LLVM 13 or above is required for unsigned or float AMX instructions"; #endif // {rows, colbytes, var, index} auto lhs_var = Variable::make(Handle(), lhs_load->name); - auto lhs_type = [&]() { - switch (op_type) { - case AMXOpType::Int8: - if (lhs_signed) { - return Int(8, 1024); - } else { - return UInt(8, 1024); - } - case AMXOpType::Bf16: - return BFloat(16, 512); - default: - return Type(); - } - }(); - int element_width = 0; - switch (op_type) { - case AMXOpType::Int8: - element_width = 1; - break; - case AMXOpType::Bf16: - element_width = 2; - break; - } - auto lhs = Call::make(lhs_type, "tile_load", {tile_x, tile_r * element_width, lhs_var, lhs_tile.base, lhs_tile.stride[0]}, Call::Intrinsic); + const auto &lhs_load_type = lhs_load->type; + int element_width = lhs_load_type.bytes(); + auto lhs_type = lhs_load_type.with_lanes(1024 / element_width); + auto lhs = Call::make(lhs_type, "tile_load", {tile_x, tile_r * element_width, lhs_var, lhs_tile.base * element_width, print(lhs_tile.stride[0] * element_width, " <- lhs load stride")}, Call::Intrinsic); auto rhs_var = Variable::make(Handle(), rhs_load->name); - auto rhs_type = [&]() -> Type { - switch (op_type) { - case AMXOpType::Int8: - if (rhs_signed) { - return Int(8, 1024); - } else { - return UInt(8, 1024); - } - case AMXOpType::Bf16: - return BFloat(16, 512); - default: - return Type(); - } - }(); - auto rhs = Call::make(rhs_type, "tile_load", {1, tile_y * tile_r, rhs_var, rhs_tile.base, rhs_tile.stride[0]}, Call::Intrinsic); + const auto &rhs_load_type = rhs_load->type; + auto rhs_type = rhs_load_type.with_lanes(1024 / element_width); + auto rhs = Call::make(rhs_type, "tile_load", {1, tile_y * tile_r * element_width, rhs_var, rhs_tile.base * element_width, print(rhs_tile.stride[0] * tile_y * element_width, " <- rhs load stride, ", rhs_tile.stride[1], " <- rhs stride 1")}, Call::Intrinsic); auto res_type = [&]() { switch (op_type) { @@ -275,7 +235,7 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType o return {true, std::move(store), tile_x, tile_y, tile_r}; } -Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const string &new_name, AMXOpType op_type) { +Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const string &new_name) { if (const auto *ramp = op->index.as()) { if (const auto *bcast = op->value.as()) { if (is_const_one(ramp->stride) && @@ -284,18 +244,11 @@ Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const string &new_ auto rows = Cast::make(Int(16), tile_x); auto bytes = op->value.type().bytes(); auto colbytes = Cast::make(Int(16), tile_y * bytes); - auto type = [&]() { - switch (op_type) { - case AMXOpType::Int8: - return Int(32, 256); - case AMXOpType::Bf16: - return Float(32, 256); - default: - return Type(); - } - }(); - auto val = Call::make(type, "tile_zero", {rows, colbytes}, Call::Intrinsic); - auto store = Store::make(new_name, val, Ramp::make(0, 1, 256), Parameter(), const_true(256), ModulusRemainder()); + const auto &store_type = op->value.type(); + // will be f32 or i32 + auto tile_zero_type = store_type.with_lanes(1024 / store_type.bytes()); + auto val = Call::make(std::move(tile_zero_type), "tile_zero", {rows, colbytes}, Call::Intrinsic); + auto store = Store::make(new_name, std::move(val), Ramp::make(0, 1, 256), Parameter(), const_true(256), ModulusRemainder()); return store; } } @@ -303,26 +256,17 @@ Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const string &new_ return {}; } -Stmt convert_to_tile_store(const Store *op, const string &amx_name, int tile_x, int tile_y, AMXOpType op_type) { +Stmt convert_to_tile_store(const Store *op, const string &amx_name, int tile_x, int tile_y) { auto tile = is_2d_tile_index(op->index); if (tile.result && tile.extent[0] == tile_x && tile.extent[1] == tile_y) { auto out = Variable::make(Handle(), op->name); - auto tile_type = [&]() { - switch (op_type) { - case AMXOpType::Int8: - return Int(32, 256); - case AMXOpType::Bf16: - return Float(32, 256); - default: - return Type(); - } - }(); + auto tile_type = op->value.type().with_lanes(256); auto tile_val = Load::make(tile_type, amx_name, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); auto bytes = op->value.type().bytes(); - internal_assert(bytes == 4) << "AMX store only supported for int32 and float32, not for " << op->value.type() << "\n"; + internal_assert(bytes == 4) << "AMX store only supported for int32 and float32 output, not for " << op->value.type() << "\n"; // {tile_x, tile_y, var, base, stride} - auto store = Call::make(Bool(2), "tile_store", {tile_x, tile_y * bytes, out, tile.base * bytes, tile.stride[0] * bytes, tile_val}, Call::Intrinsic); - return Evaluate::make(store); + auto store = Call::make(Bool(2), "tile_store", {tile_x, tile_y * bytes, std::move(out), tile.base * bytes, tile.stride[0] * bytes, std::move(tile_val)}, Call::Intrinsic); + return Evaluate::make(std::move(store)); } return {}; } @@ -342,7 +286,7 @@ class ExtractTileOperations : public IRMutator { Stmt visit(const Allocate *op) override { if (op->memory_type == MemoryType::AMXTile) { user_assert( - (op->type.is_int() && op->type.bits() == 32) || + (op->type.is_int() && op->type.bits() == 32) || (op->type.is_float() && op->type.bits() == 32)) << "scheduled tile operations must yield 32-bit integers or 32-bit floats"; @@ -369,13 +313,13 @@ class ExtractTileOperations : public IRMutator { } auto alloc_type = [&]() { - switch(op_type) { - case AMXOpType::Int8: - return Int(32, 256); - case AMXOpType::Bf16: - return Float(32, 256); - default: - return Type(); + switch (op_type) { + case AMXOpType::Int8: + return Int(32, 256); + case AMXOpType::Bf16: + return Float(32, 256); + default: + return Type(); } }(); @@ -413,7 +357,7 @@ class ExtractTileOperations : public IRMutator { if (!load || load->name != tile_name) { return op; } - auto store = convert_to_tile_store(op, amx_name, found_tile_x, found_tile_y, op_type); + auto store = convert_to_tile_store(op, amx_name, found_tile_x, found_tile_y); user_assert(store.defined()) << "Store to AMX tile allocation of a non-tile value"; return store; } @@ -422,8 +366,8 @@ class ExtractTileOperations : public IRMutator { if (matmul.result) { user_assert( (found_tile_x < 0 || matmul.tile_x == found_tile_x) && - (found_tile_x < 0 || matmul.tile_x == found_tile_x) && - (found_tile_x < 0 || matmul.tile_x == found_tile_x)) + (found_tile_y < 0 || matmul.tile_y == found_tile_y) && + (found_tile_r < 0 || matmul.tile_r == found_tile_r)) << "Found different tile sizes for AMX tile allocation"; found_tile_x = matmul.tile_x; found_tile_y = matmul.tile_y; @@ -436,7 +380,7 @@ class ExtractTileOperations : public IRMutator { return op; } - auto zero = convert_to_zero(op, found_tile_x, found_tile_y, amx_name, op_type); + auto zero = convert_to_zero(op, found_tile_x, found_tile_y, amx_name); if (zero.defined()) { return zero; } From c7278b615f3edd920a5c567ba6e0a3f4c82abddd Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Fri, 21 May 2021 17:28:39 +0100 Subject: [PATCH 37/53] add checks to verify bf16 result --- test/performance/tiled_matmul.cpp | 45 ++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index ba6338f54eeb..08332be589f6 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -63,7 +63,10 @@ void fill_buffer_b(Buffer &buf, int col, int acc) { } template -bool matmul() { +bool matmul(Halide::Target target) { + // used for compiling to llvm IR or asm + (void)target; + constexpr bool lhs_signed = std::is_signed::value; constexpr bool rhs_signed = std::is_signed::value; @@ -167,7 +170,13 @@ auto matmul_us = &matmul; auto matmul_su = &matmul; auto matmul_uu = &matmul; -void matmul_bf16() { +bool equal_eps(float lhs, float rhs, float eps) { + return std::abs(lhs - rhs) < eps; +} + +bool matmul_bf16(Halide::Target target) { + (void)target; + // lhs: 32x16, rhs: 16x32 const int row = 32; const int col = 32; @@ -181,7 +190,7 @@ void matmul_bf16() { Func mm("matmul"); mm(x, y) = cast(0); - mm(x, y) += cast(A(r.x, y)) * B(r.x % 2, x, r.x / 2); + mm(x, y) += cast(cast(A(r.x, y))) * cast(B(r.x % 2, x, r.x / 2)); int tile_x = 8; int tile_y = 8; @@ -215,7 +224,7 @@ void matmul_bf16() { .vectorize(mmyi); Func result = mm.in(); - result.print_loop_nest(); + //result.print_loop_nest(); Buffer a_buf(acc, row); fill_buffer_a_bf16(a_buf, row, acc); @@ -234,6 +243,24 @@ void matmul_bf16() { auto time = Tools::benchmark(20, 20, [&]() { result.realize(out); }); + + std::cout << "Exec time: " << time << "\n"; + + for (int j = 0; j < row; ++j) { + for (int i = 0; i < col; ++i) { + float val = 0.f; + for (int k = 0; k < acc; ++k) { + val += static_cast(a_buf(k, j)) * static_cast(b_buf(k % 2, i, k / 2)); + } + if (!equal_eps(val, out(i, j), 0.01f)) { + std::cerr << "Invalid result at " << i << ", " << j << "\n" + << out(i, j) << " != " << val << "\n"; + return false; + } + } + } + std::cout << "Success!\n"; + return true; } int main(int argc, char **argv) { @@ -243,11 +270,11 @@ int main(int argc, char **argv) { return 0; } - //matmul_ss(); - //matmul_us(); - //matmul_su(); - //matmul_uu(); + matmul_ss(target); + matmul_us(target); + matmul_su(target); + matmul_uu(target); - matmul_bf16(); + matmul_bf16(target); return 0; } From 5e81a72e8ff890621e738cc555612f6b915d4986 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Fri, 21 May 2021 17:29:01 +0100 Subject: [PATCH 38/53] fix scaling of col_bytes on matmul call --- src/ExtractTileOperations.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 78a422becc6b..5733a643a8ef 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -207,12 +207,12 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType o const auto &lhs_load_type = lhs_load->type; int element_width = lhs_load_type.bytes(); auto lhs_type = lhs_load_type.with_lanes(1024 / element_width); - auto lhs = Call::make(lhs_type, "tile_load", {tile_x, tile_r * element_width, lhs_var, lhs_tile.base * element_width, print(lhs_tile.stride[0] * element_width, " <- lhs load stride")}, Call::Intrinsic); + auto lhs = Call::make(lhs_type, "tile_load", {tile_x, tile_r * element_width, lhs_var, lhs_tile.base * element_width, lhs_tile.stride[0] * element_width}, Call::Intrinsic); auto rhs_var = Variable::make(Handle(), rhs_load->name); const auto &rhs_load_type = rhs_load->type; auto rhs_type = rhs_load_type.with_lanes(1024 / element_width); - auto rhs = Call::make(rhs_type, "tile_load", {1, tile_y * tile_r * element_width, rhs_var, rhs_tile.base * element_width, print(rhs_tile.stride[0] * tile_y * element_width, " <- rhs load stride, ", rhs_tile.stride[1], " <- rhs stride 1")}, Call::Intrinsic); + auto rhs = Call::make(rhs_type, "tile_load", {1, tile_y * tile_r * element_width, rhs_var, rhs_tile.base * element_width, rhs_tile.stride[0] * tile_y * element_width}, Call::Intrinsic); auto res_type = [&]() { switch (op_type) { @@ -228,7 +228,8 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType o // {rows, colbytes, acc, out, lhs, rhs} auto out = Load::make(res_type, new_name, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); - auto colbytes = tile_y * 32 / rhs_load->type.bits(); + // 4 bytes for i32, f32 + auto colbytes = tile_y * 4; auto matmul = Call::make(res_type, "tile_matmul", {tile_x, colbytes, tile_r, out, lhs, rhs}, Call::Intrinsic); auto store = Store::make(new_name, matmul, Ramp::make(0, 1, 256), Parameter(), const_true(256), ModulusRemainder()); From ea74fe23ff99095cc5f6cba89f33e745537650c8 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Thu, 8 Jul 2021 12:30:06 +0100 Subject: [PATCH 39/53] move brace to previous line --- src/ExtractTileOperations.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 5733a643a8ef..99fa4ffeef79 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -123,8 +123,7 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType o if (!expr_match(pattern1, op->value, matches)) { return {}; } - } else // AMXOpType::Bf16 - { + } else { // AMXOpType::Bf16 const auto pattern1 = wild_f32x + wild_f32x; if (!expr_match(pattern1, op->value, matches)) { return {}; From a854dc988e4d9a136c4af6a7213d6f3ff0c537f6 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Fri, 9 Jul 2021 16:30:53 +0100 Subject: [PATCH 40/53] derive result type using a function rather than lambda --- src/ExtractTileOperations.cpp | 37 ++++++++++++++--------------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 99fa4ffeef79..c951390c711e 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -26,6 +26,18 @@ enum class AMXOpType { Bf16, }; +/// returns the appropriate `Halide::Type` for the given operation type +Type amx_op_type_result_type(AMXOpType op_ty) { + switch (op_ty) { + case AMXOpType::Int8: + return Int(32, 256); + case AMXOpType::Bf16: + return Float(32, 256); + default: + return Type(); + } +} + const auto wild_i32 = Variable::make(Int(32), "*"); const auto wild_i32x = Variable::make(Int(32, 0), "*"); @@ -123,7 +135,7 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType o if (!expr_match(pattern1, op->value, matches)) { return {}; } - } else { // AMXOpType::Bf16 + } else { // AMXOpType::Bf16 const auto pattern1 = wild_f32x + wild_f32x; if (!expr_match(pattern1, op->value, matches)) { return {}; @@ -212,17 +224,7 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType o const auto &rhs_load_type = rhs_load->type; auto rhs_type = rhs_load_type.with_lanes(1024 / element_width); auto rhs = Call::make(rhs_type, "tile_load", {1, tile_y * tile_r * element_width, rhs_var, rhs_tile.base * element_width, rhs_tile.stride[0] * tile_y * element_width}, Call::Intrinsic); - - auto res_type = [&]() { - switch (op_type) { - case AMXOpType::Int8: - return Int(32, 256); - case AMXOpType::Bf16: - return Float(32, 256); - default: - return Type(); - } - }(); + auto res_type = amx_op_type_result_type(op_type); // {rows, colbytes, acc, out, lhs, rhs} auto out = Load::make(res_type, new_name, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); @@ -312,16 +314,7 @@ class ExtractTileOperations : public IRMutator { body = mutate(body); } - auto alloc_type = [&]() { - switch (op_type) { - case AMXOpType::Int8: - return Int(32, 256); - case AMXOpType::Bf16: - return Float(32, 256); - default: - return Type(); - } - }(); + auto alloc_type = amx_op_type_result_type(op_type); return Allocate::make(amx_name, alloc_type, MemoryType::AMXTile, {1}, const_true(), body); } From 26014d220c45de787fb3fb016c8d694119ad7f6a Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Fri, 9 Jul 2021 17:01:00 +0100 Subject: [PATCH 41/53] run clang tidy and format --- src/ExtractTileOperations.cpp | 2 +- test/performance/tiled_matmul.cpp | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index c951390c711e..da2e7effb6b8 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -249,7 +249,7 @@ Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const string &new_ const auto &store_type = op->value.type(); // will be f32 or i32 auto tile_zero_type = store_type.with_lanes(1024 / store_type.bytes()); - auto val = Call::make(std::move(tile_zero_type), "tile_zero", {rows, colbytes}, Call::Intrinsic); + auto val = Call::make(tile_zero_type, "tile_zero", {rows, colbytes}, Call::Intrinsic); auto store = Store::make(new_name, std::move(val), Ramp::make(0, 1, 256), Parameter(), const_true(256), ModulusRemainder()); return store; } diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index 08332be589f6..d71a010502b1 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -209,7 +209,7 @@ bool matmul_bf16(Halide::Target target) { .vectorize(rri) .vectorize(rxi) .vectorize(ryi); - + Var ixi("ixi"), iyi("iyi"); mm.compute_at(mm.in(), x) .tile(x, y, ixi, iyi, tile_x, tile_y) @@ -222,7 +222,7 @@ bool matmul_bf16(Halide::Target target) { .tile(x, y, mmxi, mmyi, tile_x, tile_y) .vectorize(mmxi) .vectorize(mmyi); - + Func result = mm.in(); //result.print_loop_nest(); @@ -274,7 +274,7 @@ int main(int argc, char **argv) { matmul_us(target); matmul_su(target); matmul_uu(target); - + matmul_bf16(target); return 0; } From 34557cb48722fadeb03a3febbfbd61883d902144 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Thu, 2 Sep 2021 12:09:09 +0100 Subject: [PATCH 42/53] have tile_store return i32 --- src/CodeGen_LLVM.cpp | 9 ++++----- src/CodeGen_X86.cpp | 5 ++--- src/ExtractTileOperations.cpp | 2 +- src/runtime/x86_amx.ll | 8 ++++---- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 915238f23a40..537e7156cef5 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -2529,12 +2529,11 @@ void CodeGen_LLVM::visit(const Call *op) { internal_assert(op->is_extern() || op->is_intrinsic()) << "Can only codegen extern calls and intrinsics\n"; - if (op->type.is_vector()) { - value = call_overloaded_intrin(op->type, op->name, op->args); - if (value) { - return; - } + value = call_overloaded_intrin(op->type, op->name, op->args); + if (value) { + return; } + // Some call nodes are actually injected at various stages as a // cue for llvm to generate particular ops. In general these are diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 1d1b7312e3fc..99b6f2eb98fc 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -220,9 +220,8 @@ const x86Intrinsic intrinsic_defs[] = { {"tdpbf16ps", Float(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Float(32, 256), BFloat(16, 512), BFloat(16, 512)}, Target::AVX512_SapphireRapids}, {"tilezero_i32", Int(32, 256), "tile_zero", {Int(16), Int(16)}, Target::AVX512_SapphireRapids}, {"tilezero_f32", Float(32, 256), "tile_zero", {Int(16), Int(16)}, Target::AVX512_SapphireRapids}, - // CodeGen_LLVM cannot cope with returning Type() ie void*, and return type needs to be vector to trigger call_overloaded_intrin - {"tilestored64_i32", Bool(2), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Int(32, 256)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, - {"tilestored64_f32", Bool(2), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Float(32, 256)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, + {"tilestored64_i32", Int(32), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Int(32, 256)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, + {"tilestored64_f32", Int(32), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Float(32, 256)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, }; // clang-format on diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index da2e7effb6b8..b8916b9982c5 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -267,7 +267,7 @@ Stmt convert_to_tile_store(const Store *op, const string &amx_name, int tile_x, auto bytes = op->value.type().bytes(); internal_assert(bytes == 4) << "AMX store only supported for int32 and float32 output, not for " << op->value.type() << "\n"; // {tile_x, tile_y, var, base, stride} - auto store = Call::make(Bool(2), "tile_store", {tile_x, tile_y * bytes, std::move(out), tile.base * bytes, tile.stride[0] * bytes, std::move(tile_val)}, Call::Intrinsic); + auto store = Call::make(Int(32), "tile_store", {tile_x, tile_y * bytes, std::move(out), tile.base * bytes, tile.stride[0] * bytes, std::move(tile_val)}, Call::Intrinsic); return Evaluate::make(std::move(store)); } return {}; diff --git a/src/runtime/x86_amx.ll b/src/runtime/x86_amx.ll index e4fd4179d636..6c7f4659f0c8 100644 --- a/src/runtime/x86_amx.ll +++ b/src/runtime/x86_amx.ll @@ -63,19 +63,19 @@ define weak_odr <256 x float> @tdpbf16ps(i16 %rows, i16 %colbytes, i16 %acc, <25 } declare x86_amx @llvm.x86.tdpbf16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) -define weak_odr <2 x i1> @tilestored64_i32(i16 %rows, i16 %cols, i8* %ptr, i64 %off, i64 %stride, <256 x i32> %val) nounwind alwaysinline writeonly { +define weak_odr i32 @tilestored64_i32(i16 %rows, i16 %cols, i8* %ptr, i64 %off, i64 %stride, <256 x i32> %val) nounwind alwaysinline writeonly { %1 = getelementptr i8, i8* %ptr, i64 %off %2 = bitcast <256 x i32> %val to x86_amx tail call void @llvm.x86.tilestored64.internal(i16 %rows, i16 %cols, i8* %1, i64 %stride, x86_amx %2) nounwind writeonly - ret <2 x i1> zeroinitializer + ret i32 zeroinitializer ; return 0 since Halide has no void return value } declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) -define weak_odr <2 x i1> @tilestored64_f32(i16 %rows, i16 %cols, i8* %ptr, i64 %off, i64 %stride, <256 x float> %val) nounwind alwaysinline writeonly { +define weak_odr i32 @tilestored64_f32(i16 %rows, i16 %cols, i8* %ptr, i64 %off, i64 %stride, <256 x float> %val) nounwind alwaysinline writeonly { %1 = getelementptr i8, i8* %ptr, i64 %off %2 = bitcast <256 x float> %val to x86_amx tail call void @llvm.x86.tilestored64.internal(i16 %rows, i16 %cols, i8* %1, i64 %stride, x86_amx %2) nounwind writeonly - ret <2 x i1> zeroinitializer + ret i32 zeroinitializer } ; NB: Even though this should be readnone, that will cause LLVM to try to From 5ad06e0907d39d3cc1073ec6e4124e2342f1f4c2 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Wed, 8 Sep 2021 13:34:29 +0100 Subject: [PATCH 43/53] make is_3d_tile_index robust to indexing changes --- src/ExtractTileOperations.cpp | 44 ++++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index b8916b9982c5..27f6bd600c52 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -57,15 +57,28 @@ Tile<2> is_2d_tile_index(const Expr &e) { Tile<3> is_3d_tile_index(const Expr &e) { vector matches; - auto add_sub_pattern = (wild_i32x + wild_i32x) - wild_i32x; - if (!expr_match(add_sub_pattern, e, matches)) { + + // there could be a sub node + const Sub* sub = e.as(); + + const Add* add = nullptr; + + if (sub) { + add = sub->a.as(); + } + else { + add = e.as(); + } + + if (!add) { return {}; } - // ramp(x16(base), x16(stride), 4) + x16(ramp(idx, 1, 4)) y: 4, x: 4, r: 4 - // ramp(x10(base), x10(stride), 3) + x6(ramp(idx, 1, 5)) y: 2, x: 3, r: 5 - Expr first = std::move(matches[0]); - Expr second = std::move(matches[1]); - Expr adj = std::move(matches[2]); + + auto& first = add->a; + auto& second = add->b; + + // ramp(x[x*r](base), x[x*r](stride), x) + x[x*y](ramp(idx, 1, r)) + const auto *r1 = first.as(); const auto *b2 = second.as(); if (!r1 && !b2) { @@ -105,11 +118,20 @@ Tile<3> is_3d_tile_index(const Expr &e) { base += std::move(matches[0]); Expr r_stride = std::move(matches[1]); - auto pattern3 = Broadcast::make(wild_i32, b1->lanes * r1->lanes); - if (!expr_match(pattern3, adj, matches)) { - return {}; + if (sub) { + Expr adj = sub->b; + const Broadcast* bcast = adj.as(); + + if (!bcast) { + return {}; + } + + if (bcast->lanes != b1->lanes * r1->lanes) { + return {}; + } + + base -= bcast->value; } - base -= std::move(matches[0]); return {true, base, {x_stride, 0, r_stride}, {x_tile, y_tile, r_tile}}; } From 7cab15523dfda64751bea4afe647491a5c586aeb Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Wed, 8 Sep 2021 17:11:01 +0100 Subject: [PATCH 44/53] apply formatting suggestions --- src/CodeGen_LLVM.cpp | 1 - src/ExtractTileOperations.cpp | 13 ++++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 46492e6e105f..af2950e0f6b0 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -2568,7 +2568,6 @@ void CodeGen_LLVM::visit(const Call *op) { if (value) { return; } - // Some call nodes are actually injected at various stages as a // cue for llvm to generate particular ops. In general these are diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 27f6bd600c52..c236292c6bfa 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -59,14 +59,13 @@ Tile<3> is_3d_tile_index(const Expr &e) { vector matches; // there could be a sub node - const Sub* sub = e.as(); + const Sub *sub = e.as(); - const Add* add = nullptr; + const Add *add = nullptr; if (sub) { add = sub->a.as(); - } - else { + } else { add = e.as(); } @@ -74,8 +73,8 @@ Tile<3> is_3d_tile_index(const Expr &e) { return {}; } - auto& first = add->a; - auto& second = add->b; + auto &first = add->a; + auto &second = add->b; // ramp(x[x*r](base), x[x*r](stride), x) + x[x*y](ramp(idx, 1, r)) @@ -120,7 +119,7 @@ Tile<3> is_3d_tile_index(const Expr &e) { if (sub) { Expr adj = sub->b; - const Broadcast* bcast = adj.as(); + const Broadcast *bcast = adj.as(); if (!bcast) { return {}; From 8f835441c278d43c6a9d0e2d7c7f43d1a78aa498 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Thu, 9 Sep 2021 10:54:26 +0100 Subject: [PATCH 45/53] both first and second can be const qualified --- src/ExtractTileOperations.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index c236292c6bfa..97d240566008 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -73,8 +73,8 @@ Tile<3> is_3d_tile_index(const Expr &e) { return {}; } - auto &first = add->a; - auto &second = add->b; + const auto &first = add->a; + const auto &second = add->b; // ramp(x[x*r](base), x[x*r](stride), x) + x[x*y](ramp(idx, 1, r)) From b1e1452409ca93b04e90a33681f396a49ec4face Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Thu, 9 Sep 2021 13:07:37 +0100 Subject: [PATCH 46/53] remove trailing whitespace in unformatted section --- src/CodeGen_X86.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 99b6f2eb98fc..32c8abbdacbc 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -200,7 +200,7 @@ const x86Intrinsic intrinsic_defs[] = { {"dpwssdx16", Int(32, 16), "dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_SapphireRapids}, {"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids}, - {"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids}, + {"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids}, {"dpbusdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_SapphireRapids}, {"dpbusdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_SapphireRapids}, From 14df0bcfb9ee406c6a7e1f6fe082bcc170c2a846 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Tue, 21 Sep 2021 13:15:27 +0100 Subject: [PATCH 47/53] make requested style changes --- src/ExtractTileOperations.cpp | 40 ++++++++++++++----------------- test/performance/tiled_matmul.cpp | 1 - 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 97d240566008..4bbbc6a53595 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -1,9 +1,9 @@ #include "ExtractTileOperations.h" -#include "IRMatch.h" // expr_match +#include "IRMatch.h" #include "IRMutator.h" -#include "IROperator.h" // Expr + Expr -#include "Util.h" // ScopedValue +#include "IROperator.h" +#include "Util.h" namespace Halide { namespace Internal { @@ -23,7 +23,7 @@ struct Tile { enum class AMXOpType { Int8, - Bf16, + Bfloat16, }; /// returns the appropriate `Halide::Type` for the given operation type @@ -31,17 +31,17 @@ Type amx_op_type_result_type(AMXOpType op_ty) { switch (op_ty) { case AMXOpType::Int8: return Int(32, 256); - case AMXOpType::Bf16: + case AMXOpType::Bfloat16: return Float(32, 256); default: - return Type(); + internal_error << "Unexpected"; } } const auto wild_i32 = Variable::make(Int(32), "*"); const auto wild_i32x = Variable::make(Int(32, 0), "*"); -Tile<2> is_2d_tile_index(const Expr &e) { +Tile<2> get_2d_tile_index(const Expr &e) { // ramp(ramp(base, 1, 4), x4(stride), 4) vector matches; if (const auto *r1 = e.as()) { @@ -55,12 +55,11 @@ Tile<2> is_2d_tile_index(const Expr &e) { return {}; } -Tile<3> is_3d_tile_index(const Expr &e) { +Tile<3> get_3d_tile_index(const Expr &e) { vector matches; // there could be a sub node const Sub *sub = e.as(); - const Add *add = nullptr; if (sub) { @@ -156,7 +155,7 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType o if (!expr_match(pattern1, op->value, matches)) { return {}; } - } else { // AMXOpType::Bf16 + } else { // AMXOpType::Bfloat16 const auto pattern1 = wild_f32x + wild_f32x; if (!expr_match(pattern1, op->value, matches)) { return {}; @@ -198,8 +197,7 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType o if (!(rhs_cast->value.type().element_of() == Int(8) || rhs_cast->value.type().element_of() == UInt(8))) { user_assert(false) << "Expected rhs cast of i8/u8"; } - } else // AMXOpType::Bf16 - { + } else { // AMXOpType::Bfloat16 user_assert(rhs_cast->value.type().element_of() == BFloat(16)) << "Expected rhs cast of bf16"; } } else { @@ -211,8 +209,8 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType o return {}; } - const auto lhs_tile = is_3d_tile_index(lhs_load->index); - const auto rhs_tile = is_2d_tile_index(rhs_load->index); + const auto lhs_tile = get_3d_tile_index(lhs_load->index); + const auto rhs_tile = get_2d_tile_index(rhs_load->index); if (!lhs_tile.result || !rhs_tile.result) { return {}; } @@ -229,7 +227,7 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType o } #if LLVM_VERSION < 130 - user_assert(op_type != AMXOpType::Bf16 && + user_assert(op_type != AMXOpType::Bfloat16 && lhs_load->type.is_int() && rhs_cast->value.type().is_int()) << "LLVM 13 or above is required for unsigned or float AMX instructions"; #endif @@ -252,8 +250,7 @@ NewMatmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType o // 4 bytes for i32, f32 auto colbytes = tile_y * 4; - auto matmul = - Call::make(res_type, "tile_matmul", {tile_x, colbytes, tile_r, out, lhs, rhs}, Call::Intrinsic); + auto matmul = Call::make(res_type, "tile_matmul", {tile_x, colbytes, tile_r, out, lhs, rhs}, Call::Intrinsic); auto store = Store::make(new_name, matmul, Ramp::make(0, 1, 256), Parameter(), const_true(256), ModulusRemainder()); return {true, std::move(store), tile_x, tile_y, tile_r}; } @@ -280,7 +277,7 @@ Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const string &new_ } Stmt convert_to_tile_store(const Store *op, const string &amx_name, int tile_x, int tile_y) { - auto tile = is_2d_tile_index(op->index); + auto tile = get_2d_tile_index(op->index); if (tile.result && tile.extent[0] == tile_x && tile.extent[1] == tile_y) { auto out = Variable::make(Handle(), op->name); auto tile_type = op->value.type().with_lanes(256); @@ -316,10 +313,10 @@ class ExtractTileOperations : public IRMutator { if (op->type.is_int() && op->type.bits() == 32) { op_type = AMXOpType::Int8; } else { - op_type = AMXOpType::Bf16; + op_type = AMXOpType::Bfloat16; } - user_assert(!in_allocate) << "Found two possible tile allocations for AMX allocation"; + user_assert(!in_allocate) << "Already in AMX allocation: " << amx_name; ScopedValue old_amx_name(amx_name, op->name + ".amx"); ScopedValue old_tile_name(tile_name, op->name); ScopedValue old_in_alloc(in_allocate, true); @@ -336,7 +333,6 @@ class ExtractTileOperations : public IRMutator { } auto alloc_type = amx_op_type_result_type(op_type); - return Allocate::make(amx_name, alloc_type, MemoryType::AMXTile, {1}, const_true(), body); } return IRMutator::visit(op); @@ -355,7 +351,7 @@ class ExtractTileOperations : public IRMutator { } auto body = mutate(op->body); - return ProducerConsumer::make(amx_name, op->is_producer, body); + return ProducerConsumer::make(amx_name, op->is_producer, std::move(body)); } Expr visit(const Load *op) override { diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index d71a010502b1..b8d51954ee83 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -224,7 +224,6 @@ bool matmul_bf16(Halide::Target target) { .vectorize(mmyi); Func result = mm.in(); - //result.print_loop_nest(); Buffer a_buf(acc, row); fill_buffer_a_bf16(a_buf, row, acc); From 8b63d77f0ea6daf186dabed7972697f91a0bc980 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Fri, 1 Oct 2021 14:31:25 +0100 Subject: [PATCH 48/53] rename NewMatmul -> Matmul --- src/ExtractTileOperations.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 4bbbc6a53595..283ed050d2cb 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -134,7 +134,7 @@ Tile<3> get_3d_tile_index(const Expr &e) { return {true, base, {x_stride, 0, r_stride}, {x_tile, y_tile, r_tile}}; } -struct NewMatmul { +struct Matmul { bool result = false; Stmt stmt; int tile_x; @@ -142,7 +142,7 @@ struct NewMatmul { int tile_r; }; -NewMatmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType op_type) { +Matmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType op_type) { // m[ramp(0, 1, S)] = VectorAdd(lhs[{XYR tile}] * xX(rhs[{YR tile}])) + m[ramp(0, 1, S)] const auto wild_i8x = Variable::make(Int(8, 0), "*"); const auto wild_u8x = Variable::make(UInt(8, 0), "*"); From 6a5eeaa80a2e2216448019e29d67d690d20ed96e Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Fri, 1 Oct 2021 14:32:04 +0100 Subject: [PATCH 49/53] fix warning about missing return value --- src/ExtractTileOperations.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 283ed050d2cb..493978e086eb 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -35,6 +35,7 @@ Type amx_op_type_result_type(AMXOpType op_ty) { return Float(32, 256); default: internal_error << "Unexpected"; + return Type(); } } From cc7c97da22fd596fa242ae51ace23354ce177599 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Fri, 1 Oct 2021 14:42:34 +0100 Subject: [PATCH 50/53] use get_1d_tile_index to handle special case When using `Buffer` instead of `ImageParam` the `Ramp` expression generated is 1D instead of 2D, therefore we recognize this with a special case. The lanes are still matched against the dimensions of the LHS 3d tile lanes. --- src/ExtractTileOperations.cpp | 45 ++++++++++++++++++++++++++++++----- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 493978e086eb..08df3ff7e39f 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -42,6 +42,14 @@ Type amx_op_type_result_type(AMXOpType op_ty) { const auto wild_i32 = Variable::make(Int(32), "*"); const auto wild_i32x = Variable::make(Int(32, 0), "*"); +Tile<1> get_1d_tile_index(const Expr &e) { + if (const auto *r1 = e.as()) { + return {true, r1->base, {r1->stride}, {r1->lanes}}; + } + + return {}; +} + Tile<2> get_2d_tile_index(const Expr &e) { // ramp(ramp(base, 1, 4), x4(stride), 4) vector matches; @@ -211,8 +219,8 @@ Matmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType op_t } const auto lhs_tile = get_3d_tile_index(lhs_load->index); - const auto rhs_tile = get_2d_tile_index(rhs_load->index); - if (!lhs_tile.result || !rhs_tile.result) { + + if (!lhs_tile.result) { return {}; } @@ -220,10 +228,35 @@ Matmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType op_t const int tile_y = lhs_tile.extent[1]; const int tile_r = lhs_tile.extent[2]; const int factor = reduce->value.type().lanes() / reduce->type.lanes(); + + Expr rhs_base; + Expr rhs_stride; + + const auto rhs_tile2 = get_2d_tile_index(rhs_load->index); + if (!rhs_tile2.result) { + const auto rhs_tile1 = get_1d_tile_index(rhs_load->index); + + if (!rhs_tile1.result) { + return {}; + } + + if (rhs_tile1.extent[0] != tile_y * tile_r) { + return {}; + } + + rhs_base = rhs_tile1.base; + rhs_stride = rhs_tile1.stride[0]; + } else { + if (tile_y != rhs_tile2.extent[0] || tile_r != rhs_tile2.extent[1]) { + return {}; + } + + rhs_base = rhs_tile2.base; + rhs_stride = rhs_tile2.stride[0]; + } + if (op->index.type().lanes() != tile_x * tile_y || - factor != tile_r || - tile_y != rhs_tile.extent[0] || - tile_r != rhs_tile.extent[1]) { + factor != tile_r) { return {}; } @@ -243,7 +276,7 @@ Matmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType op_t auto rhs_var = Variable::make(Handle(), rhs_load->name); const auto &rhs_load_type = rhs_load->type; auto rhs_type = rhs_load_type.with_lanes(1024 / element_width); - auto rhs = Call::make(rhs_type, "tile_load", {1, tile_y * tile_r * element_width, rhs_var, rhs_tile.base * element_width, rhs_tile.stride[0] * tile_y * element_width}, Call::Intrinsic); + auto rhs = Call::make(rhs_type, "tile_load", {1, tile_y * tile_r * element_width, rhs_var, rhs_base * element_width, rhs_stride * tile_y * element_width}, Call::Intrinsic); auto res_type = amx_op_type_result_type(op_type); // {rows, colbytes, acc, out, lhs, rhs} From 014f0c648abb1fc8647245ff47285801537ce49d Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Fri, 1 Oct 2021 15:33:35 +0100 Subject: [PATCH 51/53] add correctness test for AMX instructions --- test/correctness/CMakeLists.txt | 1 + test/correctness/tiled_matmul.cpp | 262 ++++++++++++++++++++++++++++++ 2 files changed, 263 insertions(+) create mode 100644 test/correctness/tiled_matmul.cpp diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 370fe711c663..743ece0565ff 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -309,6 +309,7 @@ tests(GROUPS correctness strided_load.cpp target.cpp thread_safety.cpp + tiled_matmul.cpp tracing.cpp tracing_bounds.cpp tracing_broadcast.cpp diff --git a/test/correctness/tiled_matmul.cpp b/test/correctness/tiled_matmul.cpp new file mode 100644 index 000000000000..22b8cf212cd1 --- /dev/null +++ b/test/correctness/tiled_matmul.cpp @@ -0,0 +1,262 @@ +#include "Halide.h" +#include + +using namespace Halide; + +void fill_buffer_a_bf16(Buffer &buf, int row, int acc) { + for (int iy = 0; iy < row; ++iy) { + for (int ix = 0; ix < acc; ++ix) { + // value between 0 and 100 + bfloat16_t val = bfloat16_t(((float)rand() / (float)(RAND_MAX)) * 100.f); + buf(ix, iy) = val; + } + } +} + +void fill_buffer_b_bf16(Buffer &buf, int col, int acc) { + for (int iy = 0; iy < acc / 2; ++iy) { + for (int ix = 0; ix < col; ++ix) { + for (int ik = 0; ik < 2; ++ik) { + bfloat16_t val = bfloat16_t(((float)rand() / (float)(RAND_MAX)) * 100.f); + buf(ik, ix, iy) = val; + } + } + } +} + +template +void fill_buffer_a(Buffer &buf, int row, int acc) { + for (int iy = 0; iy < row; iy++) { + for (int ix = 0; ix < acc; ix++) { + buf(ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } +} + +template +void fill_buffer_b(Buffer &buf, int col, int acc) { + for (int iy = 0; iy < acc / 4; iy++) { + for (int ix = 0; ix < col; ix++) { + for (int ik = 0; ik < 4; ++ik) { + buf(ik, ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } + } +} + +bool equal_eps(float lhs, float rhs, float eps) { + return std::abs(lhs - rhs) < eps; +} + +struct make_uint_t { + template + Type operator()(Args &&...args) const { + return UInt(static_cast(args)...); + } +}; + +struct make_int_t { + template + Type operator()(Args &&...args) const { + return Int(static_cast(args)...); + } +}; + +template +bool matmul() { + constexpr bool lhs_signed = std::is_signed::value; + constexpr bool rhs_signed = std::is_signed::value; + + auto lhs = typename std::conditional::type{}; + auto rhs = typename std::conditional::type{}; + + constexpr int row = 16; + constexpr int col = 16; + constexpr int acc = 16; + + Buffer A_buf(acc, row); + Buffer B_buf(4, col, acc / 4); + + Var x("x"), y("y"); + RDom r(0, acc); + + Func mm("matmul"); + mm(x, y) = cast(0); + mm(x, y) += cast(A_buf(r, y)) * cast(B_buf(r % 4, x, r / 4)); + + constexpr int tile_x = 8; + constexpr int tile_y = 8; + constexpr int tile_r = 4; + + Var rxi("rxi"), ryi("ryi"); + RVar rri("rri"), rro("rro"); + + mm.compute_at(mm.in(), x) + .store_in(MemoryType::AMXTile) + .update() + .tile(x, y, rxi, ryi, tile_x, tile_y, TailStrategy::GuardWithIf) + .split(r, rro, rri, tile_r) + .reorder(rri, rxi, ryi, rro, x, y) + .atomic() + .vectorize(rri) + .vectorize(rxi) + .vectorize(ryi); + + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), x) + .tile(x, y, ixi, iyi, tile_x, tile_y) + .vectorize(ixi) + .vectorize(iyi); + + // schedule the consumer + Var mmxi("mmxi"), mmyi("mmyi"); + mm.in() + .tile(x, y, mmxi, mmyi, tile_x, tile_y) + .vectorize(mmxi) + .vectorize(mmyi); + + Func result = mm.in(); + + fill_buffer_a(A_buf, row, acc); + fill_buffer_b(B_buf, col, acc); + + Buffer out(col, row); + + result.realize(out); + + for (int j = 0; j < row; ++j) { + for (int i = 0; i < col; ++i) { + int32_t val = 0; + for (int k = 0; k < acc; ++k) { + val += static_cast(A_buf(k, j)) * static_cast(B_buf(k % 4, i, k / 4)); + } + if (val != out(i, j)) { + std::cerr << "Invalid result at " << i << ", " << j << "\n" + << out(i, j) << " != " << val << "\n"; + return false; + } + } + } + + return true; +} + +bool matmul_bf16() { + // lhs: 32x16, rhs: 16x32 + const int row = 32; + const int col = 32; + const int acc = 16; + + Var x("x"), y("y"); + Buffer A(acc, row); + Buffer B(2, col, acc / 2); + + RDom r(0, acc, "acc"); + + Func mm("matmul"); + mm(x, y) = cast(0); + mm(x, y) += cast(cast(A(r.x, y))) * cast(B(r.x % 2, x, r.x / 2)); + + int tile_x = 8; + int tile_y = 8; + int tile_r = 2; + + Var rxi("rxi"), ryi("ryi"); + RVar rri("rri"), rro("rro"); + + mm.compute_at(mm.in(), x) + .store_in(MemoryType::AMXTile) + .update() + .tile(x, y, rxi, ryi, tile_x, tile_y, TailStrategy::GuardWithIf) + .split(r.x, rro, rri, tile_r) + .reorder({rri, rxi, ryi, rro, x, y}) + .atomic() + .vectorize(rri) + .vectorize(rxi) + .vectorize(ryi); + + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), x) + .tile(x, y, ixi, iyi, tile_x, tile_y) + .vectorize(ixi) + .vectorize(iyi); + + // schedule the consumer + Var mmxi("mmxi"), mmyi("mmyi"); + mm.in() + .tile(x, y, mmxi, mmyi, tile_x, tile_y) + .vectorize(mmxi) + .vectorize(mmyi); + + Func result = mm.in(); + + fill_buffer_a_bf16(A, row, acc); + fill_buffer_b_bf16(B, col, acc); + + Buffer out(col, row); + + // Uncomment to check the asm + //result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul_bf16.ll", {A, B}, target); + //result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); + + result.realize(out); + + for (int j = 0; j < row; ++j) { + for (int i = 0; i < col; ++i) { + float val = 0.f; + for (int k = 0; k < acc; ++k) { + val += static_cast(A(k, j)) * static_cast(B(k % 2, i, k / 2)); + } + if (!equal_eps(val, out(i, j), 0.01f)) { + std::cerr << "Invalid result at " << i << ", " << j << "\n" + << out(i, j) << " != " << val << "\n"; + return false; + } + } + } + + return true; +} + +auto matmul_ss = &matmul; +auto matmul_us = &matmul; +auto matmul_su = &matmul; +auto matmul_uu = &matmul; + +int main(int argc, char **argv) { + Target t = get_jit_target_from_environment(); + if (!t.has_feature(Target::AVX512_SapphireRapids)) { + printf("[SKIP] No AMX support\n"); + return 0; + } + + printf("Running AMX matmul (signed/signed)\n"); + if (!matmul_ss()) { + return -1; + } + + // llvm >= 13.0 is required for unsigned and float AMX instructions + if (Halide::Internal::get_llvm_version() >= 130) { + printf("Running AMX matmul (signed/unsigned)\n"); + if (!matmul_su()) { + return -1; + } + + printf("Running AMX matmul (unsigned/signed)\n"); + if (!matmul_us()) { + return -1; + } + + printf("Running AMX matmul (unsigned/unsigned)\n"); + if (!matmul_uu()) { + return -1; + } + + printf("Running AMX matmul (bf16)\n"); + if (!matmul_bf16()) { + return -1; + } + } + printf("Success!\n"); + return 0; +} \ No newline at end of file From 655dbdfc3250c0ec4e435b0039ffd4a29129e617 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Fri, 1 Oct 2021 16:03:08 +0100 Subject: [PATCH 52/53] correctness part has been separated out --- test/performance/tiled_matmul.cpp | 33 +++++-------------------------- 1 file changed, 5 insertions(+), 28 deletions(-) diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index b8d51954ee83..2fd90683bd38 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -147,20 +147,6 @@ bool matmul(Halide::Target target) { result.realize(out); }); std::cout << "Exec time: " << time << "\n"; - - for (int j = 0; j < row; ++j) { - for (int i = 0; i < col; ++i) { - int32_t val = 0; - for (int k = 0; k < acc; ++k) { - val += a_buf(k, j) * b_buf(k % 4, i, k / 4); - } - if (val != out(i, j)) { - std::cerr << "Invalid result at " << i << ", " << j << "\n" - << out(i, j) << " != " << val << "\n"; - return false; - } - } - } std::cout << "Success!\n"; return true; } @@ -244,20 +230,6 @@ bool matmul_bf16(Halide::Target target) { }); std::cout << "Exec time: " << time << "\n"; - - for (int j = 0; j < row; ++j) { - for (int i = 0; i < col; ++i) { - float val = 0.f; - for (int k = 0; k < acc; ++k) { - val += static_cast(a_buf(k, j)) * static_cast(b_buf(k % 2, i, k / 2)); - } - if (!equal_eps(val, out(i, j), 0.01f)) { - std::cerr << "Invalid result at " << i << ", " << j << "\n" - << out(i, j) << " != " << val << "\n"; - return false; - } - } - } std::cout << "Success!\n"; return true; } @@ -269,11 +241,16 @@ int main(int argc, char **argv) { return 0; } + printf("Running AMX (signed/signed)\n"); matmul_ss(target); + printf("Running AMX (unsigned/signed)\n"); matmul_us(target); + printf("Running AMX (signed/unsigned)\n"); matmul_su(target); + printf("Running AMX (unsigned/unsigned)\n"); matmul_uu(target); + printf("Running AMX (bf16)\n"); matmul_bf16(target); return 0; } From abc660be1aa34d87863543f706399e4d82967693 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Wed, 6 Oct 2021 13:15:57 +0100 Subject: [PATCH 53/53] remove unused variables --- test/correctness/tiled_matmul.cpp | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/test/correctness/tiled_matmul.cpp b/test/correctness/tiled_matmul.cpp index 22b8cf212cd1..7fbeedef3ecc 100644 --- a/test/correctness/tiled_matmul.cpp +++ b/test/correctness/tiled_matmul.cpp @@ -64,18 +64,12 @@ struct make_int_t { template bool matmul() { - constexpr bool lhs_signed = std::is_signed::value; - constexpr bool rhs_signed = std::is_signed::value; - - auto lhs = typename std::conditional::type{}; - auto rhs = typename std::conditional::type{}; - constexpr int row = 16; constexpr int col = 16; constexpr int acc = 16; - Buffer A_buf(acc, row); - Buffer B_buf(4, col, acc / 4); + Buffer A_buf(acc, row); + Buffer B_buf(4, col, acc / 4); Var x("x"), y("y"); RDom r(0, acc); @@ -226,13 +220,15 @@ auto matmul_uu = &matmul; int main(int argc, char **argv) { Target t = get_jit_target_from_environment(); if (!t.has_feature(Target::AVX512_SapphireRapids)) { - printf("[SKIP] No AMX support\n"); + printf("[SKIP] No AMX target enabled\n"); return 0; } printf("Running AMX matmul (signed/signed)\n"); if (!matmul_ss()) { return -1; + } else { + printf("Success!\n"); } // llvm >= 13.0 is required for unsigned and float AMX instructions @@ -240,23 +236,30 @@ int main(int argc, char **argv) { printf("Running AMX matmul (signed/unsigned)\n"); if (!matmul_su()) { return -1; + } else { + printf("Success!\n"); } printf("Running AMX matmul (unsigned/signed)\n"); if (!matmul_us()) { return -1; + } else { + printf("Success!\n"); } printf("Running AMX matmul (unsigned/unsigned)\n"); if (!matmul_uu()) { return -1; + } else { + printf("Success!\n"); } printf("Running AMX matmul (bf16)\n"); if (!matmul_bf16()) { return -1; + } else { + printf("Success!\n"); } } - printf("Success!\n"); return 0; } \ No newline at end of file