Skip to content

Commit 1f09bae

Browse files
committed
Implement min/max neon intrisics
1 parent ef37036 commit 1f09bae

File tree

5 files changed

+259
-19
lines changed

5 files changed

+259
-19
lines changed

build_system/tests.rs

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ const BASE_SYSROOT_SUITE: &[TestCase] = &[
9999
TestCase::build_bin_and_run("aot.mod_bench", "example/mod_bench.rs", &[]),
100100
TestCase::build_bin_and_run("aot.issue-72793", "example/issue-72793.rs", &[]),
101101
TestCase::build_bin("aot.issue-59326", "example/issue-59326.rs"),
102+
TestCase::build_bin_and_run("aot.neon", "example/neon.rs", &[]),
102103
];
103104

104105
pub(crate) static RAND_REPO: GitRepo = GitRepo::github(

config.txt

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ aot.float-minmax-pass
4242
aot.mod_bench
4343
aot.issue-72793
4444
aot.issue-59326
45+
aot.neon
4546

4647
testsuite.extended_sysroot
4748
test.rust-random/rand

example/neon.rs

+155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
// Most of these tests are copied from https://github.com/japaric/stdsimd/blob/0f4413d01c4f0c3ffbc5a69e9a37fbc7235b31a9/coresimd/arm/neon.rs
2+
3+
#![feature(portable_simd)]
4+
use std::arch::aarch64::*;
5+
use std::mem::transmute;
6+
use std::simd::*;
7+
8+
#[cfg(target_arch = "aarch64")]
9+
unsafe fn test_vpmin_s8() {
10+
let a = i8x8::from([1, -2, 3, -4, 5, 6, 7, 8]);
11+
let b = i8x8::from([0, 3, 2, 5, 4, 7, 6, 9]);
12+
let e = i8x8::from([-2, -4, 5, 7, 0, 2, 4, 6]);
13+
let r: i8x8 = transmute(vpmin_s8(transmute(a), transmute(b)));
14+
assert_eq!(r, e);
15+
}
16+
17+
#[cfg(target_arch = "aarch64")]
18+
unsafe fn test_vpmin_s16() {
19+
let a = i16x4::from([1, 2, 3, -4]);
20+
let b = i16x4::from([0, 3, 2, 5]);
21+
let e = i16x4::from([1, -4, 0, 2]);
22+
let r: i16x4 = transmute(vpmin_s16(transmute(a), transmute(b)));
23+
assert_eq!(r, e);
24+
}
25+
26+
#[cfg(target_arch = "aarch64")]
27+
unsafe fn test_vpmin_s32() {
28+
let a = i32x2::from([1, -2]);
29+
let b = i32x2::from([0, 3]);
30+
let e = i32x2::from([-2, 0]);
31+
let r: i32x2 = transmute(vpmin_s32(transmute(a), transmute(b)));
32+
assert_eq!(r, e);
33+
}
34+
35+
#[cfg(target_arch = "aarch64")]
36+
unsafe fn test_vpmin_u8() {
37+
let a = u8x8::from([1, 2, 3, 4, 5, 6, 7, 8]);
38+
let b = u8x8::from([0, 3, 2, 5, 4, 7, 6, 9]);
39+
let e = u8x8::from([1, 3, 5, 7, 0, 2, 4, 6]);
40+
let r: u8x8 = transmute(vpmin_u8(transmute(a), transmute(b)));
41+
assert_eq!(r, e);
42+
}
43+
44+
#[cfg(target_arch = "aarch64")]
45+
unsafe fn test_vpmin_u16() {
46+
let a = u16x4::from([1, 2, 3, 4]);
47+
let b = u16x4::from([0, 3, 2, 5]);
48+
let e = u16x4::from([1, 3, 0, 2]);
49+
let r: u16x4 = transmute(vpmin_u16(transmute(a), transmute(b)));
50+
assert_eq!(r, e);
51+
}
52+
53+
#[cfg(target_arch = "aarch64")]
54+
unsafe fn test_vpmin_u32() {
55+
let a = u32x2::from([1, 2]);
56+
let b = u32x2::from([0, 3]);
57+
let e = u32x2::from([1, 0]);
58+
let r: u32x2 = transmute(vpmin_u32(transmute(a), transmute(b)));
59+
assert_eq!(r, e);
60+
}
61+
62+
#[cfg(target_arch = "aarch64")]
63+
unsafe fn test_vpmin_f32() {
64+
let a = f32x2::from([1., -2.]);
65+
let b = f32x2::from([0., 3.]);
66+
let e = f32x2::from([-2., 0.]);
67+
let r: f32x2 = transmute(vpmin_f32(transmute(a), transmute(b)));
68+
assert_eq!(r, e);
69+
}
70+
71+
#[cfg(target_arch = "aarch64")]
72+
unsafe fn test_vpmax_s8() {
73+
let a = i8x8::from([1, -2, 3, -4, 5, 6, 7, 8]);
74+
let b = i8x8::from([0, 3, 2, 5, 4, 7, 6, 9]);
75+
let e = i8x8::from([1, 3, 6, 8, 3, 5, 7, 9]);
76+
let r: i8x8 = transmute(vpmax_s8(transmute(a), transmute(b)));
77+
assert_eq!(r, e);
78+
}
79+
80+
#[cfg(target_arch = "aarch64")]
81+
unsafe fn test_vpmax_s16() {
82+
let a = i16x4::from([1, 2, 3, -4]);
83+
let b = i16x4::from([0, 3, 2, 5]);
84+
let e = i16x4::from([2, 3, 3, 5]);
85+
let r: i16x4 = transmute(vpmax_s16(transmute(a), transmute(b)));
86+
assert_eq!(r, e);
87+
}
88+
89+
#[cfg(target_arch = "aarch64")]
90+
unsafe fn test_vpmax_s32() {
91+
let a = i32x2::from([1, -2]);
92+
let b = i32x2::from([0, 3]);
93+
let e = i32x2::from([1, 3]);
94+
let r: i32x2 = transmute(vpmax_s32(transmute(a), transmute(b)));
95+
assert_eq!(r, e);
96+
}
97+
98+
#[cfg(target_arch = "aarch64")]
99+
unsafe fn test_vpmax_u8() {
100+
let a = u8x8::from([1, 2, 3, 4, 5, 6, 7, 8]);
101+
let b = u8x8::from([0, 3, 2, 5, 4, 7, 6, 9]);
102+
let e = u8x8::from([2, 4, 6, 8, 3, 5, 7, 9]);
103+
let r: u8x8 = transmute(vpmax_u8(transmute(a), transmute(b)));
104+
assert_eq!(r, e);
105+
}
106+
107+
#[cfg(target_arch = "aarch64")]
108+
unsafe fn test_vpmax_u16() {
109+
let a = u16x4::from([1, 2, 3, 4]);
110+
let b = u16x4::from([0, 3, 2, 5]);
111+
let e = u16x4::from([2, 4, 3, 5]);
112+
let r: u16x4 = transmute(vpmax_u16(transmute(a), transmute(b)));
113+
assert_eq!(r, e);
114+
}
115+
116+
#[cfg(target_arch = "aarch64")]
117+
unsafe fn test_vpmax_u32() {
118+
let a = u32x2::from([1, 2]);
119+
let b = u32x2::from([0, 3]);
120+
let e = u32x2::from([2, 3]);
121+
let r: u32x2 = transmute(vpmax_u32(transmute(a), transmute(b)));
122+
assert_eq!(r, e);
123+
}
124+
125+
#[cfg(target_arch = "aarch64")]
126+
unsafe fn test_vpmax_f32() {
127+
let a = f32x2::from([1., -2.]);
128+
let b = f32x2::from([0., 3.]);
129+
let e = f32x2::from([1., 3.]);
130+
let r: f32x2 = transmute(vpmax_f32(transmute(a), transmute(b)));
131+
assert_eq!(r, e);
132+
}
133+
134+
#[cfg(target_arch = "aarch64")]
135+
fn main() {
136+
unsafe {
137+
test_vpmin_s8();
138+
test_vpmin_s16();
139+
test_vpmin_s32();
140+
test_vpmin_u8();
141+
test_vpmin_u16();
142+
test_vpmin_u32();
143+
test_vpmin_f32();
144+
test_vpmax_s8();
145+
test_vpmax_s16();
146+
test_vpmax_s32();
147+
test_vpmax_u8();
148+
test_vpmax_u16();
149+
test_vpmax_u32();
150+
test_vpmax_f32();
151+
}
152+
}
153+
154+
#[cfg(target_arch = "x86_64")]
155+
fn main() {}

src/intrinsics/llvm_aarch64.rs

+72-19
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,78 @@ pub(crate) fn codegen_aarch64_llvm_intrinsic_call<'tcx>(
156156
});
157157
}
158158

