Skip to content

Commit

Permalink
x64: Lower fcopysign, ceil, floor, nearest, and trunc in ISLE (#4730)
Browse files Browse the repository at this point in the history
  • Loading branch information
elliottt authored Aug 22, 2022
1 parent bb0b6da commit cee4b20
Show file tree
Hide file tree
Showing 14 changed files with 605 additions and 111 deletions.
61 changes: 60 additions & 1 deletion cranelift/codegen/src/isa/x64/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,15 @@
(decl encode_fcmp_imm (FcmpImm) u8)
(extern constructor encode_fcmp_imm encode_fcmp_imm)

(type RoundImm extern
(enum RoundNearest
RoundDown
RoundUp
RoundZero))

(decl encode_round_imm (RoundImm) u8)
(extern constructor encode_round_imm encode_round_imm)

;;;; Newtypes for Different Register Classes ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(type Gpr (primitive Gpr))
Expand Down Expand Up @@ -1394,6 +1403,9 @@
(decl use_fma () Type)
(extern extractor use_fma use_fma)

(decl use_sse41 () Type)
(extern extractor use_sse41 use_sse41)

;;;; Helpers for Merging and Sinking Immediates/Loads ;;;;;;;;;;;;;;;;;;;;;;;;;

;; Extract a constant `Imm8Reg.Imm8` from a value operand.
Expand Down Expand Up @@ -2575,6 +2587,42 @@
lane
size))

;; Helper for creating `roundss` instructions.
(decl x64_roundss (Xmm RoundImm) Xmm)
(rule (x64_roundss src1 round)
(xmm_rm_r_imm (SseOpcode.Roundss)
src1
src1
(encode_round_imm round)
(OperandSize.Size32)))

;; Helper for creating `roundsd` instructions.
(decl x64_roundsd (Xmm RoundImm) Xmm)
(rule (x64_roundsd src1 round)
(xmm_rm_r_imm (SseOpcode.Roundsd)
src1
src1
(encode_round_imm round)
(OperandSize.Size32)))

;; Helper for creating `roundps` instructions.
(decl x64_roundps (Xmm RoundImm) Xmm)
(rule (x64_roundps src1 round)
(xmm_rm_r_imm (SseOpcode.Roundps)
src1
src1
(encode_round_imm round)
(OperandSize.Size32)))

;; Helper for creating `roundpd` instructions.
(decl x64_roundpd (Xmm RoundImm) Xmm)
(rule (x64_roundpd src1 round)
(xmm_rm_r_imm (SseOpcode.Roundpd)
src1
src1
(encode_round_imm round)
(OperandSize.Size32)))

;; Helper for creating `pmaddwd` instructions.
(decl x64_pmaddwd (Xmm XmmMem) Xmm)
(rule (x64_pmaddwd src1 src2)
Expand Down Expand Up @@ -3659,7 +3707,18 @@
(type LibCall extern
(enum
FmaF32
FmaF64))
FmaF64
CeilF32
CeilF64
FloorF32
FloorF64
NearestF32
NearestF64
TruncF32
TruncF64))

(decl libcall_1 (LibCall Reg) Reg)
(extern constructor libcall_1 libcall_1)

(decl libcall_3 (LibCall Reg Reg Reg) Reg)
(extern constructor libcall_3 libcall_3)
3 changes: 2 additions & 1 deletion cranelift/codegen/src/isa/x64/inst/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1770,7 +1770,8 @@ impl From<FloatCC> for FcmpImm {
/// However the rounding immediate which this field helps make up, also includes
/// bits 3 and 4 which define the rounding select and precision mask respectively.
/// These two bits are not defined here and are implictly set to zero when encoded.
pub(crate) enum RoundImm {
#[derive(Clone, Copy)]
pub enum RoundImm {
RoundNearest = 0x00,
RoundDown = 0x01,
RoundUp = 0x02,
Expand Down
94 changes: 94 additions & 0 deletions cranelift/codegen/src/isa/x64/lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -3332,3 +3332,97 @@

(rule (lower (has_type $F64 (bitcast src @ (value_type $I64))))
(bitcast_gpr_to_xmm $I64 src))

;; Rules for `fcopysign` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule (lower (has_type $F32 (fcopysign a @ (value_type $F32) b)))
(let ((sign_bit Xmm (imm $F32 0x80000000)))
(x64_orps
(x64_andnps sign_bit a)
(x64_andps sign_bit b))))

(rule (lower (has_type $F64 (fcopysign a @ (value_type $F64) b)))
(let ((sign_bit Xmm (imm $F64 0x8000000000000000)))
(x64_orpd
(x64_andnpd sign_bit a)
(x64_andpd sign_bit b))))

;; Rules for `ceil` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule (lower (has_type (use_sse41) (ceil a @ (value_type $F32))))
(x64_roundss a (RoundImm.RoundUp)))

(rule (lower (ceil a @ (value_type $F32)))
(libcall_1 (LibCall.CeilF32) a))

(rule (lower (has_type (use_sse41) (ceil a @ (value_type $F64))))
(x64_roundsd a (RoundImm.RoundUp)))

(rule (lower (ceil a @ (value_type $F64)))
(libcall_1 (LibCall.CeilF64) a))

(rule (lower (has_type (use_sse41) (ceil a @ (value_type $F32X4))))
(x64_roundps a (RoundImm.RoundUp)))

(rule (lower (has_type (use_sse41) (ceil a @ (value_type $F64X2))))
(x64_roundpd a (RoundImm.RoundUp)))

;; Rules for `floor` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule (lower (has_type (use_sse41) (floor a @ (value_type $F32))))
(x64_roundss a (RoundImm.RoundDown)))

