Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

riscv64: Implement various SIMD float ops #6657

Merged
merged 4 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,7 @@ fn ignore(testsuite: &str, testname: &str, strategy: &str) -> bool {
"cvt_from_uint",
"issue_3327_bnot_lowering",
"simd_conversions",
"simd_f32x4",
"simd_f32x4_pmin_pmax",
"simd_f32x4_rounding",
"simd_f64x2",
"simd_f64x2_pmin_pmax",
"simd_f64x2_rounding",
"simd_i32x4_trunc_sat_f32x4",
"simd_i32x4_trunc_sat_f64x2",
Expand Down
5 changes: 5 additions & 0 deletions cranelift/codegen/src/isa/riscv64/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -1610,6 +1610,11 @@

;; Float Helpers

;; Returns the bitpattern of the Canonical NaN for the given type.
(decl pure canonical_nan_u64 (Type) u64)
(rule (canonical_nan_u64 $F32) 0x7fc00000)
(rule (canonical_nan_u64 $F64) 0x7ff8000000000000)

(decl gen_default_frm () OptionFloatRoundingMode)
(extern constructor gen_default_frm gen_default_frm)

Expand Down
3 changes: 3 additions & 0 deletions cranelift/codegen/src/isa/riscv64/inst/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1744,6 +1744,9 @@ impl Inst {
(VecAluOpRRR::VfsgnjnVV, vs2, vs1) if vs2 == vs1 => {
format!("vfneg.v {vd_s},{vs2_s}{mask} {vstate}")
}
(VecAluOpRRR::VfsgnjxVV, vs2, vs1) if vs2 == vs1 => {
format!("vfabs.v {vd_s},{vs2_s}{mask} {vstate}")
}
(VecAluOpRRR::VmnandMM, vs2, vs1) if vs2 == vs1 => {
format!("vmnot.m {vd_s},{vs2_s}{mask} {vstate}")
}
Expand Down
9 changes: 9 additions & 0 deletions cranelift/codegen/src/isa/riscv64/inst/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,13 @@ impl VecAluOpRRR {
| VecAluOpRRR::VsadduVV
| VecAluOpRRR::VsadduVX => 0b100000,
VecAluOpRRR::VfrdivVF | VecAluOpRRR::VsaddVV | VecAluOpRRR::VsaddVX => 0b100001,
VecAluOpRRR::VfminVV => 0b000100,
VecAluOpRRR::VfmaxVV => 0b000110,
VecAluOpRRR::VssubuVV | VecAluOpRRR::VssubuVX => 0b100010,
VecAluOpRRR::VssubVV | VecAluOpRRR::VssubVX => 0b100011,
VecAluOpRRR::VfsgnjVV | VecAluOpRRR::VfsgnjVF => 0b001000,
VecAluOpRRR::VfsgnjnVV => 0b001001,
VecAluOpRRR::VfsgnjxVV => 0b001010,
VecAluOpRRR::VrgatherVV | VecAluOpRRR::VrgatherVX => 0b001100,
VecAluOpRRR::VwadduVV | VecAluOpRRR::VwadduVX => 0b110000,
VecAluOpRRR::VwaddVV | VecAluOpRRR::VwaddVX => 0b110001,
Expand Down Expand Up @@ -473,7 +477,11 @@ impl VecAluOpRRR {
| VecAluOpRRR::VfsubVV
| VecAluOpRRR::VfmulVV
| VecAluOpRRR::VfdivVV
| VecAluOpRRR::VfmaxVV
| VecAluOpRRR::VfminVV
| VecAluOpRRR::VfsgnjVV
| VecAluOpRRR::VfsgnjnVV
| VecAluOpRRR::VfsgnjxVV
| VecAluOpRRR::VmfeqVV
| VecAluOpRRR::VmfneVV
| VecAluOpRRR::VmfltVV
Expand All @@ -485,6 +493,7 @@ impl VecAluOpRRR {
| VecAluOpRRR::VfdivVF
| VecAluOpRRR::VfrdivVF
| VecAluOpRRR::VfmergeVFM
| VecAluOpRRR::VfsgnjVF
| VecAluOpRRR::VmfeqVF
| VecAluOpRRR::VmfneVF
| VecAluOpRRR::VmfltVF
Expand Down
38 changes: 38 additions & 0 deletions cranelift/codegen/src/isa/riscv64/inst_vector.isle
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,11 @@
(VfsubVV)
(VfmulVV)
(VfdivVV)
(VfminVV)
(VfmaxVV)
(VfsgnjVV)
(VfsgnjnVV)
(VfsgnjxVV)
(VmergeVVM)
(VredmaxuVS)
(VredminuVS)
Expand Down Expand Up @@ -180,6 +184,7 @@
(VfrsubVF)
(VfmulVF)
(VfdivVF)
(VfsgnjVF)
(VfrdivVF)
(VmergeVXM)
(VfmergeVFM)
Expand Down Expand Up @@ -836,6 +841,27 @@
(rule (rv_vfrdiv_vf vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VfrdivVF) vs2 vs1 mask vstate))

;; Helper for emitting the `vfmin.vv` instruction.
(decl rv_vfmin_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vfmin_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VfminVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vfmax.vv` instruction.
(decl rv_vfmax_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vfmax_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VfmaxVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vfsgnj.vv` ("Floating Point Sign Injection") instruction.
;; The output of this instruction is `vs2` with the sign bit from `vs1`
(decl rv_vfsgnj_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vfsgnj_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VfsgnjVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vfsgnj.vf` ("Floating Point Sign Injection") instruction.
(decl rv_vfsgnj_vf (VReg FReg VecOpMasking VState) VReg)
(rule (rv_vfsgnj_vf vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VfsgnjVF) vs2 vs1 mask vstate))

;; Helper for emitting the `vfsgnjn.vv` ("Floating Point Sign Injection Negated") instruction.
;; The output of this instruction is `vs2` with the negated sign bit from `vs1`
(decl rv_vfsgnjn_vv (VReg VReg VecOpMasking VState) VReg)
Expand All @@ -847,6 +873,18 @@
(decl rv_vfneg_v (VReg VecOpMasking VState) VReg)
(rule (rv_vfneg_v vs mask vstate) (rv_vfsgnjn_vv vs vs mask vstate))

;; Helper for emitting the `vfsgnjx.vv` ("Floating Point Sign Injection Exclusive") instruction.
;; The output of this instruction is `vs2` with the XOR of the sign bits from `vs2` and `vs1`.
;; When `vs2 == vs1` this implements `fabs`
(decl rv_vfsgnjx_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vfsgnjx_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VfsgnjxVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vfabs.v` instruction.
;; This instruction is a mnemonic for `vfsgnjx.vv vd, vs, vs`
(decl rv_vfabs_v (VReg VecOpMasking VState) VReg)
(rule (rv_vfabs_v vs mask vstate) (rv_vfsgnjx_vv vs vs mask vstate))

;; Helper for emitting the `vfsqrt.v` instruction.
;; This instruction splats the F regsiter into all elements of the destination vector.
(decl rv_vfsqrt_v (VReg VecOpMasking VState) VReg)
Expand Down
69 changes: 57 additions & 12 deletions cranelift/codegen/src/isa/riscv64/lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -981,9 +981,12 @@


;;;; Rules for `fabs` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type ty (fabs x)))
(rule 0 (lower (has_type (ty_scalar_float ty) (fabs x)))
(rv_fabs ty x))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (fabs x)))
(rv_vfabs_v x (unmasked) ty))

;;;; Rules for `fneg` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule 0 (lower (has_type (ty_scalar_float ty) (fneg x)))
(rv_fneg ty x))
Expand All @@ -992,9 +995,15 @@
(rv_vfneg_v x (unmasked) ty))

;;;; Rules for `fcopysign` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type ty (fcopysign x y)))
(rule 0 (lower (has_type (ty_scalar_float ty) (fcopysign x y)))
(rv_fsgnj ty x y))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (fcopysign x y)))
(rv_vfsgnj_vv x y (unmasked) ty))

(rule 2 (lower (has_type (ty_vec_fits_in_register ty) (fcopysign x (splat y))))
(rv_vfsgnj_vf x y (unmasked) ty))

;;;; Rules for `fma` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type ty (fma x y z)))
(rv_fmadd ty x y z))
Expand Down Expand Up @@ -1169,24 +1178,60 @@
(rule 3 (lower (has_type (ty_vec_fits_in_register ty) (fdiv (splat x) y)))
(rv_vfrdiv_vf y x (unmasked) ty))

;;;; Rules for `fmin/fmax` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;;; Rules for `fmin` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule
(lower (has_type ty (fmin x y)))
(rule 0 (lower (has_type (ty_scalar_float ty) (fmin x y)))
(gen_float_select (FloatSelectOP.Min) x y ty))

(rule
(lower (has_type ty (fmin_pseudo x y)))
;; vfmin does almost the right thing, but it does not handle NaN's correctly.
;; We should return a NaN if any of the inputs is a NaN, but vfmin returns the
;; number input instead.
;;
;; TODO: We can improve this by using a masked `fmin` instruction that modifies
;; the canonical nan register. That way we could avoid the `vmerge.vv` instruction.
(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (fmin x y)))
(let ((is_not_nan VReg (gen_fcmp_mask ty (FloatCC.Ordered) x y))
(nan XReg (imm $I64 (canonical_nan_u64 (lane_type ty))))
(vec_nan VReg (rv_vmv_vx nan ty))
(min VReg (rv_vfmin_vv x y (unmasked) ty)))
(rv_vmerge_vvm vec_nan min is_not_nan ty)))

;;;; Rules for `fmax` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule 0 (lower (has_type (ty_scalar_float ty) (fmax x y)))
(gen_float_select (FloatSelectOP.Max) x y ty))

