Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for AMX instructions #5818

Merged
merged 68 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from 61 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
036b037
Add support for AMX tile instructions
jwlawson Feb 17, 2021
2bd3453
Make AMX transform opt-in with memory type
jwlawson Mar 9, 2021
34e1a4c
Clean up tiled_matmul test
jwlawson Mar 10, 2021
dfeac55
Handle AMX intrinsic attributes better
jwlawson Mar 9, 2021
a9f84de
Format
jwlawson Mar 10, 2021
da04b0a
Fix test to behave like other tests
jwlawson Mar 11, 2021
9923040
Add doc and missing load check
jwlawson Mar 11, 2021
1a1d10a
Format
jwlawson Mar 11, 2021
d54fe24
Throw error if user requests AMX for invalid operation
jwlawson Mar 12, 2021
228beda
Add Tile lowering pass to makefile
jwlawson Mar 12, 2021
673480a
Use spaces in Makefile
jwlawson Mar 12, 2021
f91d79f
Place AMX instrinsics into a separate module (x86_amx.ll)
Mar 15, 2021
9fc4366
Merge branch 'master' into pr/5818
steven-johnson Mar 17, 2021
16a1c7b
Fix CreateAlignedLoad() call in CodeGen_X86
steven-johnson Mar 17, 2021
b3b1dfa
Merge branch 'master' into tile_matmul
Mar 22, 2021
03f894a
Merge branch 'master' into tile_matmul
Mar 24, 2021
97890dd
Merge branch 'master' into tile_matmul
Mar 31, 2021
293ca90
Merge branch 'master' into tile_matmul
Apr 1, 2021
4b950e5
Merge branch 'master' into tile_matmul
Apr 7, 2021
d0e5123
fix exporting to module
frengels Apr 19, 2021
9189120
add llvm funcs for su, us, uu amx variants
frengels Apr 8, 2021
75c4262
add other amx intrinsics to intrinsic_defs
frengels Apr 16, 2021
09f7551
match with unsigned 8 bit integers
frengels Apr 16, 2021
5dd1471
match for 32 bit integer and guard unsigned amx on llvm 13
frengels Apr 27, 2021
7e45c29
adjust test to cover unsigned tile operations
frengels Apr 29, 2021
4ab681b
guard properly with llvm 12
frengels Apr 29, 2021
6339ae7
create explicit error if failed to use tile operations
frengels Apr 29, 2021
525e11e
pass types as template params rather than boolean
frengels Apr 29, 2021
5a21484
clang-format patch
Apr 30, 2021
7950614
add x86_amx to makefile's runtime components
frengels Apr 30, 2021
4a6c10c
make tiled_matmul compatible with c++11
frengels Apr 30, 2021
02c375a
Merge remote-tracking branch 'upstream/master' into tile_matmul
frengels Apr 30, 2021
f985644
add mattrs required for amx
frengels May 3, 2021
7df3d5e
Merge pull request #3 from frengels/tile_matmul
mcleary May 3, 2021
e3f1ef6
fix formatting issues
frengels May 5, 2021
57b6080
remove outdated FIXME comments
frengels May 10, 2021
ef9d544
Merge pull request #4 from frengels/tile_matmul
mcleary May 10, 2021
9c6bfbc
Merge branch 'master' into tile_matmul
May 11, 2021
b5c46d2
Merge branch 'master' into tile_matmul
May 13, 2021
5af3dae
Merge remote-tracking branch 'upstream/master' into tile_matmul
frengels May 24, 2021
55098f3
add bf16 tile operations to the runtime
frengels May 10, 2021
9f078f0
create a schedule that should map to amx
frengels May 11, 2021
e73702d
create full amx-bf16 schedule
frengels May 12, 2021
16217b4
allow amx operations to yield f32s
frengels May 12, 2021
c722610
accept 32 bit float stores
frengels May 13, 2021
66885f1
add support for bf16
frengels May 19, 2021
97fb022
add missing bf16 intrinsics
frengels May 19, 2021
f6ba739
fix striding error when loading matrix
frengels May 21, 2021
c7278b6
add checks to verify bf16 result
frengels May 21, 2021
5e81a72
fix scaling of col_bytes on matmul call
frengels May 21, 2021
ea74fe2
move brace to previous line
frengels Jul 8, 2021
a854dc9
derive result type using a function rather than lambda
frengels Jul 9, 2021
d820737
Merge remote-tracking branch 'upstream/master' into tile_matmul_bf16
frengels Jul 9, 2021
26014d2
run clang tidy and format
frengels Jul 9, 2021
34557cb
have tile_store return i32
frengels Sep 2, 2021
5ad06e0
make is_3d_tile_index robust to indexing changes
frengels Sep 8, 2021
f0f9f3e
Merge remote-tracking branch 'upstream/master' into tile_matmul_bf16
frengels Sep 8, 2021
7cab155
apply formatting suggestions
frengels Sep 8, 2021
8f83544
both first and second can be const qualified
frengels Sep 9, 2021
b1e1452
remove trailing whitespace in unformatted section
frengels Sep 9, 2021
95d38f0
Merge branch 'master' into pr/5818
steven-johnson Sep 17, 2021
14df0bc
make requested style changes
frengels Sep 21, 2021
8b63d77
rename NewMatmul -> Matmul
frengels Oct 1, 2021
6a5eeaa
fix warning about missing return value
frengels Oct 1, 2021
cc7c97d
use get_1d_tile_index to handle special case
frengels Oct 1, 2021
014f0c6
add correctness test for AMX instructions
frengels Oct 1, 2021
655dbdf
correctness part has been separated out
frengels Oct 1, 2021
abc660b
remove unused variables
frengels Oct 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ SOURCE_FILES = \
EmulateFloat16Math.cpp \
Error.cpp \
Expr.cpp \
ExtractTileOperations.cpp \
FastIntegerDivide.cpp \
FindCalls.cpp \
FindIntrinsics.cpp \
Expand Down Expand Up @@ -627,6 +628,7 @@ HEADER_FILES = \
ExprUsesVar.h \
Extern.h \
ExternFuncArgument.h \
ExtractTileOperations.h \
FastIntegerDivide.h \
FindCalls.h \
FindIntrinsics.h \
Expand Down Expand Up @@ -841,6 +843,7 @@ RUNTIME_LL_COMPONENTS = \
x86_avx \
x86_avx2 \
x86_avx512 \
x86_amx \
x86_sse41