159+
_ if intrinsic.starts_with("llvm.aarch64.neon.umaxp.v") => {
160+
intrinsic_args!(fx, args => (x, y); intrinsic);
161+
162+
simd_horizontal_pair_for_each_lane(
163+
fx,
164+
x,
165+
y,
166+
ret,
167+
&|fx, _lane_ty, _res_lane_ty, x_lane, y_lane| fx.bcx.ins().umax(x_lane, y_lane),
168+
);
169+
}
170+
171+
_ if intrinsic.starts_with("llvm.aarch64.neon.smaxp.v") => {
172+
intrinsic_args!(fx, args => (x, y); intrinsic);
173+
174+
simd_horizontal_pair_for_each_lane(
175+
fx,
176+
x,
177+
y,
178+
ret,
179+
&|fx, _lane_ty, _res_lane_ty, x_lane, y_lane| fx.bcx.ins().smax(x_lane, y_lane),
180+
);
181+
}
182+
183+
_ if intrinsic.starts_with("llvm.aarch64.neon.uminp.v") => {
184+
intrinsic_args!(fx, args => (x, y); intrinsic);
185+
186+
simd_horizontal_pair_for_each_lane(
187+
fx,
188+
x,
189+
y,
190+
ret,
191+
&|fx, _lane_ty, _res_lane_ty, x_lane, y_lane| fx.bcx.ins().umin(x_lane, y_lane),
192+
);
193+
}
194+
195+
_ if intrinsic.starts_with("llvm.aarch64.neon.sminp.v") => {
196+
intrinsic_args!(fx, args => (x, y); intrinsic);
197+
198+
simd_horizontal_pair_for_each_lane(
199+
fx,
200+
x,
201+
y,
202+
ret,
203+
&|fx, _lane_ty, _res_lane_ty, x_lane, y_lane| fx.bcx.ins().smin(x_lane, y_lane),
204+
);
205+
}
206+
207+
_ if intrinsic.starts_with("llvm.aarch64.neon.fminp.v") => {
208+
intrinsic_args!(fx, args => (x, y); intrinsic);
209+
210+
simd_horizontal_pair_for_each_lane(
211+
fx,
212+
x,
213+
y,
214+
ret,
215+
&|fx, _lane_ty, _res_lane_ty, x_lane, y_lane| fx.bcx.ins().fmin(x_lane, y_lane),
216+
);
217+
}
218+
219+
_ if intrinsic.starts_with("llvm.aarch64.neon.fmaxp.v") => {
220+
intrinsic_args!(fx, args => (x, y); intrinsic);
221+
222+
simd_horizontal_pair_for_each_lane(
223+
fx,
224+
x,
225+
y,
226+
ret,
227+
&|fx, _lane_ty, _res_lane_ty, x_lane, y_lane| fx.bcx.ins().fmax(x_lane, y_lane),
228+
);
229+
}
230+
159231
// FIXME generalize vector types
160232
"llvm.aarch64.neon.tbl1.v16i8" => {
161233
intrinsic_args!(fx, args => (t, idx); intrinsic);
@@ -172,25 +244,6 @@ pub(crate) fn codegen_aarch64_llvm_intrinsic_call<'tcx>(
172244
}
173245
}
174246

