Skip to content

Commit

Permalink
Added runtime detection
Browse files Browse the repository at this point in the history
Cannot do a `cupid` test because they don't support `amx`.
  • Loading branch information
sayantn committed Jun 23, 2024
1 parent 86098df commit cfbc131
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 23 deletions.
15 changes: 15 additions & 0 deletions crates/std_detect/src/detect/arch/x86.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ features! {
/// * `"avx512bf16"`
/// * `"avx512vp2intersect"`
/// * `"avx512fp16"`
/// * `"amx-tile"`
/// * `"amx-int8"`
/// * `"amx-bf16"`
/// * `"amx-fp16"`
/// * `"amx-complex"`
/// * `"f16c"`
/// * `"fma"`
/// * `"bmi1"`
Expand Down Expand Up @@ -172,6 +177,16 @@ features! {
/// AVX-512 P2INTERSECT
@FEATURE: #[stable(feature = "simd_x86", since = "1.27.0")] avx512fp16: "avx512fp16";
/// AVX-512 FP16 (FLOAT16 instructions)
@FEATURE: #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] amx_tile: "amx-tile";
/// AMX (Advanced Matrix Extensions) - Tile load/store
@FEATURE: #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] amx_int8: "amx-int8";
/// AMX-INT8 (Operations on 8-bit integers)
@FEATURE: #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] amx_bf16: "amx-bf16";
/// AMX-BF16 (BFloat16 Operations)
@FEATURE: #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] amx_fp16: "amx-fp16";
/// AMX-FP16 (Float16 Operations)
@FEATURE: #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] amx_complex: "amx-complex";
/// AMX-COMPLEX (Complex number Operations)
@FEATURE: #[stable(feature = "simd_x86", since = "1.27.0")] f16c: "f16c";
/// F16C (Conversions between IEEE-754 `binary16` and `binary32` formats)
@FEATURE: #[stable(feature = "simd_x86", since = "1.27.0")] fma: "fma";
Expand Down
23 changes: 19 additions & 4 deletions crates/std_detect/src/detect/os/x86.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,17 @@ pub(crate) fn detect_features() -> cache::Initializer {
extended_features_ecx,
extended_features_edx,
extended_features_eax_leaf_1,
extended_features_edx_leaf_1,
) = if max_basic_leaf >= 7 {
let CpuidResult { ebx, ecx, edx, .. } = unsafe { __cpuid(0x0000_0007_u32) };
let CpuidResult { eax: eax_1, .. } =
unsafe { __cpuid_count(0x0000_0007_u32, 0x0000_0001_u32) };
(ebx, ecx, edx, eax_1)
let CpuidResult {
eax: eax_1,
edx: edx_1,
..
} = unsafe { __cpuid_count(0x0000_0007_u32, 0x0000_0001_u32) };
(ebx, ecx, edx, eax_1, edx_1)
} else {
(0, 0, 0, 0) // CPUID does not support "Extended Features"
(0, 0, 0, 0, 0) // CPUID does not support "Extended Features"
};

// EAX = 0x8000_0000, ECX = 0: Get Highest Extended Function Supported
Expand Down Expand Up @@ -157,6 +161,7 @@ pub(crate) fn detect_features() -> cache::Initializer {
// * SSE -> `XCR0.SSE[1]`
// * AVX -> `XCR0.AVX[2]`
// * AVX-512 -> `XCR0.AVX-512[7:5]`.
// * AMX -> `XCR0.AMX[18:17]`
//
// by setting the corresponding bits of `XCR0` to `1`.
//
Expand All @@ -167,6 +172,8 @@ pub(crate) fn detect_features() -> cache::Initializer {
let os_avx_support = xcr0 & 6 == 6;
// Test `XCR0.AVX-512[7:5]` with the mask `0b1110_0000 == 224`:
let os_avx512_support = xcr0 & 224 == 224;
// Test `XCR0.AMX[18:17]` with the mask `0b110_0000_0000_0000_0000 == 0x60000`
let os_amx_support = xcr0 & 0x60000 == 0x60000;

// Only if the OS and the CPU support saving/restoring the AVX
// registers we enable `xsave` support:
Expand Down Expand Up @@ -225,6 +232,14 @@ pub(crate) fn detect_features() -> cache::Initializer {
enable(extended_features_edx, 8, Feature::avx512vp2intersect);
enable(extended_features_edx, 23, Feature::avx512fp16);
enable(extended_features_eax_leaf_1, 5, Feature::avx512bf16);

if os_amx_support {
enable(extended_features_edx, 24, Feature::amx_tile);
enable(extended_features_edx, 25, Feature::amx_int8);
enable(extended_features_edx, 22, Feature::amx_bf16);
enable(extended_features_eax_leaf_1, 21, Feature::amx_fp16);
enable(extended_features_edx_leaf_1, 8, Feature::amx_complex);
}
}
}
}
Expand Down
43 changes: 24 additions & 19 deletions crates/std_detect/tests/x86-specific.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#![allow(internal_features)]
#![feature(stdarch_internal)]
#![feature(stdarch_internal, x86_amx_intrinsics)]

extern crate cupid;
#[macro_use]
Expand All @@ -24,34 +24,34 @@ fn dump() {
println!("f16c: {:?}", is_x86_feature_detected!("f16c"));
println!("avx: {:?}", is_x86_feature_detected!("avx"));
println!("avx2: {:?}", is_x86_feature_detected!("avx2"));
println!("avx512f {:?}", is_x86_feature_detected!("avx512f"));
println!("avx512cd {:?}", is_x86_feature_detected!("avx512cd"));
println!("avx512er {:?}", is_x86_feature_detected!("avx512er"));
println!("avx512pf {:?}", is_x86_feature_detected!("avx512pf"));
println!("avx512bw {:?}", is_x86_feature_detected!("avx512bw"));
println!("avx512dq {:?}", is_x86_feature_detected!("avx512dq"));
println!("avx512vl {:?}", is_x86_feature_detected!("avx512vl"));
println!("avx512_ifma {:?}", is_x86_feature_detected!("avx512ifma"));
println!("avx512f: {:?}", is_x86_feature_detected!("avx512f"));
println!("avx512cd: {:?}", is_x86_feature_detected!("avx512cd"));
println!("avx512er: {:?}", is_x86_feature_detected!("avx512er"));
println!("avx512pf: {:?}", is_x86_feature_detected!("avx512pf"));
println!("avx512bw: {:?}", is_x86_feature_detected!("avx512bw"));
println!("avx512dq: {:?}", is_x86_feature_detected!("avx512dq"));
println!("avx512vl: {:?}", is_x86_feature_detected!("avx512vl"));
println!("avx512_ifma: {:?}", is_x86_feature_detected!("avx512ifma"));
println!("avx512vbmi {:?}", is_x86_feature_detected!("avx512vbmi"));
println!(
"avx512_vpopcntdq {:?}",
"avx512_vpopcntdq: {:?}",
is_x86_feature_detected!("avx512vpopcntdq")
);
println!("avx512vbmi2 {:?}", is_x86_feature_detected!("avx512vbmi2"));
println!("gfni {:?}", is_x86_feature_detected!("gfni"));
println!("vaes {:?}", is_x86_feature_detected!("vaes"));
println!("vpclmulqdq {:?}", is_x86_feature_detected!("vpclmulqdq"));
println!("avx512vnni {:?}", is_x86_feature_detected!("avx512vnni"));
println!("avx512vbmi2: {:?}", is_x86_feature_detected!("avx512vbmi2"));
println!("gfni: {:?}", is_x86_feature_detected!("gfni"));
println!("vaes: {:?}", is_x86_feature_detected!("vaes"));
println!("vpclmulqdq: {:?}", is_x86_feature_detected!("vpclmulqdq"));
println!("avx512vnni: {:?}", is_x86_feature_detected!("avx512vnni"));
println!(
"avx512bitalg {:?}",
"avx512bitalg: {:?}",
is_x86_feature_detected!("avx512bitalg")
);
println!("avx512bf16 {:?}", is_x86_feature_detected!("avx512bf16"));
println!("avx512bf16: {:?}", is_x86_feature_detected!("avx512bf16"));
println!(
"avx512vp2intersect {:?}",
"avx512vp2intersect: {:?}",
is_x86_feature_detected!("avx512vp2intersect")
);
println!("avx512fp16 {:?}", is_x86_feature_detected!("avx512fp16"));
println!("avx512fp16: {:?}", is_x86_feature_detected!("avx512fp16"));
println!("fma: {:?}", is_x86_feature_detected!("fma"));
println!("abm: {:?}", is_x86_feature_detected!("abm"));
println!("bmi: {:?}", is_x86_feature_detected!("bmi1"));
Expand All @@ -68,6 +68,11 @@ fn dump() {
println!("adx: {:?}", is_x86_feature_detected!("adx"));
println!("rtm: {:?}", is_x86_feature_detected!("rtm"));
println!("movbe: {:?}", is_x86_feature_detected!("movbe"));
println!("amx-bf16: {:?}", is_x86_feature_detected!("amx-bf16"));
println!("amx-tile: {:?}", is_x86_feature_detected!("amx-tile"));
println!("amx-int8: {:?}", is_x86_feature_detected!("amx-int8"));
println!("amx-fp16: {:?}", is_x86_feature_detected!("amx-fp16"));
println!("amx-complex: {:?}", is_x86_feature_detected!("amx-complex"));
}

#[cfg(feature = "std_detect_env_override")]
Expand Down

0 comments on commit cfbc131

Please sign in to comment.