RUNTIME_EXPORTED_INCLUDES = $(INCLUDE_DIR)/HalideRuntime.h \
Expand Down
3 changes: 3 additions & 0 deletions dependencies/llvm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}")
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
message(STATUS "Using ClangConfig.cmake in: ${Clang_DIR}")

# LLVM_PACKAGE_VERSION does not propagate to higher scopes
set(Halide_LLVM_VERSION ${LLVM_PACKAGE_VERSION} CACHE INTERNAL "Provided LLVM version")

alexreinking marked this conversation as resolved.
Show resolved Hide resolved
if (LLVM_PACKAGE_VERSION VERSION_LESS 11.0)
message(FATAL_ERROR "LLVM version must be 11.0 or newer")
endif ()
Expand Down
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ set(HEADER_FILES
ExprUsesVar.h
Extern.h
ExternFuncArgument.h
ExtractTileOperations.h
FastIntegerDivide.h
FindCalls.h
FindIntrinsics.h
Expand Down Expand Up @@ -219,6 +220,7 @@ set(SOURCE_FILES
EmulateFloat16Math.cpp
Error.cpp
Expr.cpp
ExtractTileOperations.cpp
FastIntegerDivide.cpp
FindCalls.cpp
FindIntrinsics.cpp
Expand Down
8 changes: 3 additions & 5 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2564,11 +2564,9 @@ void CodeGen_LLVM::visit(const Call *op) {
internal_assert(op->is_extern() || op->is_intrinsic())
<< "Can only codegen extern calls and intrinsics\n";

if (op->type.is_vector()) {
value = call_overloaded_intrin(op->type, op->name, op->args);
if (value) {
return;
}
value = call_overloaded_intrin(op->type, op->name, op->args);
if (value) {
return;
}

// Some call nodes are actually injected at various stages as a
Expand Down
63 changes: 60 additions & 3 deletions src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,22 +79,32 @@ class CodeGen_X86 : public CodeGen_Posix {
void visit(const EQ *) override;
void visit(const NE *) override;
void visit(const Select *) override;
void visit(const Allocate *) override;
void visit(const Load *) override;
void visit(const Store *) override;
void codegen_vector_reduce(const VectorReduce *, const Expr &init) override;
// @}

private:
Scope<MemoryType> mem_type;
};

CodeGen_X86::CodeGen_X86(Target t)
: CodeGen_Posix(complete_x86_target(t)) {
}

const int max_intrinsic_args = 4;
const int max_intrinsic_args = 6;

struct x86Intrinsic {
const char *intrin_name;
halide_type_t ret_type;
const char *name;
halide_type_t arg_types[max_intrinsic_args];
Target::Feature feature = Target::FeatureEnd;
uint32_t flags = 0;
enum Options {
AccessesMemory = 1 << 0,
};
};

// clang-format off
Expand Down Expand Up @@ -199,6 +209,19 @@ const x86Intrinsic intrinsic_defs[] = {
{"dpwssdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_SapphireRapids},
{"dpwssdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids},
{"dpwssdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids},

{"tileloadd64_i8", Int(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory},
{"tileloadd64_i8", UInt(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory},
{"tileloadd64_bf16", BFloat(16, 512), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory},
{"tdpbssd", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), Int(8, 1024), Int(8, 1024)}, Target::AVX512_SapphireRapids},
{"tdpbsud", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), Int(8, 1024), UInt(8, 1024)}, Target::AVX512_SapphireRapids},
{"tdpbusd", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), UInt(8, 1024), Int(8, 1024)}, Target::AVX512_SapphireRapids},
{"tdpbuud", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), UInt(8, 1024), UInt(8, 1024)}, Target::AVX512_SapphireRapids},
{"tdpbf16ps", Float(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Float(32, 256), BFloat(16, 512), BFloat(16, 512)}, Target::AVX512_SapphireRapids},
{"tilezero_i32", Int(32, 256), "tile_zero", {Int(16), Int(16)}, Target::AVX512_SapphireRapids},
{"tilezero_f32", Float(32, 256), "tile_zero", {Int(16), Int(16)}, Target::AVX512_SapphireRapids},
{"tilestored64_i32", Int(32), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Int(32, 256)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory},
{"tilestored64_f32", Int(32), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Float(32, 256)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory},
};
// clang-format on

Expand All @@ -221,7 +244,9 @@ void CodeGen_X86::init_module() {
}

auto *fn = declare_intrin_overload(i.name, ret_type, i.intrin_name, std::move(arg_types));
fn->addFnAttr(llvm::Attribute::ReadNone);
if ((i.flags & x86Intrinsic::AccessesMemory) == 0) {
fn->addFnAttr(llvm::Attribute::ReadNone);
}
fn->addFnAttr(llvm::Attribute::NoUnwind);
}
}
Expand Down Expand Up @@ -584,6 +609,38 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init
CodeGen_Posix::codegen_vector_reduce(op, init);
}