(rule (lower (floor a @ (value_type $F32)))
(libcall_1 (LibCall.FloorF32) a))

(rule (lower (has_type (use_sse41) (floor a @ (value_type $F64))))
(x64_roundsd a (RoundImm.RoundDown)))

(rule (lower (floor a @ (value_type $F64)))
(libcall_1 (LibCall.FloorF64) a))

(rule (lower (has_type (use_sse41) (floor a @ (value_type $F32X4))))
(x64_roundps a (RoundImm.RoundDown)))

(rule (lower (has_type (use_sse41) (floor a @ (value_type $F64X2))))
(x64_roundpd a (RoundImm.RoundDown)))

;; Rules for `nearest` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule (lower (has_type (use_sse41) (nearest a @ (value_type $F32))))
(x64_roundss a (RoundImm.RoundNearest)))

(rule (lower (nearest a @ (value_type $F32)))
(libcall_1 (LibCall.NearestF32) a))

(rule (lower (has_type (use_sse41) (nearest a @ (value_type $F64))))
(x64_roundsd a (RoundImm.RoundNearest)))

(rule (lower (nearest a @ (value_type $F64)))
(libcall_1 (LibCall.NearestF64) a))

(rule (lower (has_type (use_sse41) (nearest a @ (value_type $F32X4))))
(x64_roundps a (RoundImm.RoundNearest)))

(rule (lower (has_type (use_sse41) (nearest a @ (value_type $F64X2))))
(x64_roundpd a (RoundImm.RoundNearest)))

;; Rules for `trunc` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule (lower (has_type (use_sse41) (trunc a @ (value_type $F32))))
(x64_roundss a (RoundImm.RoundZero)))

(rule (lower (trunc a @ (value_type $F32)))
(libcall_1 (LibCall.TruncF32) a))

(rule (lower (has_type (use_sse41) (trunc a @ (value_type $F64))))
(x64_roundsd a (RoundImm.RoundZero)))

(rule (lower (trunc a @ (value_type $F64)))
(libcall_1 (LibCall.TruncF64) a))

(rule (lower (has_type (use_sse41) (trunc a @ (value_type $F32X4))))
(x64_roundps a (RoundImm.RoundZero)))

