Skip to content

Commit

Permalink
[ARM] support new udot/sdot patterns (halide#7800)
Browse files Browse the repository at this point in the history
  • Loading branch information
rootjalex authored and ardier committed Mar 3, 2024
1 parent 4e0502a commit 59f2272
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 4 deletions.
85 changes: 85 additions & 0 deletions src/CodeGen_ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class CodeGen_ARM : public CodeGen_Posix {
/** Nodes for which we want to emit specific neon intrinsics */
// @{
void visit(const Cast *) override;
void visit(const Add *) override;
void visit(const Sub *) override;
void visit(const Min *) override;
void visit(const Max *) override;
Expand Down Expand Up @@ -910,6 +911,90 @@ void CodeGen_ARM::visit(const Cast *op) {
CodeGen_Posix::visit(op);
}

void CodeGen_ARM::visit(const Add *op) {
if (neon_intrinsics_disabled() ||
!op->type.is_vector() ||
!target.has_feature(Target::ARMDotProd) ||
!op->type.is_int_or_uint() ||
op->type.bits() != 32) {
CodeGen_Posix::visit(op);
return;
}

struct Pattern {
Expr pattern;
const char *intrin;
Type coeff_type = UInt(8);
};

// Initial values.
Expr init_i32 = Variable::make(Int(32, 0), "init");
Expr init_u32 = Variable::make(UInt(32, 0), "init");
// Values
Expr a_i8 = Variable::make(Int(8, 0), "a"), b_i8 = Variable::make(Int(8, 0), "b");
Expr c_i8 = Variable::make(Int(8, 0), "c"), d_i8 = Variable::make(Int(8, 0), "d");
Expr a_u8 = Variable::make(UInt(8, 0), "a"), b_u8 = Variable::make(UInt(8, 0), "b");
Expr c_u8 = Variable::make(UInt(8, 0), "c"), d_u8 = Variable::make(UInt(8, 0), "d");
// Coefficients
Expr ac_i8 = Variable::make(Int(8, 0), "ac"), bc_i8 = Variable::make(Int(8, 0), "bc");
Expr cc_i8 = Variable::make(Int(8, 0), "cc"), dc_i8 = Variable::make(Int(8, 0), "dc");
Expr ac_u8 = Variable::make(UInt(8, 0), "ac"), bc_u8 = Variable::make(UInt(8, 0), "bc");
Expr cc_u8 = Variable::make(UInt(8, 0), "cc"), dc_u8 = Variable::make(UInt(8, 0), "dc");

// clang-format off
static const Pattern patterns[] = {
// If we had better normalization, we could drastically reduce the number of patterns here.
// Signed variants.
{init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product"},
{init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), i16(d_i8)), "dot_product", Int(8)},
{init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(i16(c_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)},
{init_i32 + widening_add(widening_mul(a_i8, ac_i8), i16(b_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)},
{init_i32 + widening_add(i16(a_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)},
// Signed variants (associative).
{init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product"},
{init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), i16(d_i8))), "dot_product", Int(8)},
{init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(i16(c_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)},
{init_i32 + (widening_add(widening_mul(a_i8, ac_i8), i16(b_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)},
{init_i32 + (widening_add(i16(a_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)},
// Unsigned variants.
{init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product"},
{init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), u16(d_u8)), "dot_product", UInt(8)},
{init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(u16(c_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)},
{init_u32 + widening_add(widening_mul(a_u8, ac_u8), u16(b_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)},
{init_u32 + widening_add(u16(a_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)},
// Unsigned variants (associative).
{init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product"},
{init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), u16(d_u8))), "dot_product", UInt(8)},
{init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(u16(c_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)},
{init_u32 + (widening_add(widening_mul(a_u8, ac_u8), u16(b_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)},
{init_u32 + (widening_add(u16(a_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)},
};
// clang-format on

std::map<std::string, Expr> matches;
for (const Pattern &p : patterns) {
if (expr_match(p.pattern, op, matches)) {
Expr init = matches["init"];
Expr values = Shuffle::make_interleave({matches["a"], matches["b"], matches["c"], matches["d"]});
// Coefficients can be 1 if not in the pattern.
Expr one = make_one(p.coeff_type.with_lanes(op->type.lanes()));
// This hideous code pattern implements fetching a
// default value if the map doesn't contain a key.
Expr _ac = matches.try_emplace("ac", one).first->second;
Expr _bc = matches.try_emplace("bc", one).first->second;
Expr _cc = matches.try_emplace("cc", one).first->second;
Expr _dc = matches.try_emplace("dc", one).first->second;
Expr coeffs = Shuffle::make_interleave({_ac, _bc, _cc, _dc});
value = call_overloaded_intrin(op->type, p.intrin, {init, values, coeffs});
if (value) {
return;
}
}
}

CodeGen_Posix::visit(op);
}

void CodeGen_ARM::visit(const Sub *op) {
if (neon_intrinsics_disabled()) {
CodeGen_Posix::visit(op);
Expand Down
11 changes: 9 additions & 2 deletions src/FindIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ class FindIntrinsics : public IRMutator {
IRMatcher::Wild<0> x;
IRMatcher::Wild<1> y;
IRMatcher::Wild<2> z;
IRMatcher::Wild<3> w;
IRMatcher::WildConst<0> c0;
IRMatcher::WildConst<1> c1;

Expand Down Expand Up @@ -255,7 +256,7 @@ class FindIntrinsics : public IRMutator {
result = widen_right_add(b, narrow_a);
}
internal_assert(result.type() == op->type);
return result;
return mutate(result);
} else if (narrow_b.defined()) {
Expr result;
if (a.type().code() != narrow_b.type().code()) {
Expand Down Expand Up @@ -420,7 +421,7 @@ class FindIntrinsics : public IRMutator {
result = widen_right_mul(b, narrow_a);
}
internal_assert(result.type() == op->type);
return result;
return mutate(result);
} else if (narrow_b.defined()) {
Expr result;
if (a.type().code() != narrow_b.type().code()) {
Expand Down Expand Up @@ -772,6 +773,12 @@ class FindIntrinsics : public IRMutator {
x + cast(op->type, widening_sub(z, y)),
is_x_same_uint) ||

// (x + y + widen(z)) + widen(w) = x + y + widening_add(z, w)
rewrite(widen_right_add(x + widen_right_add(y, z), w),
x + (y + widening_add(z, w)),
// We only care about integers, this should be trivially true.
is_x_same_int_or_uint) ||

// Saturating patterns.
rewrite(saturating_cast(op->type, widening_add(x, y)),
saturating_add(x, y),
Expand Down
17 changes: 15 additions & 2 deletions test/correctness/simd_op_check_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class SimdOpCheckARM : public SimdOpCheckTest {
Expr f64_1 = in_f64(x), f64_2 = in_f64(x + 16), f64_3 = in_f64(x + 32);
Expr f32_1 = in_f32(x), f32_2 = in_f32(x + 16), f32_3 = in_f32(x + 32);
Expr f16_1 = in_f16(x), f16_2 = in_f16(x + 16), f16_3 = in_f16(x + 32);
Expr i8_1 = in_i8(x), i8_2 = in_i8(x + 16), i8_3 = in_i8(x + 32);
Expr u8_1 = in_u8(x), u8_2 = in_u8(x + 16), u8_3 = in_u8(x + 32);
Expr i8_1 = in_i8(x), i8_2 = in_i8(x + 16), i8_3 = in_i8(x + 32), i8_4 = in_i8(x + 48);
Expr u8_1 = in_u8(x), u8_2 = in_u8(x + 16), u8_3 = in_u8(x + 32), u8_4 = in_u8(x + 48);
Expr i16_1 = in_i16(x), i16_2 = in_i16(x + 16), i16_3 = in_i16(x + 32);
Expr u16_1 = in_u16(x), u16_2 = in_u16(x + 16), u16_3 = in_u16(x + 32);
Expr i32_1 = in_i32(x), i32_2 = in_i32(x + 16), i32_3 = in_i32(x + 32);
Expand Down Expand Up @@ -587,6 +587,19 @@ class SimdOpCheckARM : public SimdOpCheckTest {
check(arm32 ? "vpaddl.s8" : "sdot", 8, sum_(i32(in_i8(f * x + r))));
check(arm32 ? "vpaddl.u8" : "udot", 8, sum_(i32(in_u8(f * x + r))));
check(arm32 ? "vpaddl.u8" : "udot", 8, sum_(u32(in_u8(f * x + r))));
if (!arm32) {
check("sdot", 8, i32_1 + i32(i8_1) * 3 + i32(i8_2) * 6 + i32(i8_3) * 9 + i32(i8_4) * 12);
check("sdot", 8, i32_1 + i32(i8_1) * 3 + i32(i8_2) * 6 + i32(i8_3) * 9 + i32(i8_4));
check("sdot", 8, i32_1 + i32(i8_1) * 3 + i32(i8_2) * 6 + i32(i8_3) + i32(i8_4) * 12);
check("sdot", 8, i32_1 + i32(i8_1) * 3 + i32(i8_2) + i32(i8_3) * 9 + i32(i8_4) * 12);
check("sdot", 8, i32_1 + i32(i8_1) + i32(i8_2) * 6 + i32(i8_3) * 9 + i32(i8_4) * 12);

check("udot", 8, u32_1 + u32(u8_1) * 3 + u32(u8_2) * 6 + u32(u8_3) * 9 + u32(u8_4) * 12);
check("udot", 8, u32_1 + u32(u8_1) * 3 + u32(u8_2) * 6 + u32(u8_3) * 9 + u32(u8_4));
check("udot", 8, u32_1 + u32(u8_1) * 3 + u32(u8_2) * 6 + u32(u8_3) + u32(u8_4) * 12);
check("udot", 8, u32_1 + u32(u8_1) * 3 + u32(u8_2) + u32(u8_3) * 9 + u32(u8_4) * 12);
check("udot", 8, u32_1 + u32(u8_1) + u32(u8_2) * 6 + u32(u8_3) * 9 + u32(u8_4) * 12);
}
} else {
check(arm32 ? "vpaddl.s8" : "saddlp", 8, sum_(i32(in_i8(f * x + r))));
check(arm32 ? "vpaddl.u8" : "uaddlp", 8, sum_(i32(in_u8(f * x + r))));
Expand Down

0 comments on commit 59f2272

Please sign in to comment.