175-
// FIXME generalize vector types
176-
"llvm.aarch64.neon.umaxp.v16i8" => {
177-
intrinsic_args!(fx, args => (a, b); intrinsic);
178-
179-
// FIXME add helper for horizontal pairwise operations
180-
for i in 0..8 {
181-
let lane1 = a.value_lane(fx, i * 2).load_scalar(fx);
182-
let lane2 = a.value_lane(fx, i * 2 + 1).load_scalar(fx);
183-
let res = fx.bcx.ins().umax(lane1, lane2);
184-
ret.place_lane(fx, i).to_ptr().store(fx, res, MemFlags::trusted());
185-
}
186-
for i in 0..8 {
187-
let lane1 = b.value_lane(fx, i * 2).load_scalar(fx);
188-
let lane2 = b.value_lane(fx, i * 2 + 1).load_scalar(fx);
189-
let res = fx.bcx.ins().umax(lane1, lane2);
190-
ret.place_lane(fx, 8 + i).to_ptr().store(fx, res, MemFlags::trusted());
191-
}
192-
}
193-
194247
/*
195248
_ if intrinsic.starts_with("llvm.aarch64.neon.sshl.v")
196249
|| intrinsic.starts_with("llvm.aarch64.neon.sqshl.v")

src/intrinsics/mod.rs

+30
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,36 @@ fn simd_pair_for_each_lane<'tcx>(
132132
}
133133
}
134134

135+
fn simd_horizontal_pair_for_each_lane<'tcx>(
136+
fx: &mut FunctionCx<'_, '_, 'tcx>,
137+
x: CValue<'tcx>,
138+
y: CValue<'tcx>,
139+
ret: CPlace<'tcx>,
140+
f: &dyn Fn(&mut FunctionCx<'_, '_, 'tcx>, Ty<'tcx>, Ty<'tcx>, Value, Value) -> Value,
141+
) {
142+
assert_eq!(x.layout(), y.layout());
143+
let layout = x.layout();
144+
145+
let (lane_count, lane_ty) = layout.ty.simd_size_and_type(fx.tcx);
146+
let lane_layout = fx.layout_of(lane_ty);
147+
let (ret_lane_count, ret_lane_ty) = ret.layout().ty.simd_size_and_type(fx.tcx);
148+
let ret_lane_layout = fx.layout_of(ret_lane_ty);
149+
assert_eq!(lane_count, ret_lane_count);
150+
151+
for lane_idx in 0..lane_count {
152+
let src = if lane_idx < (lane_count / 2) { x } else { y };
153+
let src_idx = lane_idx % (lane_count / 2);
154+
155+
let lhs_lane = src.value_lane(fx, src_idx * 2).load_scalar(fx);
156+
let rhs_lane = src.value_lane(fx, src_idx * 2 + 1).load_scalar(fx);
157+
158+
let res_lane = f(fx, lane_layout.ty, ret_lane_layout.ty, lhs_lane, rhs_lane);
159+
let res_lane = CValue::by_val(res_lane, ret_lane_layout);
160+
161+
ret.place_lane(fx, lane_idx).write_cvalue(fx, res_lane);
162+
}
163+
}
164+
135165
fn simd_trio_for_each_lane<'tcx>(
136166
fx: &mut FunctionCx<'_, '_, 'tcx>,
137167
x: CValue<'tcx>,

0 commit comments

Comments
 (0)