(rule (lower (has_type (use_sse41) (trunc a @ (value_type $F64X2))))
(x64_roundpd a (RoundImm.RoundZero)))
115 changes: 6 additions & 109 deletions cranelift/codegen/src/isa/x64/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -569,118 +569,15 @@ fn lower_insn_to_regs(
| Opcode::Unarrow
| Opcode::Bitcast
| Opcode::Fabs
| Opcode::Fneg => {
| Opcode::Fneg
| Opcode::Fcopysign
| Opcode::Ceil
| Opcode::Floor
| Opcode::Nearest
| Opcode::Trunc => {
implemented_in_isle(ctx);
}

Opcode::Fcopysign => {
let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap();
let lhs = put_input_in_reg(ctx, inputs[0]);
let rhs = put_input_in_reg(ctx, inputs[1]);

let ty = ty.unwrap();

// We're going to generate the following sequence:
//
// movabs $INT_MIN, tmp_gpr1
// mov{d,q} tmp_gpr1, tmp_xmm1
// movap{s,d} tmp_xmm1, dst
// andnp{s,d} src_1, dst
// movap{s,d} src_2, tmp_xmm2
// andp{s,d} tmp_xmm1, tmp_xmm2
// orp{s,d} tmp_xmm2, dst

let tmp_xmm1 = ctx.alloc_tmp(types::F32).only_reg().unwrap();
let tmp_xmm2 = ctx.alloc_tmp(types::F32).only_reg().unwrap();

let (sign_bit_cst, mov_op, and_not_op, and_op, or_op) = match ty {
types::F32 => (
0x8000_0000,
SseOpcode::Movaps,
SseOpcode::Andnps,
SseOpcode::Andps,
SseOpcode::Orps,
),
types::F64 => (
0x8000_0000_0000_0000,
SseOpcode::Movapd,
SseOpcode::Andnpd,
SseOpcode::Andpd,
SseOpcode::Orpd,
),
_ => {
panic!("unexpected type {:?} for copysign", ty);
}
};

for inst in Inst::gen_constant(ValueRegs::one(tmp_xmm1), sign_bit_cst, ty, |ty| {
ctx.alloc_tmp(ty).only_reg().unwrap()
}) {
ctx.emit(inst);
}
ctx.emit(Inst::xmm_mov(mov_op, RegMem::reg(tmp_xmm1.to_reg()), dst));
ctx.emit(Inst::xmm_rm_r(and_not_op, RegMem::reg(lhs), dst));
ctx.emit(Inst::xmm_mov(mov_op, RegMem::reg(rhs), tmp_xmm2));
ctx.emit(Inst::xmm_rm_r(
and_op,
RegMem::reg(tmp_xmm1.to_reg()),
tmp_xmm2,
));
ctx.emit(Inst::xmm_rm_r(or_op, RegMem::reg(tmp_xmm2.to_reg()), dst));
}

Opcode::Ceil | Opcode::Floor | Opcode::Nearest | Opcode::Trunc => {
let ty = ty.unwrap();
if isa_flags.use_sse41() {
let mode = match op {
Opcode::Ceil => RoundImm::RoundUp,
Opcode::Floor => RoundImm::RoundDown,
Opcode::Nearest => RoundImm::RoundNearest,
Opcode::Trunc => RoundImm::RoundZero,
_ => panic!("unexpected opcode {:?} in Ceil/Floor/Nearest/Trunc", op),
};
let op = match ty {
types::F32 => SseOpcode::Roundss,
types::F64 => SseOpcode::Roundsd,
types::F32X4 => SseOpcode::Roundps,
types::F64X2 => SseOpcode::Roundpd,
_ => panic!("unexpected type {:?} in Ceil/Floor/Nearest/Trunc", ty),
};
let src = input_to_reg_mem(ctx, inputs[0]);
let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap();
ctx.emit(Inst::xmm_rm_r_imm(
op,
src,
dst,
mode.encode(),
OperandSize::Size32,
));
} else {
// Lower to VM calls when there's no access to SSE4.1.
// Note, for vector types on platforms that don't support sse41
// the execution will panic here.
let libcall = match (op, ty) {
(Opcode::Ceil, types::F32) => LibCall::CeilF32,
(Opcode::Ceil, types::F64) => LibCall::CeilF64,
(Opcode::Floor, types::F32) => LibCall::FloorF32,
(Opcode::Floor, types::F64) => LibCall::FloorF64,
(Opcode::Nearest, types::F32) => LibCall::NearestF32,
(Opcode::Nearest, types::F64) => LibCall::NearestF64,
(Opcode::Trunc, types::F32) => LibCall::TruncF32,
(Opcode::Trunc, types::F64) => LibCall::TruncF64,
_ => panic!(
"unexpected type/opcode {:?}/{:?} in Ceil/Floor/Nearest/Trunc",
ty, op
),
};

let input = put_input_in_reg(ctx, inputs[0]);
let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap();

emit_vm_call(ctx, flags, triple, libcall, &[input], &[dst])?;
}
}

Opcode::DynamicStackAddr => unimplemented!("DynamicStackAddr"),

Opcode::StackAddr => {
Expand Down
32 changes: 32 additions & 0 deletions cranelift/codegen/src/isa/x64/lower/isle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ impl Context for IsleContext<'_, '_, MInst, Flags, IsaFlags, 6> {
imm.encode()
}

#[inline]
fn encode_round_imm(&mut self, imm: &RoundImm) -> u8 {
imm.encode()
}

#[inline]
fn avx512vl_enabled(&mut self, _: Type) -> Option<()> {
if self.isa_flags.use_avx512vl_simd() {
Expand Down Expand Up @@ -248,6 +253,15 @@ impl Context for IsleContext<'_, '_, MInst, Flags, IsaFlags, 6> {
}
}

#[inline]
fn use_sse41(&mut self, _: Type) -> Option<()> {
if self.isa_flags.use_sse41() {
Some(())
} else {
None
}
}

#[inline]
fn imm8_from_value(&mut self, val: Value) -> Option<Imm8Reg> {
let inst = self.lower_ctx.dfg().value_def(val).inst()?;
Expand Down Expand Up @@ -715,6 +729,24 @@ impl Context for IsleContext<'_, '_, MInst, Flags, IsaFlags, 6> {
regs::rsp().to_real_reg().unwrap().into()
}

fn libcall_1(&mut self, libcall: &LibCall, a: Reg) -> Reg {
let call_conv = self.lower_ctx.abi().call_conv();
let ret_ty = libcall.signature(call_conv).returns[0].value_type;
let output_reg = self.lower_ctx.alloc_tmp(ret_ty).only_reg().unwrap();

emit_vm_call(
self.lower_ctx,
self.flags,
self.triple,
libcall.clone(),
&[a],
&[output_reg],
)
.expect("Failed to emit LibCall");

output_reg.to_reg()
}

fn libcall_3(&mut self, libcall: &LibCall, a: Reg, b: Reg, c: Reg) -> Reg {
let call_conv = self.lower_ctx.abi().call_conv();
let ret_ty = libcall.signature(call_conv).returns[0].value_type;
Expand Down
Loading

0 comments on commit cee4b20

Please sign in to comment.