Skip to content

Commit f75b8b7

Browse files
authored
Arm Fused Multiply-Add fixes (rust-lang#1219)
1 parent 30b3eb3 commit f75b8b7

File tree

4 files changed

+60
-19
lines changed

4 files changed

+60
-19
lines changed

crates/core_arch/src/arm_shared/neon/generated.rs

+12-12
Original file line numberDiff line numberDiff line change
@@ -8721,7 +8721,7 @@ pub unsafe fn vmull_laneq_u32<const LANE: i32>(a: uint32x2_t, b: uint32x4_t) ->
87218721
/// Floating-point fused Multiply-Add to accumulator(vector)
87228722
#[inline]
87238723
#[target_feature(enable = "neon")]
8724-
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
8724+
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
87258725
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfma))]
87268726
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmla))]
87278727
pub unsafe fn vfma_f32(a: float32x2_t, b: float32x2_t, c: float32x2_t) -> float32x2_t {
@@ -8737,7 +8737,7 @@ vfma_f32_(b, c, a)
87378737
/// Floating-point fused Multiply-Add to accumulator(vector)
87388738
#[inline]
87398739
#[target_feature(enable = "neon")]
8740-
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
8740+
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
87418741
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfma))]
87428742
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmla))]
87438743
pub unsafe fn vfmaq_f32(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {
@@ -8753,27 +8753,27 @@ vfmaq_f32_(b, c, a)
87538753
/// Floating-point fused Multiply-Add to accumulator(vector)
87548754
#[inline]
87558755
#[target_feature(enable = "neon")]
8756-
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
8756+
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
87578757
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfma))]
87588758
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmla))]
87598759
pub unsafe fn vfma_n_f32(a: float32x2_t, b: float32x2_t, c: f32) -> float32x2_t {
8760-
vfma_f32(a, b, vdup_n_f32(c))
8760+
vfma_f32(a, b, vdup_n_f32_vfp4(c))
87618761
}
87628762

87638763
/// Floating-point fused Multiply-Add to accumulator(vector)
87648764
#[inline]
87658765
#[target_feature(enable = "neon")]
8766-
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
8766+
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
87678767
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfma))]
87688768
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmla))]
87698769
pub unsafe fn vfmaq_n_f32(a: float32x4_t, b: float32x4_t, c: f32) -> float32x4_t {
8770-
vfmaq_f32(a, b, vdupq_n_f32(c))
8770+
vfmaq_f32(a, b, vdupq_n_f32_vfp4(c))
87718771
}
87728772

87738773
/// Floating-point fused multiply-subtract from accumulator
87748774
#[inline]
87758775
#[target_feature(enable = "neon")]
8776-
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
8776+
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
87778777
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfms))]
87788778
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmls))]
87798779
pub unsafe fn vfms_f32(a: float32x2_t, b: float32x2_t, c: float32x2_t) -> float32x2_t {
@@ -8784,7 +8784,7 @@ pub unsafe fn vfms_f32(a: float32x2_t, b: float32x2_t, c: float32x2_t) -> float3
87848784
/// Floating-point fused multiply-subtract from accumulator
87858785
#[inline]
87868786
#[target_feature(enable = "neon")]
8787-
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
8787+
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
87888788
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfms))]
87898789
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmls))]
87908790
pub unsafe fn vfmsq_f32(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {
@@ -8795,21 +8795,21 @@ pub unsafe fn vfmsq_f32(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float
87958795
/// Floating-point fused Multiply-subtract to accumulator(vector)
87968796
#[inline]
87978797
#[target_feature(enable = "neon")]
8798-
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
8798+
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
87998799
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfms))]
88008800
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmls))]
88018801
pub unsafe fn vfms_n_f32(a: float32x2_t, b: float32x2_t, c: f32) -> float32x2_t {
8802-
vfms_f32(a, b, vdup_n_f32(c))
8802+
vfms_f32(a, b, vdup_n_f32_vfp4(c))
88038803
}
88048804

