Skip to content

Commit 78f1391

Browse files
authored
Add sqrt and more rsqrte neon instructions (#1078)
This adds instructions for sqrt and some of the missing reciprocal square-root estimate instructions.
1 parent 1619c70 commit 78f1391

File tree

5 files changed

+203
-34
lines changed

5 files changed

+203
-34
lines changed

crates/core_arch/src/aarch64/neon/generated.rs

+112-6
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ pub unsafe fn vabd_f64(a: float64x1_t, b: float64x1_t) -> float64x1_t {
1717
#[allow(improper_ctypes)]
1818
extern "C" {
1919
#[cfg_attr(target_arch = "aarch64", link_name = "llvm.aarch64.neon.fabd.v1f64")]
20-
fn vabd_f64_(a: float64x1_t, a: float64x1_t) -> float64x1_t;
20+
fn vabd_f64_(a: float64x1_t, b: float64x1_t) -> float64x1_t;
2121
}
2222
vabd_f64_(a, b)
2323
}
@@ -30,7 +30,7 @@ pub unsafe fn vabdq_f64(a: float64x2_t, b: float64x2_t) -> float64x2_t {
3030
#[allow(improper_ctypes)]
3131
extern "C" {
3232
#[cfg_attr(target_arch = "aarch64", link_name = "llvm.aarch64.neon.fabd.v2f64")]
33-
fn vabdq_f64_(a: float64x2_t, a: float64x2_t) -> float64x2_t;
33+
fn vabdq_f64_(a: float64x2_t, b: float64x2_t) -> float64x2_t;
3434
}
3535
vabdq_f64_(a, b)
3636
}
@@ -1087,7 +1087,7 @@ pub unsafe fn vmax_f64(a: float64x1_t, b: float64x1_t) -> float64x1_t {
10871087
#[allow(improper_ctypes)]
10881088
extern "C" {
10891089
#[cfg_attr(target_arch = "aarch64", link_name = "llvm.aarch64.neon.fmax.v1f64")]
1090-
fn vmax_f64_(a: float64x1_t, a: float64x1_t) -> float64x1_t;
1090+
fn vmax_f64_(a: float64x1_t, b: float64x1_t) -> float64x1_t;
10911091
}
10921092
vmax_f64_(a, b)
10931093
}
@@ -1100,7 +1100,7 @@ pub unsafe fn vmaxq_f64(a: float64x2_t, b: float64x2_t) -> float64x2_t {
11001100
#[allow(improper_ctypes)]
11011101
extern "C" {
11021102
#[cfg_attr(target_arch = "aarch64", link_name = "llvm.aarch64.neon.fmax.v2f64")]
1103-
fn vmaxq_f64_(a: float64x2_t, a: float64x2_t) -> float64x2_t;
1103+
fn vmaxq_f64_(a: float64x2_t, b: float64x2_t) -> float64x2_t;
11041104
}
11051105
vmaxq_f64_(a, b)
11061106
}
@@ -1113,7 +1113,7 @@ pub unsafe fn vmin_f64(a: float64x1_t, b: float64x1_t) -> float64x1_t {
11131113
#[allow(improper_ctypes)]
11141114
extern "C" {
11151115
#[cfg_attr(target_arch = "aarch64", link_name = "llvm.aarch64.neon.fmin.v1f64")]
1116-
fn vmin_f64_(a: float64x1_t, a: float64x1_t) -> float64x1_t;
1116+
fn vmin_f64_(a: float64x1_t, b: float64x1_t) -> float64x1_t;
11171117
}
11181118
vmin_f64_(a, b)
11191119
}
@@ -1126,11 +1126,69 @@ pub unsafe fn vminq_f64(a: float64x2_t, b: float64x2_t) -> float64x2_t {
11261126
#[allow(improper_ctypes)]
11271127
extern "C" {
11281128
#[cfg_attr(target_arch = "aarch64", link_name = "llvm.aarch64.neon.fmin.v2f64")]
1129-
fn vminq_f64_(a: float64x2_t, a: float64x2_t) -> float64x2_t;
1129+
fn vminq_f64_(a: float64x2_t, b: float64x2_t) -> float64x2_t;
11301130
}
11311131
vminq_f64_(a, b)
11321132
}
11331133

1134+
/// Calculates the square root of each lane.
1135+
#[inline]
1136+
#[target_feature(enable = "neon")]
1137+
#[cfg_attr(test, assert_instr(fsqrt))]
1138+
pub unsafe fn vsqrt_f32(a: float32x2_t) -> float32x2_t {
1139+
simd_fsqrt(a)
1140+
}
1141+
1142+
/// Calculates the square root of each lane.
1143+
#[inline]
1144+
#[target_feature(enable = "neon")]
1145+
#[cfg_attr(test, assert_instr(fsqrt))]
1146+
pub unsafe fn vsqrtq_f32(a: float32x4_t) -> float32x4_t {
1147+
simd_fsqrt(a)
1148+
}
1149+
1150+
/// Calculates the square root of each lane.
1151+
#[inline]
1152+
#[target_feature(enable = "neon")]
1153+
#[cfg_attr(test, assert_instr(fsqrt))]
1154+
pub unsafe fn vsqrt_f64(a: float64x1_t) -> float64x1_t {
1155+
simd_fsqrt(a)
1156+
}
1157+
1158+
/// Calculates the square root of each lane.
1159+
#[inline]
1160+
#[target_feature(enable = "neon")]
1161+
#[cfg_attr(test, assert_instr(fsqrt))]
1162+
pub unsafe fn vsqrtq_f64(a: float64x2_t) -> float64x2_t {
1163+
simd_fsqrt(a)
1164+
}
1165+
1166+
/// Reciprocal square-root estimate.
1167+
#[inline]
1168+
#[target_feature(enable = "neon")]
1169+
#[cfg_attr(test, assert_instr(frsqrte))]
1170+
pub unsafe fn vrsqrte_f64(a: float64x1_t) -> float64x1_t {
1171+
#[allow(improper_ctypes)]
1172+
extern "C" {
1173+
#[cfg_attr(target_arch = "aarch64", link_name = "llvm.aarch64.neon.frsqrte.v1f64")]
1174+
fn vrsqrte_f64_(a: float64x1_t) -> float64x1_t;
1175+
}
1176+
vrsqrte_f64_(a)
1177+
}
1178+
1179+
/// Reciprocal square-root estimate.
1180+
#[inline]
1181+
#[target_feature(enable = "neon")]
1182+
#[cfg_attr(test, assert_instr(frsqrte))]
1183+
pub unsafe fn vrsqrteq_f64(a: float64x2_t) -> float64x2_t {
1184+
#[allow(improper_ctypes)]
1185+
extern "C" {
1186+
#[cfg_attr(target_arch = "aarch64", link_name = "llvm.aarch64.neon.frsqrte.v2f64")]
1187+
fn vrsqrteq_f64_(a: float64x2_t) -> float64x2_t;
1188+
}
1189+
vrsqrteq_f64_(a)
1190+
}
1191+
11341192
#[cfg(test)]
11351193
mod test {
11361194
use super::*;
@@ -2233,4 +2291,52 @@ mod test {
22332291
let r: f64x2 = transmute(vminq_f64(transmute(a), transmute(b)));
22342292
assert_eq!(r, e);
22352293
}
2294+
2295+
#[simd_test(enable = "neon")]
2296+
unsafe fn test_vsqrt_f32() {
2297+
let a: f32x2 = f32x2::new(4.0, 9.0);
2298+
let e: f32x2 = f32x2::new(2.0, 3.0);
2299+
let r: f32x2 = transmute(vsqrt_f32(transmute(a)));
2300+
assert_eq!(r, e);
2301+
}
2302+
2303+
#[simd_test(enable = "neon")]
2304+
unsafe fn test_vsqrtq_f32() {
2305+
let a: f32x4 = f32x4::new(4.0, 9.0, 16.0, 25.0);
2306+
let e: f32x4 = f32x4::new(2.0, 3.0, 4.0, 5.0);
2307+
let r: f32x4 = transmute(vsqrtq_f32(transmute(a)));
2308+
assert_eq!(r, e);
2309+
}
2310+
2311+
#[simd_test(enable = "neon")]
2312+
unsafe fn test_vsqrt_f64() {
2313+
let a: f64 = 4.0;
2314+
let e: f64 = 2.0;
2315+
let r: f64 = transmute(vsqrt_f64(transmute(a)));
2316+
assert_eq!(r, e);
2317+
}
2318+
2319+
#[simd_test(enable = "neon")]
2320+
unsafe fn test_vsqrtq_f64() {
2321+
let a: f64x2 = f64x2::new(4.0, 9.0);
2322+
let e: f64x2 = f64x2::new(2.0, 3.0);
2323+
let r: f64x2 = transmute(vsqrtq_f64(transmute(a)));
2324+
assert_eq!(r, e);
2325+
}
2326+
2327+
#[simd_test(enable = "neon")]
2328+
unsafe fn test_vrsqrte_f64() {
2329+
let a: f64 = 1.0;
2330+
let e: f64 = 0.998046875;
2331+
let r: f64 = transmute(vrsqrte_f64(transmute(a)));
2332+
assert_eq!(r, e);
2333+
}
2334+
2335+
#[simd_test(enable = "neon")]
2336+
unsafe fn test_vrsqrteq_f64() {
2337+
let a: f64x2 = f64x2::new(1.0, 2.0);
2338+
let e: f64x2 = f64x2::new(0.998046875, 0.705078125);
2339+
let r: f64x2 = transmute(vrsqrteq_f64(transmute(a)));
2340+
assert_eq!(r, e);
2341+
}
22362342
}

crates/core_arch/src/arm/neon/generated.rs

+48
Original file line numberDiff line numberDiff line change
@@ -3349,6 +3349,38 @@ pub unsafe fn vminq_f32(a: float32x4_t, b: float32x4_t) -> float32x4_t {
33493349
vminq_f32_(a, b)
33503350
}
33513351

3352+
/// Reciprocal square-root estimate.
3353+
#[inline]
3354+
#[target_feature(enable = "neon")]
3355+
#[cfg_attr(target_arch = "arm", target_feature(enable = "v7"))]
3356+
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vrsqrte))]
3357+
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(frsqrte))]
3358+
pub unsafe fn vrsqrte_f32(a: float32x2_t) -> float32x2_t {
3359+
#[allow(improper_ctypes)]
3360+
extern "C" {
3361+
#[cfg_attr(target_arch = "arm", link_name = "llvm.arm.neon.vrsqrte.v2f32")]
3362+
#[cfg_attr(target_arch = "aarch64", link_name = "llvm.aarch64.neon.frsqrte.v2f32")]
3363+
fn vrsqrte_f32_(a: float32x2_t) -> float32x2_t;
3364+
}
3365+
vrsqrte_f32_(a)
3366+
}
3367+
3368+
/// Reciprocal square-root estimate.
3369+
#[inline]
3370+
#[target_feature(enable = "neon")]
3371+
#[cfg_attr(target_arch = "arm", target_feature(enable = "v7"))]
3372+
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vrsqrte))]
3373+
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(frsqrte))]
3374+
pub unsafe fn vrsqrteq_f32(a: float32x4_t) -> float32x4_t {
3375+
#[allow(improper_ctypes)]
3376+
extern "C" {
3377+
#[cfg_attr(target_arch = "arm", link_name = "llvm.arm.neon.vrsqrte.v4f32")]
3378+
#[cfg_attr(target_arch = "aarch64", link_name = "llvm.aarch64.neon.frsqrte.v4f32")]
3379+
fn vrsqrteq_f32_(a: float32x4_t) -> float32x4_t;
3380+
}
3381+
vrsqrteq_f32_(a)
3382+
}
3383+
33523384
#[cfg(test)]
33533385
#[allow(overflowing_literals)]
33543386
mod test {
@@ -5964,4 +5996,20 @@ mod test {
59645996
let r: f32x4 = transmute(vminq_f32(transmute(a), transmute(b)));
59655997
assert_eq!(r, e);
59665998
}
5999+
6000+
#[simd_test(enable = "neon")]
6001+
unsafe fn test_vrsqrte_f32() {
6002+
let a: f32x2 = f32x2::new(1.0, 2.0);
6003+
let e: f32x2 = f32x2::new(0.998046875, 0.705078125);
6004+
let r: f32x2 = transmute(vrsqrte_f32(transmute(a)));
6005+
assert_eq!(r, e);
6006+
}
6007+
6008+
#[simd_test(enable = "neon")]
6009+
unsafe fn test_vrsqrteq_f32() {
6010+
let a: f32x4 = f32x4::new(1.0, 2.0, 3.0, 4.0);
6011+
let e: f32x4 = f32x4::new(0.998046875, 0.705078125, 0.576171875, 0.4990234375);
6012+
let r: f32x4 = transmute(vrsqrteq_f32(transmute(a)));
6013+
assert_eq!(r, e);
6014+
}
59676015
}

