Skip to content

Commit

Permalink
Zen4 support (halide#7840)
Browse files Browse the repository at this point in the history
* Enable emission of float16/32 casts on x86

Fixes halide#7836
Fixes halide#4166

* Add support for zen4

* Add avx512_Zen4 target flag

It's a superset of cannon lake, and a subset of sapphire rapids

* Fix runtime detection, sapphire rapids CPUID bits

* Fix comment

* Don't catch bfloat casts

* Fix Zen4 model number

* Use llvm BFloat type for bfloat intrinsics

* Give up on native bfloat16 conversion for now

* Don't use llvm's bfloat type at all

* Add missing enum

* Fix constant in comment

* clang-format
  • Loading branch information
abadams authored and ardier committed Mar 3, 2024
1 parent 0a4e53f commit 9ac49d3
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 40 deletions.
1 change: 1 addition & 0 deletions python_bindings/src/halide/halide_/PyEnums.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ void define_enums(py::module &m) {
.value("AVX512_KNL", Target::Feature::AVX512_KNL)
.value("AVX512_Skylake", Target::Feature::AVX512_Skylake)
.value("AVX512_Cannonlake", Target::Feature::AVX512_Cannonlake)
.value("AVX512_Zen4", Target::Feature::AVX512_Zen4)
.value("AVX512_SapphireRapids", Target::Feature::AVX512_SapphireRapids)
.value("TraceLoads", Target::Feature::TraceLoads)
.value("TraceStores", Target::Feature::TraceStores)
Expand Down
37 changes: 27 additions & 10 deletions src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ namespace {
// oldest feature flag that supports an instruction.
Target complete_x86_target(Target t) {
if (t.has_feature(Target::AVX512_SapphireRapids)) {
t.set_feature(Target::AVX512_Zen4);
}
if (t.has_feature(Target::AVX512_Zen4)) {
t.set_feature(Target::AVX512_Cannonlake);
}
if (t.has_feature(Target::AVX512_Cannonlake)) {
Expand Down Expand Up @@ -208,12 +211,19 @@ const x86Intrinsic intrinsic_defs[] = {
{"llvm.x86.sse2.pmulhu.w", UInt(16, 8), "pmulh", {UInt(16, 8), UInt(16, 8)}},
{"llvm.x86.ssse3.pmul.hr.sw.128", Int(16, 8), "pmulhrs", {Int(16, 8), Int(16, 8)}, Target::SSE41},

// As of LLVM main September 5 2023, LLVM only has partial handling of
// bfloat16. The below rules will match fine for simple examples, but bfloat
// conversion will get folded through any nearby shuffles and cause
// unimplemented errors in llvm's x86 instruction selection for the shuffle
// node. Disabling them for now. See https://github.com/halide/Halide/issues/7219
/*
// Convert FP32 to BF16
{"vcvtne2ps2bf16x32", BFloat(16, 32), "f32_to_bf16", {Float(32, 32)}, Target::AVX512_SapphireRapids},
{"llvm.x86.avx512bf16.cvtneps2bf16.512", BFloat(16, 16), "f32_to_bf16", {Float(32, 16)}, Target::AVX512_SapphireRapids},
{"llvm.x86.avx512bf16.cvtneps2bf16.256", BFloat(16, 8), "f32_to_bf16", {Float(32, 8)}, Target::AVX512_SapphireRapids},
{"vcvtne2ps2bf16x32", BFloat(16, 32), "f32_to_bf16", {Float(32, 32)}, Target::AVX512_Zen4},
{"llvm.x86.avx512bf16.cvtneps2bf16.512", BFloat(16, 16), "f32_to_bf16", {Float(32, 16)}, Target::AVX512_Zen4},
{"llvm.x86.avx512bf16.cvtneps2bf16.256", BFloat(16, 8), "f32_to_bf16", {Float(32, 8)}, Target::AVX512_Zen4},
// LLVM does not provide an unmasked 128bit cvtneps2bf16 intrinsic, so provide a wrapper around the masked version.
{"vcvtneps2bf16x4", BFloat(16, 4), "f32_to_bf16", {Float(32, 4)}, Target::AVX512_SapphireRapids},
{"vcvtneps2bf16x4", BFloat(16, 4), "f32_to_bf16", {Float(32, 4)}, Target::AVX512_Zen4},
*/

// 2-way dot products
{"llvm.x86.avx2.pmadd.ub.sw", Int(16, 16), "saturating_dot_product", {UInt(8, 32), Int(8, 32)}, Target::AVX2},
Expand All @@ -240,23 +250,23 @@ const x86Intrinsic intrinsic_defs[] = {

// 4-way dot product vector reduction
// The LLVM intrinsics combine the bf16 pairs into i32, so provide a wrapper to correctly call the intrinsic.
{"dpbf16psx16", Float(32, 16), "dot_product", {Float(32, 16), BFloat(16, 32), BFloat(16, 32)}, Target::AVX512_SapphireRapids},
{"dpbf16psx16", Float(32, 16), "dot_product", {Float(32, 16), BFloat(16, 32), BFloat(16, 32)}, Target::AVX512_Zen4},
{"dpbf16psx8", Float(32, 8), "dot_product", {Float(32, 8), BFloat(16, 16), BFloat(16, 16)}, Target::AVX512_SapphireRapids},
{"dpbf16psx4", Float(32, 4), "dot_product", {Float(32, 4), BFloat(16, 8), BFloat(16, 8)}, Target::AVX512_SapphireRapids},

{"dpbusdx16", Int(32, 16), "dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_SapphireRapids},
{"dpbusdx16", Int(32, 16), "dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_Zen4},
{"dpbusdx8", Int(32, 8), "dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_SapphireRapids},
{"dpbusdx4", Int(32, 4), "dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVX512_SapphireRapids},

{"dpwssdx16", Int(32, 16), "dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_SapphireRapids},
{"dpwssdx16", Int(32, 16), "dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_Zen4},
{"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids},
{"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids},

{"dpbusdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_SapphireRapids},
{"dpbusdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_Zen4},
{"dpbusdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_SapphireRapids},
{"dpbusdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVX512_SapphireRapids},

{"dpwssdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_SapphireRapids},
{"dpwssdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_Zen4},
{"dpwssdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids},
{"dpwssdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids},

Expand Down Expand Up @@ -885,6 +895,8 @@ string CodeGen_X86::mcpu_target() const {
// The CPU choice here *WILL* affect -mattrs!
if (target.has_feature(Target::AVX512_SapphireRapids)) {
return "sapphirerapids";
} else if (target.has_feature(Target::AVX512_Zen4)) {
return "znver4";
} else if (target.has_feature(Target::AVX512_Cannonlake)) {
return "cannonlake";
} else if (target.has_feature(Target::AVX512_Skylake)) {
Expand Down Expand Up @@ -931,6 +943,8 @@ string CodeGen_X86::mcpu_tune() const {
return "znver2";
case Target::Processor::ZnVer3:
return "znver3";
case Target::Processor::ZnVer4:
return "znver4";

case Target::Processor::ProcessorGeneric:
break;
Expand Down Expand Up @@ -972,8 +986,11 @@ string CodeGen_X86::mattrs() const {
if (target.has_feature(Target::AVX512_Cannonlake)) {
features += ",+avx512ifma,+avx512vbmi";
}
if (target.has_feature(Target::AVX512_Zen4)) {
features += ",+avx512bf16,+avx512vnni,+avx512bitalg,+avx512vbmi2";
}
if (target.has_feature(Target::AVX512_SapphireRapids)) {
features += ",+avx512bf16,+avx512vnni,+amx-int8,+amx-bf16";
features += ",+avxvnni,+amx-int8,+amx-bf16";
}
}
return features;
Expand Down
29 changes: 24 additions & 5 deletions src/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,10 @@ Target::Processor get_amd_processor(unsigned family, unsigned model, bool have_s
}
break;
case 0x19: // AMD Family 19h
if (model <= 0x0f || model == 0x21) {
if ((model & 0xf0) == 0 || model == 0x21) {
return Target::Processor::ZnVer3; // 00h-0Fh, 21h: Zen3
} else if (model == 0x61) {
return Target::Processor::ZnVer4; // 61h: Zen4
}
break;
default:
Expand Down Expand Up @@ -215,8 +217,22 @@ Target calculate_host_target() {

if (vendor_signature == VendorSignatures::AuthenticAMD) {
processor = get_amd_processor(family, model, have_sse3);

if (processor == Target::Processor::ZnVer4) {
Target t{os, arch, bits, processor, initial_features, vector_bits};
t.set_features({Target::SSE41, Target::AVX,
Target::F16C, Target::FMA,
Target::AVX2, Target::AVX512,
Target::AVX512_Skylake, Target::AVX512_Cannonlake,
Target::AVX512_Zen4});
return t;
}
}

// Processors not specifically detected by model number above use the cpuid
// feature bits to determine what flags are supported. For future models,
// detect them explicitly above rather than extending the code below.

if (have_sse41) {
initial_features.push_back(Target::SSE41);
}
Expand Down Expand Up @@ -265,12 +281,12 @@ Target calculate_host_target() {
if ((info2[1] & avx512_cannonlake) == avx512_cannonlake) {
initial_features.push_back(Target::AVX512_Cannonlake);

const uint32_t avx512vnni = 1U << 11; // vnni result in ecx
const uint32_t avx512bf16 = 1U << 5; // bf16 result in eax, with cpuid(eax=7, ecx=1)
const uint32_t avxvnni = 1U << 4; // avxvnni (note, not avx512vnni) result in eax
const uint32_t avx512bf16 = 1U << 5; // bf16 result in eax, with cpuid(eax=7, ecx=1)
int info3[4];
cpuid(info3, 7, 1);
// TODO: port to family/model -based detection.
if ((info2[2] & avx512vnni) == avx512vnni &&
if ((info3[0] & avxvnni) == avxvnni &&
(info3[0] & avx512bf16) == avx512bf16) {
initial_features.push_back(Target::AVX512_SapphireRapids);
}
Expand Down Expand Up @@ -441,6 +457,7 @@ const std::map<std::string, Target::Processor> processor_name_map = {
{"tune_znver1", Target::Processor::ZnVer1},
{"tune_znver2", Target::Processor::ZnVer2},
{"tune_znver3", Target::Processor::ZnVer3},
{"tune_znver4", Target::Processor::ZnVer4},
};

bool lookup_processor(const std::string &tok, Target::Processor &result) {
Expand Down Expand Up @@ -502,6 +519,7 @@ const std::map<std::string, Target::Feature> feature_name_map = {
{"avx512_skylake", Target::AVX512_Skylake},
{"avx512_cannonlake", Target::AVX512_Cannonlake},
{"avx512_sapphirerapids", Target::AVX512_SapphireRapids},
{"avx512_zen4", Target::AVX512_Zen4},
{"trace_loads", Target::TraceLoads},
{"trace_stores", Target::TraceStores},
{"trace_realizations", Target::TraceRealizations},
Expand Down Expand Up @@ -1258,7 +1276,7 @@ bool Target::get_runtime_compatible_target(const Target &other, Target &result)
// clang-format on

// clang-format off
const std::array<Feature, 14> intersection_features = {{
const std::array<Feature, 15> intersection_features = {{
ARMv7s,
ARMv81a,
AVX,
Expand All @@ -1268,6 +1286,7 @@ bool Target::get_runtime_compatible_target(const Target &other, Target &result)
AVX512_KNL,
AVX512_SapphireRapids,
AVX512_Skylake,
AVX512_Zen4,
F16C,
FMA,
FMA4,
Expand Down
2 changes: 2 additions & 0 deletions src/Target.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ struct Target {
ZnVer1, /// Tune for AMD Zen CPU (AMD Family 17h, launched 2017).
ZnVer2, /// Tune for AMD Zen 2 CPU (AMD Family 17h, launched 2019).
ZnVer3, /// Tune for AMD Zen 3 CPU (AMD Family 19h, launched 2020).
ZnVer4, /// Tune for AMD Zen 4 CPU (AMD Family 19h, launched 2022).
} processor_tune = ProcessorGeneric;

/** Optional features a target can have.
Expand Down Expand Up @@ -130,6 +131,7 @@ struct Target {
AVX512_Skylake = halide_target_feature_avx512_skylake,
AVX512_Cannonlake = halide_target_feature_avx512_cannonlake,
AVX512_SapphireRapids = halide_target_feature_avx512_sapphirerapids,
AVX512_Zen4 = halide_target_feature_avx512_zen4,
TraceLoads = halide_target_feature_trace_loads,
TraceStores = halide_target_feature_trace_stores,
TraceRealizations = halide_target_feature_trace_realizations,
Expand Down
3 changes: 2 additions & 1 deletion src/runtime/HalideRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -1360,7 +1360,8 @@ typedef enum halide_target_feature_t {
halide_target_feature_avx512_knl, ///< Enable the AVX512 features supported by Knight's Landing chips, such as the Xeon Phi x200. This includes the base AVX512 set, and also AVX512-CD and AVX512-ER.
halide_target_feature_avx512_skylake, ///< Enable the AVX512 features supported by Skylake Xeon server processors. This adds AVX512-VL, AVX512-BW, and AVX512-DQ to the base set. The main difference from the base AVX512 set is better support for small integer ops. Note that this does not include the Knight's Landing features. Note also that these features are not available on Skylake desktop and mobile processors.
halide_target_feature_avx512_cannonlake, ///< Enable the AVX512 features expected to be supported by future Cannonlake processors. This includes all of the Skylake features, plus AVX512-IFMA and AVX512-VBMI.
halide_target_feature_avx512_sapphirerapids, ///< Enable the AVX512 features supported by Sapphire Rapids processors. This include all of the Cannonlake features, plus AVX512-VNNI and AVX512-BF16.
halide_target_feature_avx512_zen4, ///< Enable the AVX512 features supported by Zen4 processors. This include all of the Cannonlake features, plus AVX512-VNNI, AVX512-BF16, and more.
halide_target_feature_avx512_sapphirerapids, ///< Enable the AVX512 features supported by Sapphire Rapids processors. This include all of the Zen4 features, plus AVX-VNNI and AMX instructions.
halide_target_feature_trace_loads, ///< Trace all loads done by the pipeline. Equivalent to calling Func::trace_loads on every non-inlined Func.
halide_target_feature_trace_stores, ///< Trace all stores done by the pipeline. Equivalent to calling Func::trace_stores on every non-inlined Func.
halide_target_feature_trace_realizations, ///< Trace all realizations done by the pipeline. Equivalent to calling Func::trace_realizations on every non-inlined Func.
Expand Down
41 changes: 38 additions & 3 deletions src/runtime/x86_cpu_features.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,44 @@ WEAK CpuFeatures halide_get_cpu_features() {
features.set_known(halide_target_feature_avx512_cannonlake);
features.set_known(halide_target_feature_avx512_sapphirerapids);

// Detect CPU features by specific microarchitecture.
int32_t vendor[4];
cpuid(vendor, 0);
int32_t info[4];
cpuid(info, 1);

uint32_t family = (info[0] >> 8) & 0xF; // Bits 8..11
uint32_t model = (info[0] >> 4) & 0xF; // Bits 4..7
if (family == 0x6 || family == 0xF) {
if (family == 0xF) {
// Examine extended family ID if family ID is 0xF.
family += (info[0] >> 20) & 0xFf; // Bits 20..27
}
// Examine extended model ID if family ID is 0x6 or 0xF.
model += ((info[0] >> 16) & 0xF) << 4; // Bits 16..19
}

if (vendor[1] == 0x68747541 && vendor[3] == 0x69746e65 && vendor[2] == 0x444d4163) {
// AMD
if (family == 0x19 && model == 0x61) {
// Zen4
features.set_available(halide_target_feature_sse41);
features.set_available(halide_target_feature_avx);
features.set_available(halide_target_feature_f16c);
features.set_available(halide_target_feature_fma);
features.set_available(halide_target_feature_avx2);
features.set_available(halide_target_feature_avx512);
features.set_available(halide_target_feature_avx512_skylake);
features.set_available(halide_target_feature_avx512_cannonlake);
features.set_available(halide_target_feature_avx512_zen4);
return features;
}
}

// Legacy code to detect CPU by feature bits instead. Handle new
// microarchitectures above rather than making the code below more
// complicated.

const bool have_sse41 = (info[2] & (1 << 19)) != 0;
const bool have_avx = (info[2] & (1 << 28)) != 0;
const bool have_f16c = (info[2] & (1 << 29)) != 0;
Expand Down Expand Up @@ -70,8 +105,8 @@ WEAK CpuFeatures halide_get_cpu_features() {
constexpr uint32_t avx512bw = 1U << 30;
constexpr uint32_t avx512vl = 1U << 31;
constexpr uint32_t avx512ifma = 1U << 21;
constexpr uint32_t avx512vnni = 1U << 11; // vnni result in ecx
constexpr uint32_t avx512bf16 = 1U << 5; // bf16 result in eax, cpuid(eax=7, ecx=1)
constexpr uint32_t avxvnni = 1U << 4;
constexpr uint32_t avx512bf16 = 1U << 5; // bf16 result in eax, cpuid(eax=7, ecx=1)
constexpr uint32_t avx512 = avx512f | avx512cd;
constexpr uint32_t avx512_knl = avx512 | avx512pf | avx512er;
constexpr uint32_t avx512_skylake = avx512 | avx512vl | avx512bw | avx512dq;
Expand All @@ -92,7 +127,7 @@ WEAK CpuFeatures halide_get_cpu_features() {

int32_t info3[4];
cpuid(info3, 7, 1);
if ((info2[2] & avx512vnni) == avx512vnni &&
if ((info3[0] & avxvnni) == avxvnni &&
(info3[0] & avx512bf16) == avx512bf16) {
features.set_available(halide_target_feature_avx512_sapphirerapids);
}
Expand Down
Loading

0 comments on commit 9ac49d3

Please sign in to comment.