88058805
/// Floating-point fused Multiply-subtract to accumulator(vector)
88068806
#[inline]
88078807
#[target_feature(enable = "neon")]
8808-
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
8808+
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
88098809
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfms))]
88108810
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmls))]
88118811
pub unsafe fn vfmsq_n_f32(a: float32x4_t, b: float32x4_t, c: f32) -> float32x4_t {
8812-
vfmsq_f32(a, b, vdupq_n_f32(c))
8812+
vfmsq_f32(a, b, vdupq_n_f32_vfp4(c))
88138813
}
88148814

88158815
/// Subtract

crates/core_arch/src/arm_shared/neon/mod.rs

+26
Original file line numberDiff line numberDiff line change
@@ -3786,6 +3786,19 @@ pub unsafe fn vdupq_n_f32(value: f32) -> float32x4_t {
37863786
float32x4_t(value, value, value, value)
37873787
}
37883788

3789+
/// Duplicate vector element to vector or scalar
3790+
///
3791+
/// Private vfp4 version used by FMA intriniscs because LLVM does
3792+
/// not inline the non-vfp4 version in vfp4 functions.
3793+
#[inline]
3794+
#[target_feature(enable = "neon")]
3795+
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
3796+
#[cfg_attr(all(test, target_arch = "arm"), assert_instr("vdup.32"))]
3797+
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(dup))]
3798+
unsafe fn vdupq_n_f32_vfp4(value: f32) -> float32x4_t {
3799+
float32x4_t(value, value, value, value)
3800+
}
3801+
37893802
/// Duplicate vector element to vector or scalar
37903803
#[inline]
37913804
#[target_feature(enable = "neon")]
@@ -3896,6 +3909,19 @@ pub unsafe fn vdup_n_f32(value: f32) -> float32x2_t {
38963909
float32x2_t(value, value)
38973910
}
38983911

3912+
/// Duplicate vector element to vector or scalar
3913+
///
3914+
/// Private vfp4 version used by FMA intriniscs because LLVM does
3915+
/// not inline the non-vfp4 version in vfp4 functions.
3916+
#[inline]
3917+
#[target_feature(enable = "neon")]
3918+
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
3919+
#[cfg_attr(all(test, target_arch = "arm"), assert_instr("vdup.32"))]
3920+
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(dup))]
3921+
unsafe fn vdup_n_f32_vfp4(value: f32) -> float32x2_t {
3922+
float32x2_t(value, value)
3923+
}
3924+
38993925
/// Duplicate vector element to vector or scalar
39003926
#[inline]
39013927
#[target_feature(enable = "neon")]

crates/stdarch-gen/neon.spec

+6-6
Original file line numberDiff line numberDiff line change
@@ -2733,15 +2733,15 @@ generate float64x1_t
27332733
aarch64 = fmla
27342734
generate float64x2_t
27352735

2736-
target = fp-armv8
2736+
target = vfp4
27372737
arm = vfma
27382738
link-arm = llvm.fma._EXT_
27392739
generate float*_t
27402740

27412741
/// Floating-point fused Multiply-Add to accumulator(vector)
27422742
name = vfma
27432743
n-suffix
2744-
multi_fn = vfma-self-noext, a, b, {vdup-nself-noext, c}
2744+
multi_fn = vfma-self-noext, a, b, {vdup-nselfvfp4-noext, c}
27452745
a = 2.0, 3.0, 4.0, 5.0
27462746
b = 6.0, 4.0, 7.0, 8.0
27472747
c = 8.0
@@ -2752,7 +2752,7 @@ generate float64x1_t:float64x1_t:f64:float64x1_t
27522752
aarch64 = fmla
27532753
generate float64x2_t:float64x2_t:f64:float64x2_t
27542754

2755-
target = fp-armv8
2755+
target = vfp4
27562756
arm = vfma
27572757
generate float32x2_t:float32x2_t:f32:float32x2_t, float32x4_t:float32x4_t:f32:float32x4_t
27582758

@@ -2811,14 +2811,14 @@ generate float64x1_t
28112811
aarch64 = fmls
28122812
generate float64x2_t
28132813

2814-
target = fp-armv8
2814+
target = vfp4
28152815
arm = vfms
28162816
generate float*_t
28172817