crates/core_arch/src/arm/neon/mod.rs

-21
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,6 @@ extern "C" {
136136
#[cfg_attr(target_arch = "aarch64", link_name = "llvm.aarch64.neon.abs.v4i32")]
137137
fn vabsq_s32_(a: int32x4_t) -> int32x4_t;
138138

139-
#[cfg_attr(target_arch = "arm", link_name = "llvm.arm.neon.vrsqrte.v2f32")]
140-
#[cfg_attr(target_arch = "aarch64", link_name = "llvm.aarch64.neon.frsqrte.v2f32")]
141-
fn frsqrte_v2f32(a: float32x2_t) -> float32x2_t;
142-
143139
//uint32x2_t vqmovn_u64 (uint64x2_t a)
144140
#[cfg_attr(target_arch = "arm", link_name = "llvm.arm.neon.vqmovnu.v2i32")]
145141
#[cfg_attr(target_arch = "aarch64", link_name = "llvm.aarch64.neon.uqxtn.v2i32")]
@@ -2466,15 +2462,6 @@ pub unsafe fn vmovl_u32(a: uint32x2_t) -> uint64x2_t {
24662462
simd_cast(a)
24672463
}
24682464

2469-
/// Reciprocal square-root estimate.
2470-
#[inline]
2471-
#[target_feature(enable = "neon")]
2472-
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(frsqrte))]
2473-
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vrsqrte))]
2474-
pub unsafe fn vrsqrte_f32(a: float32x2_t) -> float32x2_t {
2475-
frsqrte_v2f32(a)
2476-
}
2477-
24782465
/// Vector bitwise not.
24792466
#[inline]
24802467
#[target_feature(enable = "neon")]
@@ -7906,14 +7893,6 @@ mod tests {
79067893
assert_eq!(r, e);
79077894
}
79087895