void CodeGen_X86::visit(const Allocate *op) {
ScopedBinding<MemoryType> bind(mem_type, op->name, op->memory_type);
CodeGen_Posix::visit(op);
}

void CodeGen_X86::visit(const Load *op) {
if (mem_type.contains(op->name) && mem_type.get(op->name) == MemoryType::AMXTile) {
const Ramp *ramp = op->index.as<Ramp>();
internal_assert(ramp) << "Expected AMXTile to have index ramp\n";
Value *ptr = codegen_buffer_pointer(op->name, op->type, ramp->base);
LoadInst *load = builder->CreateAlignedLoad(ptr->getType()->getPointerElementType(), ptr, llvm::Align(op->type.bytes()));
add_tbaa_metadata(load, op->name, op->index);
value = load;
return;
}
CodeGen_Posix::visit(op);
}

void CodeGen_X86::visit(const Store *op) {
if (mem_type.contains(op->name) && mem_type.get(op->name) == MemoryType::AMXTile) {
Value *val = codegen(op->value);
Halide::Type value_type = op->value.type();
const Ramp *ramp = op->index.as<Ramp>();
internal_assert(ramp) << "Expected AMXTile to have index ramp\n";
Value *ptr = codegen_buffer_pointer(op->name, value_type, ramp->base);
StoreInst *store = builder->CreateAlignedStore(val, ptr, llvm::Align(value_type.bytes()));
add_tbaa_metadata(store, op->name, op->index);
return;
}
CodeGen_Posix::visit(op);
}

string CodeGen_X86::mcpu() const {
if (target.has_feature(Target::AVX512_SapphireRapids)) {
#if LLVM_VERSION >= 120
Expand Down Expand Up @@ -644,7 +701,7 @@ string CodeGen_X86::mattrs() const {
}
if (target.has_feature(Target::AVX512_SapphireRapids)) {
#if LLVM_VERSION >= 120
features += ",+avx512bf16,+avx512vnni";
features += ",+avx512bf16,+avx512vnni,+amx-int8,+amx-bf16";
#else
user_error << "AVX512 SapphireRapids requires LLVM 12 or later.";
#endif
Expand Down
4 changes: 4 additions & 0 deletions src/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,10 @@ enum class MemoryType {
* intermediate buffers. Necessary for vgather-vscatter instructions
* on Hexagon */
VTCM,

/** AMX Tile register for X86. Any data that would be used in an AMX matrix
* multiplication must first be loaded into an AMX tile register. */
AMXTile,
};

namespace Internal {
Expand Down
Loading