28182818
/// Floating-point fused Multiply-subtract to accumulator(vector)
28192819
name = vfms
28202820
n-suffix
2821-
multi_fn = vfms-self-noext, a, b, {vdup-nself-noext, c}
2821+
multi_fn = vfms-self-noext, a, b, {vdup-nselfvfp4-noext, c}
28222822
a = 50.0, 35.0, 60.0, 69.0
28232823
b = 6.0, 4.0, 7.0, 8.0
28242824
c = 8.0
@@ -2829,7 +2829,7 @@ generate float64x1_t:float64x1_t:f64:float64x1_t
28292829
aarch64 = fmls
28302830
generate float64x2_t:float64x2_t:f64:float64x2_t
28312831

2832-
target = fp-armv8
2832+
target = vfp4
28332833
arm = vfms
28342834
generate float32x2_t:float32x2_t:f32:float32x2_t, float32x4_t:float32x4_t:f32:float32x4_t
28352835

crates/stdarch-gen/src/main.rs

+16-1
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,7 @@ enum Suffix {
438438
enum TargetFeature {
439439
Default,
440440
ArmV7,
441+
Vfp4,
441442
FPArmV8,
442443
AES,
443444
}
@@ -980,6 +981,7 @@ fn gen_aarch64(
980981
let current_target = match target {
981982
Default => "neon",
982983
ArmV7 => "v7",
984+
Vfp4 => "vfp4",
983985
FPArmV8 => "fp-armv8,v8",
984986
AES => "neon,aes",
985987
};
@@ -1120,6 +1122,7 @@ fn gen_aarch64(
11201122
out_t,
11211123
fixed,
11221124
None,
1125+
true,
11231126
));
11241127
}
11251128
calls
@@ -1630,12 +1633,14 @@ fn gen_arm(
16301633
let current_target_aarch64 = match target {
16311634
Default => "neon",
16321635
ArmV7 => "neon",
1636+
Vfp4 => "neon",
16331637
FPArmV8 => "neon",
16341638
AES => "neon,aes",
16351639
};
16361640
let current_target_arm = match target {
16371641
Default => "v7",
16381642
ArmV7 => "v7",
1643+
Vfp4 => "vfp4",
16391644
FPArmV8 => "fp-armv8,v8",
16401645
AES => "aes,v8",
16411646
};
@@ -1916,6 +1921,7 @@ fn gen_arm(
19161921
out_t,
19171922
fixed,
19181923
None,
1924+
false,
19191925
));
19201926
}
19211927
calls
@@ -2283,6 +2289,7 @@ fn get_call(
22832289
out_t: &str,
22842290
fixed: &Vec<String>,
22852291
n: Option<i32>,
2292+
aarch64: bool,
22862293
) -> String {
22872294
let params: Vec<_> = in_str.split(',').map(|v| v.trim().to_string()).collect();
22882295
assert!(params.len() > 0);
@@ -2450,7 +2457,8 @@ fn get_call(
24502457
in_t,
24512458
out_t,
24522459
fixed,
2453-
Some(i as i32)
2460+
Some(i as i32),
2461+
aarch64
24542462
)
24552463
);
24562464
call.push_str(&sub_match);
@@ -2499,6 +2507,7 @@ fn get_call(
24992507
out_t,
25002508
fixed,
25012509
n.clone(),
2510+
aarch64,
25022511
);
25032512
if !param_str.is_empty() {
25042513
param_str.push_str(", ");
@@ -2569,6 +2578,11 @@ fn get_call(
25692578
fn_name.push_str(type_to_suffix(in_t[1]));
25702579
} else if fn_format[1] == "nself" {
25712580
fn_name.push_str(type_to_n_suffix(in_t[1]));
2581+
} else if fn_format[1] == "nselfvfp4" {
2582+
fn_name.push_str(type_to_n_suffix(in_t[1]));
2583+
if !aarch64 {
2584+
fn_name.push_str("_vfp4");
2585+
}
25722586
} else if fn_format[1] == "out" {
25732587
fn_name.push_str(type_to_suffix(out_t));
25742588
} else if fn_format[1] == "in0" {
@@ -2854,6 +2868,7 @@ mod test {
28542868
target = match Some(String::from(&line[9..])) {
28552869
Some(input) => match input.as_str() {
28562870
"v7" => ArmV7,
2871+
"vfp4" => Vfp4,
28572872
"fp-armv8" => FPArmV8,
28582873
"aes" => AES,
28592874
_ => Default,

0 commit comments

Comments
 (0)