diff --git a/cranelift/codegen/src/isa/x64/inst/args.rs b/cranelift/codegen/src/isa/x64/inst/args.rs index b54f1b6126fe..6e0d507ab05b 100644 --- a/cranelift/codegen/src/isa/x64/inst/args.rs +++ b/cranelift/codegen/src/isa/x64/inst/args.rs @@ -462,6 +462,7 @@ pub(crate) enum InstructionSet { BMI2, AVX512F, AVX512VL, + AVX512DQ, } /// Some SSE operations requiring 2 operands r/m and r. @@ -994,6 +995,7 @@ impl fmt::Display for SseOpcode { #[derive(Clone)] pub enum Avx512Opcode { Vpabsq, + Vpmullq, } impl Avx512Opcode { @@ -1001,6 +1003,7 @@ impl Avx512Opcode { pub(crate) fn available_from(&self) -> SmallVec<[InstructionSet; 2]> { match self { Avx512Opcode::Vpabsq => smallvec![InstructionSet::AVX512F, InstructionSet::AVX512VL], + Avx512Opcode::Vpmullq => smallvec![InstructionSet::AVX512VL, InstructionSet::AVX512DQ], } } } @@ -1009,6 +1012,7 @@ impl fmt::Debug for Avx512Opcode { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { let name = match self { Avx512Opcode::Vpabsq => "vpabsq", + Avx512Opcode::Vpmullq => "vpmullq", }; write!(fmt, "{}", name) } diff --git a/cranelift/codegen/src/isa/x64/inst/emit.rs b/cranelift/codegen/src/isa/x64/inst/emit.rs index 0bd74ecd8ba4..134d6eafa197 100644 --- a/cranelift/codegen/src/isa/x64/inst/emit.rs +++ b/cranelift/codegen/src/isa/x64/inst/emit.rs @@ -128,6 +128,7 @@ pub(crate) fn emit( InstructionSet::BMI2 => info.isa_flags.has_bmi2(), InstructionSet::AVX512F => info.isa_flags.has_avx512f(), InstructionSet::AVX512VL => info.isa_flags.has_avx512vl(), + InstructionSet::AVX512DQ => info.isa_flags.has_avx512dq(), } }; @@ -1409,6 +1410,7 @@ pub(crate) fn emit( Inst::XmmUnaryRmREvex { op, src, dst } => { let opcode = match op { Avx512Opcode::Vpabsq => 0x1f, + _ => unimplemented!("Opcode {:?} not implemented", op), }; match src { RegMem::Reg { reg: src } => EvexInstruction::new() @@ -1545,6 +1547,31 @@ pub(crate) fn emit( } } + Inst::XmmRmREvex { + op, + src1, + src2, + dst, + } => { + let opcode = match op { + Avx512Opcode::Vpmullq => 0x40, + _ => unimplemented!("Opcode {:?} not implemented", op), + }; + match src1 { + RegMem::Reg { reg: src } => EvexInstruction::new() + .length(EvexVectorLength::V128) + .prefix(LegacyPrefixes::_66) + .map(OpcodeMap::_0F38) + .w(true) + .opcode(opcode) + .reg(dst.to_reg().get_hw_encoding()) + .rm(src.get_hw_encoding()) + .vvvvv(src2.get_hw_encoding()) + .encode(sink), + _ => todo!(), + }; + } + Inst::XmmMinMaxSeq { size, is_min, diff --git a/cranelift/codegen/src/isa/x64/inst/emit_tests.rs b/cranelift/codegen/src/isa/x64/inst/emit_tests.rs index f03762b97bab..1d0dd4aba5df 100644 --- a/cranelift/codegen/src/isa/x64/inst/emit_tests.rs +++ b/cranelift/codegen/src/isa/x64/inst/emit_tests.rs @@ -3555,6 +3555,12 @@ fn test_x64_emit() { "pmullw %xmm14, %xmm1", )); + insns.push(( + Inst::xmm_rm_r_evex(Avx512Opcode::Vpmullq, RegMem::reg(xmm14), xmm10, w_xmm1), + "62D2AD0840CE", + "vpmullq %xmm14, %xmm10, %xmm1", + )); + insns.push(( Inst::xmm_rm_r(SseOpcode::Pmuludq, RegMem::reg(xmm8), w_xmm9), "66450FF4C8", @@ -4283,6 +4289,7 @@ fn test_x64_emit() { isa_flag_builder.enable("has_ssse3").unwrap(); isa_flag_builder.enable("has_sse41").unwrap(); isa_flag_builder.enable("has_avx512f").unwrap(); + isa_flag_builder.enable("has_avx512dq").unwrap(); let isa_flags = x64::settings::Flags::new(&flags, isa_flag_builder); let rru = regs::create_reg_universe_systemv(&flags); diff --git a/cranelift/codegen/src/isa/x64/inst/mod.rs b/cranelift/codegen/src/isa/x64/inst/mod.rs index fe89ac4c9009..547d8413cbfe 100644 --- a/cranelift/codegen/src/isa/x64/inst/mod.rs +++ b/cranelift/codegen/src/isa/x64/inst/mod.rs @@ -212,6 +212,13 @@ pub enum Inst { dst: Writable, }, + XmmRmREvex { + op: Avx512Opcode, + src1: RegMem, + src2: Reg, + dst: Writable, + }, + /// XMM (scalar or vector) unary op: mov between XMM registers (32 64) (reg addr) reg, sqrt, /// etc. /// @@ -577,7 +584,7 @@ impl Inst { | Inst::XmmToGpr { op, .. } | Inst::XmmUnaryRmR { op, .. } => smallvec![op.available_from()], - Inst::XmmUnaryRmREvex { op, .. } => op.available_from(), + Inst::XmmUnaryRmREvex { op, .. } | Inst::XmmRmREvex { op, .. } => op.available_from(), } } } @@ -724,6 +731,23 @@ impl Inst { Inst::XmmRmR { op, src, dst } } + pub(crate) fn xmm_rm_r_evex( + op: Avx512Opcode, + src1: RegMem, + src2: Reg, + dst: Writable, + ) -> Self { + src1.assert_regclass_is(RegClass::V128); + debug_assert!(src2.get_class() == RegClass::V128); + debug_assert!(dst.to_reg().get_class() == RegClass::V128); + Inst::XmmRmREvex { + op, + src1, + src2, + dst, + } + } + pub(crate) fn xmm_uninit_value(dst: Writable) -> Self { debug_assert!(dst.to_reg().get_class() == RegClass::V128); Inst::XmmUninitializedValue { dst } @@ -1425,6 +1449,20 @@ impl PrettyPrint for Inst { show_ireg_sized(dst.to_reg(), mb_rru, 8), ), + Inst::XmmRmREvex { + op, + src1, + src2, + dst, + .. + } => format!( + "{} {}, {}, {}", + ljustify(op.to_string()), + src1.show_rru_sized(mb_rru, 8), + show_ireg_sized(*src2, mb_rru, 8), + show_ireg_sized(dst.to_reg(), mb_rru, 8), + ), + Inst::XmmMinMaxSeq { lhs, rhs_dst, @@ -1898,6 +1936,13 @@ fn x64_get_regs(inst: &Inst, collector: &mut RegUsageCollector) { collector.add_mod(*dst); } } + Inst::XmmRmREvex { + src1, src2, dst, .. + } => { + src1.get_regs_as_uses(collector); + collector.add_use(*src2); + collector.add_def(*dst); + } Inst::XmmRmRImm { op, src, dst, .. } => { if inst.produces_const() { // No need to account for src, since src == dst. @@ -2283,6 +2328,16 @@ fn x64_map_regs(inst: &mut Inst, mapper: &RUM) { map_mod(mapper, dst); } } + Inst::XmmRmREvex { + ref mut src1, + ref mut src2, + ref mut dst, + .. + } => { + src1.map_uses(mapper); + map_use(mapper, src2); + map_def(mapper, dst); + } Inst::XmmRmiReg { ref mut src, ref mut dst, diff --git a/cranelift/codegen/src/isa/x64/lower.rs b/cranelift/codegen/src/isa/x64/lower.rs index 7a24b7327960..ffbe574f5ccf 100644 --- a/cranelift/codegen/src/isa/x64/lower.rs +++ b/cranelift/codegen/src/isa/x64/lower.rs @@ -1663,105 +1663,116 @@ fn lower_insn_to_regs>( Opcode::Imul => { let ty = ty.unwrap(); if ty == types::I64X2 { - // For I64X2 multiplication we describe a lane A as being - // composed of a 32-bit upper half "Ah" and a 32-bit lower half - // "Al". The 32-bit long hand multiplication can then be written - // as: - // Ah Al - // * Bh Bl - // ----- - // Al * Bl - // + (Ah * Bl) << 32 - // + (Al * Bh) << 32 - // - // So for each lane we will compute: - // A * B = (Al * Bl) + ((Ah * Bl) + (Al * Bh)) << 32 - // - // Note, the algorithm will use pmuldq which operates directly - // on the lower 32-bit (Al or Bl) of a lane and writes the - // result to the full 64-bits of the lane of the destination. - // For this reason we don't need shifts to isolate the lower - // 32-bits, however, we will need to use shifts to isolate the - // high 32-bits when doing calculations, i.e., Ah == A >> 32. - // - // The full sequence then is as follows: - // A' = A - // A' = A' >> 32 - // A' = Ah' * Bl - // B' = B - // B' = B' >> 32 - // B' = Bh' * Al - // B' = B' + A' - // B' = B' << 32 - // A' = A - // A' = Al' * Bl - // A' = A' + B' - // dst = A' - - // Get inputs rhs=A and lhs=B and the dst register + // Eventually one of these should be `input_to_reg_mem` (TODO). let lhs = put_input_in_reg(ctx, inputs[0]); let rhs = put_input_in_reg(ctx, inputs[1]); let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap(); - // A' = A - let rhs_1 = ctx.alloc_tmp(types::I64X2).only_reg().unwrap(); - ctx.emit(Inst::gen_move(rhs_1, rhs, ty)); - - // A' = A' >> 32 - // A' = Ah' * Bl - ctx.emit(Inst::xmm_rmi_reg( - SseOpcode::Psrlq, - RegMemImm::imm(32), - rhs_1, - )); - ctx.emit(Inst::xmm_rm_r( - SseOpcode::Pmuludq, - RegMem::reg(lhs.clone()), - rhs_1, - )); + if isa_flags.use_avx512f_simd() || isa_flags.use_avx512vl_simd() { + // With the right AVX512 features (VL, DQ) this operation + // can lower to a single operation. + ctx.emit(Inst::xmm_rm_r_evex( + Avx512Opcode::Vpmullq, + RegMem::reg(rhs), + lhs, + dst, + )); + } else { + // Otherwise, for I64X2 multiplication we describe a lane A as being + // composed of a 32-bit upper half "Ah" and a 32-bit lower half + // "Al". The 32-bit long hand multiplication can then be written + // as: + // Ah Al + // * Bh Bl + // ----- + // Al * Bl + // + (Ah * Bl) << 32 + // + (Al * Bh) << 32 + // + // So for each lane we will compute: + // A * B = (Al * Bl) + ((Ah * Bl) + (Al * Bh)) << 32 + // + // Note, the algorithm will use pmuldq which operates directly + // on the lower 32-bit (Al or Bl) of a lane and writes the + // result to the full 64-bits of the lane of the destination. + // For this reason we don't need shifts to isolate the lower + // 32-bits, however, we will need to use shifts to isolate the + // high 32-bits when doing calculations, i.e., Ah == A >> 32. + // + // The full sequence then is as follows: + // A' = A + // A' = A' >> 32 + // A' = Ah' * Bl + // B' = B + // B' = B' >> 32 + // B' = Bh' * Al + // B' = B' + A' + // B' = B' << 32 + // A' = A + // A' = Al' * Bl + // A' = A' + B' + // dst = A' + + // A' = A + let rhs_1 = ctx.alloc_tmp(types::I64X2).only_reg().unwrap(); + ctx.emit(Inst::gen_move(rhs_1, rhs, ty)); + + // A' = A' >> 32 + // A' = Ah' * Bl + ctx.emit(Inst::xmm_rmi_reg( + SseOpcode::Psrlq, + RegMemImm::imm(32), + rhs_1, + )); + ctx.emit(Inst::xmm_rm_r( + SseOpcode::Pmuludq, + RegMem::reg(lhs.clone()), + rhs_1, + )); - // B' = B - let lhs_1 = ctx.alloc_tmp(types::I64X2).only_reg().unwrap(); - ctx.emit(Inst::gen_move(lhs_1, lhs, ty)); + // B' = B + let lhs_1 = ctx.alloc_tmp(types::I64X2).only_reg().unwrap(); + ctx.emit(Inst::gen_move(lhs_1, lhs, ty)); - // B' = B' >> 32 - // B' = Bh' * Al - ctx.emit(Inst::xmm_rmi_reg( - SseOpcode::Psrlq, - RegMemImm::imm(32), - lhs_1, - )); - ctx.emit(Inst::xmm_rm_r(SseOpcode::Pmuludq, RegMem::reg(rhs), lhs_1)); + // B' = B' >> 32 + // B' = Bh' * Al + ctx.emit(Inst::xmm_rmi_reg( + SseOpcode::Psrlq, + RegMemImm::imm(32), + lhs_1, + )); + ctx.emit(Inst::xmm_rm_r(SseOpcode::Pmuludq, RegMem::reg(rhs), lhs_1)); - // B' = B' + A' - // B' = B' << 32 - ctx.emit(Inst::xmm_rm_r( - SseOpcode::Paddq, - RegMem::reg(rhs_1.to_reg()), - lhs_1, - )); - ctx.emit(Inst::xmm_rmi_reg( - SseOpcode::Psllq, - RegMemImm::imm(32), - lhs_1, - )); + // B' = B' + A' + // B' = B' << 32 + ctx.emit(Inst::xmm_rm_r( + SseOpcode::Paddq, + RegMem::reg(rhs_1.to_reg()), + lhs_1, + )); + ctx.emit(Inst::xmm_rmi_reg( + SseOpcode::Psllq, + RegMemImm::imm(32), + lhs_1, + )); - // A' = A - // A' = Al' * Bl - // A' = A' + B' - // dst = A' - ctx.emit(Inst::gen_move(rhs_1, rhs, ty)); - ctx.emit(Inst::xmm_rm_r( - SseOpcode::Pmuludq, - RegMem::reg(lhs.clone()), - rhs_1, - )); - ctx.emit(Inst::xmm_rm_r( - SseOpcode::Paddq, - RegMem::reg(lhs_1.to_reg()), - rhs_1, - )); - ctx.emit(Inst::gen_move(dst, rhs_1.to_reg(), ty)); + // A' = A + // A' = Al' * Bl + // A' = A' + B' + // dst = A' + ctx.emit(Inst::gen_move(rhs_1, rhs, ty)); + ctx.emit(Inst::xmm_rm_r( + SseOpcode::Pmuludq, + RegMem::reg(lhs.clone()), + rhs_1, + )); + ctx.emit(Inst::xmm_rm_r( + SseOpcode::Paddq, + RegMem::reg(lhs_1.to_reg()), + rhs_1, + )); + ctx.emit(Inst::gen_move(dst, rhs_1.to_reg(), ty)); + } } else if ty.lane_count() > 1 { // Emit single instruction lowerings for the remaining vector // multiplications.