7909-
#[simd_test(enable = "neon")]
7910-
unsafe fn test_vrsqrt_f32() {
7911-
let a = f32x2::new(1.0, 2.0);
7912-
let e = f32x2::new(0.9980469, 0.7050781);
7913-
let r: f32x2 = transmute(vrsqrte_f32(transmute(a)));
7914-
assert_eq!(r, e);
7915-
}
7916-
79177896
#[simd_test(enable = "neon")]
79187897
unsafe fn test_vpmin_s8() {
79197898
let a = i8x8::new(1, -2, 3, -4, 5, 6, 7, 8);

crates/stdarch-gen/neon.spec

+22
Original file line numberDiff line numberDiff line change
@@ -720,3 +720,25 @@ aarch64 = fmin
720720
link-arm = vmins._EXT_
721721
link-aarch64 = fmin._EXT_
722722
generate float*_t
723+
724+
/// Calculates the square root of each lane.
725+
name = vsqrt
726+
fn = simd_fsqrt
727+
a = 4.0, 9.0, 16.0, 25.0
728+
validate 2.0, 3.0, 4.0, 5.0
729+
730+
aarch64 = fsqrt
731+
generate float*_t, float64x*_t
732+
733+
/// Reciprocal square-root estimate.
734+
name = vrsqrte
735+
a = 1.0, 2.0, 3.0, 4.0
736+
validate 0.998046875, 0.705078125, 0.576171875, 0.4990234375
737+
738+
aarch64 = frsqrte
739+
link-aarch64 = frsqrte._EXT_
740+
generate float64x*_t
741+
742+
arm = vrsqrte
743+
link-arm = vrsqrte._EXT_
744+
generate float*_t

crates/stdarch-gen/src/main.rs

+21-7
Original file line numberDiff line numberDiff line change
@@ -345,13 +345,20 @@ fn gen_aarch64(
345345
r#"#[allow(improper_ctypes)]
346346
extern "C" {{
347347
#[cfg_attr(target_arch = "aarch64", link_name = "llvm.aarch64.neon.{}")]
348-
fn {}(a: {}, a: {}) -> {};
348+
fn {}({}) -> {};
349349
}}
350350
"#,
351351
link_aarch64.replace("_EXT_", ext),
352352
current_fn,
353-
in_t,
354-
in_t,
353+
match para_num {
354+
1 => {
355+
format!("a: {}", in_t)
356+
}
357+
2 => {
358+
format!("a: {}, b: {}", in_t, in_t)
359+
}
360+
_ => unimplemented!("unknown para_num"),
361+
},
355362
out_t
356363
)
357364
} else {
@@ -527,7 +534,7 @@ fn gen_arm(
527534
}
528535
String::new()
529536
} else {
530-
if link_aarch64.is_none() || link_arm.is_none() {
537+
if link_aarch64.is_none() && link_arm.is_none() {
531538
panic!(
532539
"[{}] Either fn or link-arm and link-aarch have to be specified.",
533540
name
@@ -544,14 +551,21 @@ fn gen_arm(
544551
extern "C" {{
545552
#[cfg_attr(target_arch = "arm", link_name = "llvm.arm.neon.{}")]
546553
#[cfg_attr(target_arch = "aarch64", link_name = "llvm.aarch64.neon.{}")]
547-
fn {}(a: {}, b: {}) -> {};
554+
fn {}({}) -> {};
548555
}}
549556
"#,
550557
link_arm.replace("_EXT_", ext),
551558
link_aarch64.replace("_EXT_", ext),
552559
current_fn,
553-
in_t,
554-
in_t,
560+
match para_num {
561+
1 => {
562+
format!("a: {}", in_t)
563+
}
564+
2 => {
565+
format!("a: {}, b: {}", in_t, in_t)
566+
}
567+
_ => unimplemented!("unknown para_num"),
568+
},
555569
out_t
556570
)
557571
} else {

0 commit comments

Comments
 (0)