Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for AMX instructions #5818

Merged
merged 68 commits into from
Oct 21, 2021
Merged

Conversation

mcleary
Copy link
Contributor

@mcleary mcleary commented Mar 17, 2021

This pull request continues the work started by @jwlawson in #5780 with the objective of adding initial support for AMX instructions in Halide.

The main addition here is the fix for building Halide using LLVM 11. Support for AMX instructions requires LLVM 12 or newer so when building with LLVM 11 the unsupported instructions are not included.

A new LLVM module was created (x86_amx.ll) to contain all the required intrinsics to enable support for tile operations, this module is only included when LLVM >= 12 is present.

@steven-johnson
Copy link
Contributor

(Synced to head to fix some irrelevant LLVM build issues)

@mcleary
Copy link
Contributor Author

mcleary commented Mar 17, 2021

(Synced to head to fix some irrelevant LLVM build issues)

That is a good excuse for me to setup LLVM 13 locally. I will try to see why it is not building with it

steven-johnson and others added 3 commits March 17, 2021 12:56
Recent changes in LLVM trunk made the previous calling convention deprecated (and thus compiling with warning/error)
@steven-johnson
Copy link
Contributor

The OSX failure is unrelated (will be fixed by #5841), should be good to land

@steven-johnson
Copy link
Contributor

You should sync this to master to force the bots to retry.

@mcleary
Copy link
Contributor Author

mcleary commented Mar 31, 2021

You should sync this to master to force the bots to retry.

I'm not sure if the buildbot is still running since there is a "cancelled" message there.

@steven-johnson
Copy link
Contributor

Please try syncing to master once again; hopefully the buildbots will finally be clean.

@mcleary
Copy link
Contributor Author

mcleary commented Apr 7, 2021

Please try syncing to master once again; hopefully the buildbots will finally be clean.

Thanks, I will do that, hopefully it will all be green now.

@steven-johnson
Copy link
Contributor

Failures are the unrelated cuda-hang failure that we still haven't diagnosed; ok to land

@@ -190,7 +200,7 @@ const x86Intrinsic intrinsic_defs[] = {

{"dpwssdx16", Int(32, 16), "dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_SapphireRapids},
{"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids},
{"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids},
{"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: irrelevant whitespace?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in b1e1452

Comment on lines +118 to +122
if (Halide_LLVM_VERSION VERSION_GREATER_EQUAL 12.0)
# AMX instructions require LLVM 12 or newer
list(APPEND RUNTIME_LL x86_amx)
endif ()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does including this fail at build time or at runtime only?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this fails at build time with the following message

[1/7] Generating initmod.x86_amx.bc
FAILED: src/runtime/initmod.x86_amx.bc /home/frederik/projects/halide/build-11/src/runtime/initmod.x86_amx.bc 
cd /home/frederik/projects/halide/build-11/src/runtime && /usr/lib/llvm-11/bin/llvm-as /home/frederik/projects/halide/src/runtime/x86_amx.ll -o initmod.x86_amx.bc
/usr/lib/llvm-11/bin/llvm-as: /home/frederik/projects/halide/src/runtime/x86_amx.ll:3:18: error: expected type
  %2 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %rows, i16 %colbytes, i8* %1, i64 %stride) nounwind readonly
                 ^

@frengels
Copy link
Contributor

frengels commented Sep 1, 2021

Just letting you know we haven't lost track of this and the TensorCore PRs. We had some different priorities and annual leave. I look forward to getting this merged soon.

@frengels
Copy link
Contributor

I don't think the test failures are related to anything in this PR

Copy link
Contributor

@steven-johnson steven-johnson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM so far -- aside from style nits, I think it would be good to split the new test into correctness and performance tests, as Halide does for virtually all other features.

@@ -0,0 +1,414 @@
#include "ExtractTileOperations.h"

#include "IRMatch.h" // expr_match
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: We don't usually add comments explaining why each header is included.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speaking of which, it might be a good idea to run IWYU on our codebase...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM


enum class AMXOpType {
Int8,
Bf16,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I assume this is bfloat16? If so, spelling it out (eg Bfloat16) would be preferable.

case AMXOpType::Bf16:
return Float(32, 256);
default:
return Type();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is a should-never-happen case, so doing something like internal_error << "Unexpected"; would be appropriate.

const auto wild_i32 = Variable::make(Int(32), "*");
const auto wild_i32x = Variable::make(Int(32, 0), "*");

Tile<2> is_2d_tile_index(const Expr &e) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd expect a function named "is_whatever" to return bool, but this returns a struct. Something like get_2d_tile_index would be better.

return {};
}

Tile<3> is_3d_tile_index(const Expr &e) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same nit here.

// 4 bytes for i32, f32
auto colbytes = tile_y * 4;
auto matmul =
Call::make(res_type, "tile_matmul", {tile_x, colbytes, tile_r, out, lhs, rhs}, Call::Intrinsic);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to split this into two lines

op_type = AMXOpType::Bf16;
}

user_assert(!in_allocate) << "Found two possible tile allocations for AMX allocation";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be helpful to append amx_name or tile_name to the error message, for debugging purposes?

}

auto alloc_type = amx_op_type_result_type(op_type);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for this blank line

}

auto body = mutate(op->body);
return ProducerConsumer::make(amx_name, op->is_producer, body);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: std::move(body) ?

.vectorize(mmyi);

Func result = mm.in();
//result.print_loop_nest();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't check in commented-out code (unless there is a comment explaining why, as is done elsewhere in this file)

@frengels
Copy link
Contributor

When converting to correctness tests there's a bit of a change in the pattern for the rhs load when using Buffer instead of ImageParam, so I'm having to take a bit extra time to make sure it's accepting all possible scenarios and generating correct code.

When using `Buffer` instead of `ImageParam` the `Ramp` expression
generated is 1D instead of 2D, therefore we recognize this with a special
case. The lanes are still matched against the dimensions of the LHS
3d tile lanes.
@frengels
Copy link
Contributor

I think the recent commits addressed all comments, is there anything else that needs to be addressed?

@abadams
Copy link
Member

abadams commented Oct 21, 2021

lgtm. The pattern matching seems to be pretty ad-hoc and possibly brittle, but that can always be improved later. The checks for LLVM 12 will be removed pretty soon too.

@steven-johnson steven-johnson merged commit 0078880 into halide:master Oct 21, 2021
@steven-johnson
Copy link
Contributor

This is failing for LLVM11 for Makefile-based builds. I'll see if I can prep a patch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants