Skip to content

Commit

Permalink
riscv64: Implement various SIMD float ops (#6657)
Browse files Browse the repository at this point in the history
* riscv64: Implement SIMD `fabs`

* riscv64: Implement SIMD `fcopysign`

* riscv64: Implement SIMD `f{min,max}_pseudo`

* riscv64: Implement SIMD `f{min,max}`
  • Loading branch information
afonso360 authored Jun 28, 2023
1 parent e04f766 commit 6755f35
Show file tree
Hide file tree
Showing 20 changed files with 845 additions and 20 deletions.
4 changes: 0 additions & 4 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,7 @@ fn ignore(testsuite: &str, testname: &str, strategy: &str) -> bool {
"cvt_from_uint",
"issue_3327_bnot_lowering",
"simd_conversions",
"simd_f32x4",
"simd_f32x4_pmin_pmax",
"simd_f32x4_rounding",
"simd_f64x2",
"simd_f64x2_pmin_pmax",
"simd_f64x2_rounding",
"simd_i32x4_trunc_sat_f32x4",
"simd_i32x4_trunc_sat_f64x2",
Expand Down
5 changes: 5 additions & 0 deletions cranelift/codegen/src/isa/riscv64/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -1610,6 +1610,11 @@

;; Float Helpers

;; Returns the bitpattern of the Canonical NaN for the given type.
(decl pure canonical_nan_u64 (Type) u64)
(rule (canonical_nan_u64 $F32) 0x7fc00000)
(rule (canonical_nan_u64 $F64) 0x7ff8000000000000)

(decl gen_default_frm () OptionFloatRoundingMode)
(extern constructor gen_default_frm gen_default_frm)

Expand Down
3 changes: 3 additions & 0 deletions cranelift/codegen/src/isa/riscv64/inst/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1744,6 +1744,9 @@ impl Inst {
(VecAluOpRRR::VfsgnjnVV, vs2, vs1) if vs2 == vs1 => {
format!("vfneg.v {vd_s},{vs2_s}{mask} {vstate}")
}
(VecAluOpRRR::VfsgnjxVV, vs2, vs1) if vs2 == vs1 => {
format!("vfabs.v {vd_s},{vs2_s}{mask} {vstate}")
}
(VecAluOpRRR::VmnandMM, vs2, vs1) if vs2 == vs1 => {
format!("vmnot.m {vd_s},{vs2_s}{mask} {vstate}")
}
Expand Down
9 changes: 9 additions & 0 deletions cranelift/codegen/src/isa/riscv64/inst/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,13 @@ impl VecAluOpRRR {
| VecAluOpRRR::VsadduVV
| VecAluOpRRR::VsadduVX => 0b100000,
VecAluOpRRR::VfrdivVF | VecAluOpRRR::VsaddVV | VecAluOpRRR::VsaddVX => 0b100001,
VecAluOpRRR::VfminVV => 0b000100,
VecAluOpRRR::VfmaxVV => 0b000110,
VecAluOpRRR::VssubuVV | VecAluOpRRR::VssubuVX => 0b100010,
VecAluOpRRR::VssubVV | VecAluOpRRR::VssubVX => 0b100011,
VecAluOpRRR::VfsgnjVV | VecAluOpRRR::VfsgnjVF => 0b001000,
VecAluOpRRR::VfsgnjnVV => 0b001001,
VecAluOpRRR::VfsgnjxVV => 0b001010,
VecAluOpRRR::VrgatherVV | VecAluOpRRR::VrgatherVX => 0b001100,
VecAluOpRRR::VwadduVV | VecAluOpRRR::VwadduVX => 0b110000,
VecAluOpRRR::VwaddVV | VecAluOpRRR::VwaddVX => 0b110001,
Expand Down Expand Up @@ -473,7 +477,11 @@ impl VecAluOpRRR {
| VecAluOpRRR::VfsubVV
| VecAluOpRRR::VfmulVV
| VecAluOpRRR::VfdivVV
| VecAluOpRRR::VfmaxVV
| VecAluOpRRR::VfminVV
| VecAluOpRRR::VfsgnjVV
| VecAluOpRRR::VfsgnjnVV
| VecAluOpRRR::VfsgnjxVV
| VecAluOpRRR::VmfeqVV
| VecAluOpRRR::VmfneVV
| VecAluOpRRR::VmfltVV
Expand All @@ -485,6 +493,7 @@ impl VecAluOpRRR {
| VecAluOpRRR::VfdivVF
| VecAluOpRRR::VfrdivVF
| VecAluOpRRR::VfmergeVFM
| VecAluOpRRR::VfsgnjVF
| VecAluOpRRR::VmfeqVF
| VecAluOpRRR::VmfneVF
| VecAluOpRRR::VmfltVF
Expand Down
38 changes: 38 additions & 0 deletions cranelift/codegen/src/isa/riscv64/inst_vector.isle
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,11 @@
(VfsubVV)
(VfmulVV)
(VfdivVV)
(VfminVV)
(VfmaxVV)
(VfsgnjVV)
(VfsgnjnVV)
(VfsgnjxVV)
(VmergeVVM)
(VredmaxuVS)
(VredminuVS)
Expand Down Expand Up @@ -180,6 +184,7 @@
(VfrsubVF)
(VfmulVF)
(VfdivVF)
(VfsgnjVF)
(VfrdivVF)
(VmergeVXM)
(VfmergeVFM)
Expand Down Expand Up @@ -836,6 +841,27 @@
(rule (rv_vfrdiv_vf vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VfrdivVF) vs2 vs1 mask vstate))

;; Helper for emitting the `vfmin.vv` instruction.
(decl rv_vfmin_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vfmin_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VfminVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vfmax.vv` instruction.
(decl rv_vfmax_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vfmax_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VfmaxVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vfsgnj.vv` ("Floating Point Sign Injection") instruction.
;; The output of this instruction is `vs2` with the sign bit from `vs1`
(decl rv_vfsgnj_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vfsgnj_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VfsgnjVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vfsgnj.vf` ("Floating Point Sign Injection") instruction.
(decl rv_vfsgnj_vf (VReg FReg VecOpMasking VState) VReg)
(rule (rv_vfsgnj_vf vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VfsgnjVF) vs2 vs1 mask vstate))

;; Helper for emitting the `vfsgnjn.vv` ("Floating Point Sign Injection Negated") instruction.
;; The output of this instruction is `vs2` with the negated sign bit from `vs1`
(decl rv_vfsgnjn_vv (VReg VReg VecOpMasking VState) VReg)
Expand All @@ -847,6 +873,18 @@
(decl rv_vfneg_v (VReg VecOpMasking VState) VReg)
(rule (rv_vfneg_v vs mask vstate) (rv_vfsgnjn_vv vs vs mask vstate))

;; Helper for emitting the `vfsgnjx.vv` ("Floating Point Sign Injection Exclusive") instruction.
;; The output of this instruction is `vs2` with the XOR of the sign bits from `vs2` and `vs1`.
;; When `vs2 == vs1` this implements `fabs`
(decl rv_vfsgnjx_vv (VReg VReg VecOpMasking VState) VReg)
(rule (rv_vfsgnjx_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VfsgnjxVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vfabs.v` instruction.
;; This instruction is a mnemonic for `vfsgnjx.vv vd, vs, vs`
(decl rv_vfabs_v (VReg VecOpMasking VState) VReg)
(rule (rv_vfabs_v vs mask vstate) (rv_vfsgnjx_vv vs vs mask vstate))

;; Helper for emitting the `vfsqrt.v` instruction.
;; This instruction splats the F regsiter into all elements of the destination vector.
(decl rv_vfsqrt_v (VReg VecOpMasking VState) VReg)
Expand Down
69 changes: 57 additions & 12 deletions cranelift/codegen/src/isa/riscv64/lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -981,9 +981,12 @@


;;;; Rules for `fabs` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type ty (fabs x)))
(rule 0 (lower (has_type (ty_scalar_float ty) (fabs x)))
(rv_fabs ty x))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (fabs x)))
(rv_vfabs_v x (unmasked) ty))

;;;; Rules for `fneg` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule 0 (lower (has_type (ty_scalar_float ty) (fneg x)))
(rv_fneg ty x))
Expand All @@ -992,9 +995,15 @@
(rv_vfneg_v x (unmasked) ty))

;;;; Rules for `fcopysign` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type ty (fcopysign x y)))
(rule 0 (lower (has_type (ty_scalar_float ty) (fcopysign x y)))
(rv_fsgnj ty x y))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (fcopysign x y)))
(rv_vfsgnj_vv x y (unmasked) ty))

(rule 2 (lower (has_type (ty_vec_fits_in_register ty) (fcopysign x (splat y))))
(rv_vfsgnj_vf x y (unmasked) ty))

;;;; Rules for `fma` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type ty (fma x y z)))
(rv_fmadd ty x y z))
Expand Down Expand Up @@ -1169,24 +1178,60 @@
(rule 3 (lower (has_type (ty_vec_fits_in_register ty) (fdiv (splat x) y)))
(rv_vfrdiv_vf y x (unmasked) ty))

;;;; Rules for `fmin/fmax` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;;; Rules for `fmin` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule
(lower (has_type ty (fmin x y)))
(rule 0 (lower (has_type (ty_scalar_float ty) (fmin x y)))
(gen_float_select (FloatSelectOP.Min) x y ty))

(rule
(lower (has_type ty (fmin_pseudo x y)))
;; vfmin does almost the right thing, but it does not handle NaN's correctly.
;; We should return a NaN if any of the inputs is a NaN, but vfmin returns the
;; number input instead.
;;
;; TODO: We can improve this by using a masked `fmin` instruction that modifies
;; the canonical nan register. That way we could avoid the `vmerge.vv` instruction.
(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (fmin x y)))
(let ((is_not_nan VReg (gen_fcmp_mask ty (FloatCC.Ordered) x y))
(nan XReg (imm $I64 (canonical_nan_u64 (lane_type ty))))
(vec_nan VReg (rv_vmv_vx nan ty))
(min VReg (rv_vfmin_vv x y (unmasked) ty)))
(rv_vmerge_vvm vec_nan min is_not_nan ty)))

;;;; Rules for `fmax` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule 0 (lower (has_type (ty_scalar_float ty) (fmax x y)))
(gen_float_select (FloatSelectOP.Max) x y ty))

;; vfmax does almost the right thing, but it does not handle NaN's correctly.
;; We should return a NaN if any of the inputs is a NaN, but vfmax returns the
;; number input instead.
;;
;; TODO: We can improve this by using a masked `fmax` instruction that modifies
;; the canonical nan register. That way we could avoid the `vmerge.vv` instruction.
(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (fmax x y)))
(let ((is_not_nan VReg (gen_fcmp_mask ty (FloatCC.Ordered) x y))
(nan XReg (imm $I64 (canonical_nan_u64 (lane_type ty))))
(vec_nan VReg (rv_vmv_vx nan ty))
(max VReg (rv_vfmax_vv x y (unmasked) ty)))
(rv_vmerge_vvm vec_nan max is_not_nan ty)))

;;;; Rules for `fmin_pseudo` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule 0 (lower (has_type (ty_scalar_float ty) (fmin_pseudo x y)))
(gen_float_select_pseudo (FloatSelectOP.Min) x y ty))

(rule
(lower (has_type ty (fmax x y)))
(gen_float_select (FloatSelectOP.Max) x y ty))
(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (fmin_pseudo x y)))
(let ((mask VReg (gen_fcmp_mask ty (FloatCC.LessThan) y x)))
(rv_vmerge_vvm x y mask ty)))

(rule
(lower (has_type ty (fmax_pseudo x y)))
;;;; Rules for `fmax_pseudo` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule 0 (lower (has_type (ty_scalar_float ty) (fmax_pseudo x y)))
(gen_float_select_pseudo (FloatSelectOP.Max) x y ty))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (fmax_pseudo x y)))
(let ((mask VReg (gen_fcmp_mask ty (FloatCC.LessThan) x y)))
(rv_vmerge_vvm x y mask ty)))

;;;;; Rules for `stack_addr`;;;;;;;;;
(rule
(lower (stack_addr ss offset))
Expand Down
83 changes: 83 additions & 0 deletions cranelift/filetests/filetests/isa/riscv64/simd-fabs.clif
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
test compile precise-output
set unwind_info=false
target riscv64 has_v


function %fabs_f32x4(f32x4) -> f32x4 {
block0(v0: f32x4):
v1 = fabs v0
return v1
}

; VCode:
; add sp,-16
; sd ra,8(sp)
; sd fp,0(sp)
; mv fp,sp
; block0:
; vle8.v v1,16(fp) #avl=16, #vtype=(e8, m1, ta, ma)
; vfabs.v v4,v1 #avl=4, #vtype=(e32, m1, ta, ma)
; vse8.v v4,0(a0) #avl=16, #vtype=(e8, m1, ta, ma)
; ld ra,8(sp)
; ld fp,0(sp)
; add sp,+16
; ret
;
; Disassembled:
; block0: ; offset 0x0
; addi sp, sp, -0x10
; sd ra, 8(sp)
; sd s0, 0(sp)
; ori s0, sp, 0
; block1: ; offset 0x10
; .byte 0x57, 0x70, 0x08, 0xcc
; addi t6, s0, 0x10
; .byte 0x87, 0x80, 0x0f, 0x02
; .byte 0x57, 0x70, 0x02, 0xcd
; .byte 0x57, 0x92, 0x10, 0x2a
; .byte 0x57, 0x70, 0x08, 0xcc
; .byte 0x27, 0x02, 0x05, 0x02
; ld ra, 8(sp)
; ld s0, 0(sp)
; addi sp, sp, 0x10
; ret

function %fabs_f64x2(f64x2) -> f64x2 {
block0(v0: f64x2):
v1 = fabs v0
return v1
}

; VCode:
; add sp,-16
; sd ra,8(sp)
; sd fp,0(sp)
; mv fp,sp
; block0:
; vle8.v v1,16(fp) #avl=16, #vtype=(e8, m1, ta, ma)
; vfabs.v v4,v1 #avl=2, #vtype=(e64, m1, ta, ma)
; vse8.v v4,0(a0) #avl=16, #vtype=(e8, m1, ta, ma)
; ld ra,8(sp)
; ld fp,0(sp)
; add sp,+16
; ret
;
; Disassembled:
; block0: ; offset 0x0
; addi sp, sp, -0x10
; sd ra, 8(sp)
; sd s0, 0(sp)
; ori s0, sp, 0
; block1: ; offset 0x10
; .byte 0x57, 0x70, 0x08, 0xcc
; addi t6, s0, 0x10
; .byte 0x87, 0x80, 0x0f, 0x02
; .byte 0x57, 0x70, 0x81, 0xcd
; .byte 0x57, 0x92, 0x10, 0x2a
; .byte 0x57, 0x70, 0x08, 0xcc
; .byte 0x27, 0x02, 0x05, 0x02
; ld ra, 8(sp)
; ld s0, 0(sp)
; addi sp, sp, 0x10
; ret

Loading

0 comments on commit 6755f35

Please sign in to comment.