;; vfmax does almost the right thing, but it does not handle NaN's correctly.
;; We should return a NaN if any of the inputs is a NaN, but vfmax returns the
;; number input instead.
;;
;; TODO: We can improve this by using a masked `fmax` instruction that modifies
;; the canonical nan register. That way we could avoid the `vmerge.vv` instruction.
(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (fmax x y)))
(let ((is_not_nan VReg (gen_fcmp_mask ty (FloatCC.Ordered) x y))
(nan XReg (imm $I64 (canonical_nan_u64 (lane_type ty))))
(vec_nan VReg (rv_vmv_vx nan ty))
(max VReg (rv_vfmax_vv x y (unmasked) ty)))
(rv_vmerge_vvm vec_nan max is_not_nan ty)))

;;;; Rules for `fmin_pseudo` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule 0 (lower (has_type (ty_scalar_float ty) (fmin_pseudo x y)))
(gen_float_select_pseudo (FloatSelectOP.Min) x y ty))

(rule
(lower (has_type ty (fmax x y)))
(gen_float_select (FloatSelectOP.Max) x y ty))
(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (fmin_pseudo x y)))
(let ((mask VReg (gen_fcmp_mask ty (FloatCC.LessThan) y x)))
(rv_vmerge_vvm x y mask ty)))

(rule
(lower (has_type ty (fmax_pseudo x y)))
;;;; Rules for `fmax_pseudo` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule 0 (lower (has_type (ty_scalar_float ty) (fmax_pseudo x y)))
(gen_float_select_pseudo (FloatSelectOP.Max) x y ty))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (fmax_pseudo x y)))
(let ((mask VReg (gen_fcmp_mask ty (FloatCC.LessThan) x y)))
(rv_vmerge_vvm x y mask ty)))

;;;;; Rules for `stack_addr`;;;;;;;;;
(rule
(lower (stack_addr ss offset))
Expand Down
83 changes: 83 additions & 0 deletions cranelift/filetests/filetests/isa/riscv64/simd-fabs.clif
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
test compile precise-output
set unwind_info=false
target riscv64 has_v


function %fabs_f32x4(f32x4) -> f32x4 {
block0(v0: f32x4):
v1 = fabs v0
return v1
}

; VCode:
; add sp,-16
; sd ra,8(sp)
; sd fp,0(sp)
; mv fp,sp
; block0:
; vle8.v v1,16(fp) #avl=16, #vtype=(e8, m1, ta, ma)
; vfabs.v v4,v1 #avl=4, #vtype=(e32, m1, ta, ma)
; vse8.v v4,0(a0) #avl=16, #vtype=(e8, m1, ta, ma)
; ld ra,8(sp)
; ld fp,0(sp)
; add sp,+16
; ret
;
; Disassembled:
; block0: ; offset 0x0
; addi sp, sp, -0x10
; sd ra, 8(sp)
; sd s0, 0(sp)
; ori s0, sp, 0
; block1: ; offset 0x10
; .byte 0x57, 0x70, 0x08, 0xcc
; addi t6, s0, 0x10
; .byte 0x87, 0x80, 0x0f, 0x02
; .byte 0x57, 0x70, 0x02, 0xcd
; .byte 0x57, 0x92, 0x10, 0x2a
; .byte 0x57, 0x70, 0x08, 0xcc
; .byte 0x27, 0x02, 0x05, 0x02
; ld ra, 8(sp)
; ld s0, 0(sp)
; addi sp, sp, 0x10
; ret

function %fabs_f64x2(f64x2) -> f64x2 {
block0(v0: f64x2):
v1 = fabs v0
return v1
}

; VCode:
; add sp,-16
; sd ra,8(sp)
; sd fp,0(sp)
; mv fp,sp
; block0:
; vle8.v v1,16(fp) #avl=16, #vtype=(e8, m1, ta, ma)
; vfabs.v v4,v1 #avl=2, #vtype=(e64, m1, ta, ma)
; vse8.v v4,0(a0) #avl=16, #vtype=(e8, m1, ta, ma)
; ld ra,8(sp)
; ld fp,0(sp)
; add sp,+16
; ret
;
; Disassembled:
; block0: ; offset 0x0
; addi sp, sp, -0x10
; sd ra, 8(sp)
; sd s0, 0(sp)
; ori s0, sp, 0
; block1: ; offset 0x10
; .byte 0x57, 0x70, 0x08, 0xcc
; addi t6, s0, 0x10
; .byte 0x87, 0x80, 0x0f, 0x02
; .byte 0x57, 0x70, 0x81, 0xcd
; .byte 0x57, 0x92, 0x10, 0x2a
; .byte 0x57, 0x70, 0x08, 0xcc
; .byte 0x27, 0x02, 0x05, 0x02
; ld ra, 8(sp)
; ld s0, 0(sp)
; addi sp, sp, 0x10
; ret

Loading