Skip to content

Commit

Permalink
8341137: Optimize long vector multiplication using x86 VPMULDQ instru…
Browse files Browse the repository at this point in the history
…ction
  • Loading branch information
jatin-bhateja committed Sep 28, 2024
1 parent bc7c0dc commit be2da71
Show file tree
Hide file tree
Showing 3 changed files with 290 additions and 5 deletions.
64 changes: 59 additions & 5 deletions src/hotspot/cpu/x86/x86.ad
Original file line number Diff line number Diff line change
Expand Up @@ -1251,6 +1251,44 @@ static inline bool is_clz_non_subword_predicate_evex(BasicType bt, int vlen_byte
(VM_Version::supports_avx512vl() || vlen_bytes == 64);
}

static inline bool is_muldq_pattern(const Node* n) {
if (n->Opcode() == Op_MulVL) {

auto is_clear_upper_double_word_uright_shift_op = [](const Node* n) {
return n->Opcode() == Op_URShiftVL &&
n->in(2)->Opcode() == Op_RShiftCntV && n->in(2)->in(1)->is_Con() &&
n->in(2)->in(1)->bottom_type()->isa_int() &&
n->in(2)->in(1)->bottom_type()->is_int()->get_con() == 32L;
};

auto is_lower_double_word_and_mask_op = [](const Node* n) {
if (n->Opcode() == Op_AndV) {
Node* replicate_operand = n->in(1)->Opcode() == Op_Replicate ? n->in(1) :
n->in(2)->Opcode() == Op_Replicate ? n->in(2) : nullptr;
if (replicate_operand) {
return replicate_operand->in(1)->is_Con() &&
replicate_operand->in(1)->bottom_type()->isa_long() &&
replicate_operand->in(1)->bottom_type()->is_long()->get_con() == 4294967295L;
} else {
return false; // Replication match failed
}
} else {
return false; // AndV match failed
}
};

return (is_lower_double_word_and_mask_op(n->in(1)) ||
is_lower_double_word_and_mask_op(n->in(1)) ||
is_clear_upper_double_word_uright_shift_op(n->in(1)) ||
is_clear_upper_double_word_uright_shift_op(n->in(1))) &&
(is_clear_upper_double_word_uright_shift_op(n->in(2)) ||
is_clear_upper_double_word_uright_shift_op(n->in(2)) ||
is_lower_double_word_and_mask_op(n->in(2)) ||
is_lower_double_word_and_mask_op(n->in(2)));
}
return false;
}

class Node::PD {
public:
enum NodeFlags {
Expand Down Expand Up @@ -6128,11 +6166,13 @@ instruct vmulI_mem(vec dst, vec src, memory mem) %{
ins_pipe( pipe_slow );
%}


// Longs vector mul
instruct evmulL_reg(vec dst, vec src1, vec src2) %{
predicate((Matcher::vector_length_in_bytes(n) == 64 &&
predicate(!is_muldq_pattern(n) &&
((Matcher::vector_length_in_bytes(n) == 64 &&
VM_Version::supports_avx512dq()) ||
VM_Version::supports_avx512vldq());
VM_Version::supports_avx512vldq()));
match(Set dst (MulVL src1 src2));
format %{ "evpmullq $dst,$src1,$src2\t! mul packedL" %}
ins_encode %{
Expand All @@ -6144,10 +6184,11 @@ instruct evmulL_reg(vec dst, vec src1, vec src2) %{
%}

instruct evmulL_mem(vec dst, vec src, memory mem) %{
predicate((Matcher::vector_length_in_bytes(n) == 64 &&
predicate(!is_muldq_pattern(n) &&
((Matcher::vector_length_in_bytes(n) == 64 &&
VM_Version::supports_avx512dq()) ||
(Matcher::vector_length_in_bytes(n) > 8 &&
VM_Version::supports_avx512vldq()));
VM_Version::supports_avx512vldq())));
match(Set dst (MulVL src (LoadVector mem)));
format %{ "evpmullq $dst,$src,$mem\t! mul packedL" %}
ins_encode %{
Expand All @@ -6159,7 +6200,7 @@ instruct evmulL_mem(vec dst, vec src, memory mem) %{
%}

instruct vmulL(vec dst, vec src1, vec src2, vec xtmp) %{
predicate(UseAVX == 0);
predicate(UseAVX == 0 && !is_muldq_pattern(n));
match(Set dst (MulVL src1 src2));
effect(TEMP dst, TEMP xtmp);
format %{ "mulVL $dst, $src1, $src2\t! using $xtmp as TEMP" %}
Expand All @@ -6181,6 +6222,7 @@ instruct vmulL(vec dst, vec src1, vec src2, vec xtmp) %{

instruct vmulL_reg(vec dst, vec src1, vec src2, vec xtmp1, vec xtmp2) %{
predicate(UseAVX > 0 &&
!is_muldq_pattern(n) &&
((Matcher::vector_length_in_bytes(n) == 64 &&
!VM_Version::supports_avx512dq()) ||
(Matcher::vector_length_in_bytes(n) < 64 &&
Expand All @@ -6203,6 +6245,18 @@ instruct vmulL_reg(vec dst, vec src1, vec src2, vec xtmp1, vec xtmp2) %{
ins_pipe( pipe_slow );
%}

instruct vmuludq_reg(vec dst, vec src1, vec src2) %{
predicate(UseAVX > 0 && is_muldq_pattern(n));
match(Set dst (MulVL src1 src2));
ins_cost(100);
format %{ "vpmuludq $dst,$src1,$src2\t! muldq packedL" %}
ins_encode %{
int vlen_enc = vector_length_encoding(this);
__ vpmuludq($dst$$XMMRegister, $src1$$XMMRegister, $src2$$XMMRegister, vlen_enc);
%}
ins_pipe( pipe_slow );
%}

// Floats vector mul
instruct vmulF(vec dst, vec src) %{
predicate(UseAVX == 0);
Expand Down
146 changes: 146 additions & 0 deletions test/hotspot/jtreg/compiler/vectorapi/VectorMultiplyOpt.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/

package compiler.vectorapi;

import jdk.incubator.vector.*;
import java.util.Random;
import java.util.stream.IntStream;

/**
* @test
* @bug 8341137
* @key randomness
* @requires vm.cpu.features ~= ".*avx.*"
* @summary Optimize long vector multiplication using x86 VPMULDQ instruction.
* @modules jdk.incubator.vector
*
* @run driver compiler.vectorapi.VectorMultiplyOpt
*/

public class VectorMultiplyOpt {

public static long [] src1;
public static long [] src2;
public static long [] res;

public static final int SIZE = 4095;
public static final Random r = new Random(1024);
public static final VectorSpecies<Long> LSP = LongVector.SPECIES_PREFERRED;

public static void pattern1(long [] res, long [] src1, long [] src2) {
int i = 0;
for (; i < LSP.loopBound(res.length); i += LSP.length()) {
LongVector vsrc1 = LongVector.fromArray(LSP, src1, i);
LongVector vsrc2 = LongVector.fromArray(LSP, src2, i);
vsrc1.lanewise(VectorOperators.AND, 0xFFFFFFFFL)
.lanewise(VectorOperators.MUL, vsrc2.lanewise(VectorOperators.AND, 0xFFFFFFFFL))
.intoArray(res, i);
}
for (; i < res.length; i++) {
res[i] = (src1[i] & 0xFFFFFFFFL) * (src2[i] & 0xFFFFFFFFL);
}
}

public static void pattern2(long [] res, long [] src1, long [] src2) {
int i = 0;
for (; i < LSP.loopBound(res.length); i += LSP.length()) {
LongVector vsrc1 = LongVector.fromArray(LSP, src1, i);
LongVector vsrc2 = LongVector.fromArray(LSP, src2, i);
vsrc1.lanewise(VectorOperators.AND, 0xFFFFFFFFL)
.lanewise(VectorOperators.MUL, vsrc2.lanewise(VectorOperators.LSHR, 32))
.intoArray(res, i);
}
for (; i < res.length; i++) {
res[i] = (src1[i] & 0xFFFFFFFFL) * (src2[i] >>> 32);
}
}

public static void pattern3(long [] res, long [] src1, long [] src2) {
int i = 0;
for (; i < LSP.loopBound(res.length); i += LSP.length()) {
LongVector vsrc1 = LongVector.fromArray(LSP, src1, i);
LongVector vsrc2 = LongVector.fromArray(LSP, src2, i);
vsrc1.lanewise(VectorOperators.LSHR, 32)
.lanewise(VectorOperators.MUL, vsrc2.lanewise(VectorOperators.LSHR, 32))
.intoArray(res, i);
}
for (; i < res.length; i++) {
res[i] = (src1[i] >>> 32) * (src2[i] >>> 32);
}
}

public static void pattern4(long [] res, long [] src1, long [] src2) {
int i = 0;
for (; i < LSP.loopBound(res.length); i += LSP.length()) {
LongVector vsrc1 = LongVector.fromArray(LSP, src1, i);
LongVector vsrc2 = LongVector.fromArray(LSP, src2, i);
vsrc1.lanewise(VectorOperators.LSHR, 32)
.lanewise(VectorOperators.MUL, vsrc2.lanewise(VectorOperators.AND, 0xFFFFFFFFL))
.intoArray(res, i);
}
for (; i < res.length; i++) {
res[i] = (src1[i] >>> 32) * (src2[i] & 0xFFFFFFFFL);
}
}

interface Validator {
public long apply(long src1, long src2);
}

public static void validate(String msg, long [] actual, long [] src1, long [] src2, Validator func) {
for (int i = 0; i < actual.length; i++) {
if (actual[i] != func.apply(src1[i], src2[i])) {
throw new AssertionError(msg + "index " + i + ": src1 = " + src1[i] + " src2 = " +
src2[i] + " actual = " + actual[i] + " expected = " +
func.apply(src1[i], src2[i]));
}
}
}

public static void setup() {
src1 = new long[SIZE];
src2 = new long[SIZE];
res = new long[SIZE];
IntStream.range(0, SIZE).forEach(i -> { src1[i] = Long.MAX_VALUE * r.nextLong(); });
IntStream.range(0, SIZE).forEach(i -> { src2[i] = Long.MAX_VALUE * r.nextLong(); });
}

public static void main(String[] args) {
setup();
for (int ic = 0; ic < 1000; ic++) {
pattern1(res, src1, src2);
validate("pattern1 ", res, src1, src2, (src1, src2) -> (src1 & 0xFFFFFFFFL) * (src2 & 0xFFFFFFFFL));

pattern2(res, src1, src2);
validate("pattern2 ", res, src1, src2, (src1, src2) -> (src1 & 0xFFFFFFFFL) * (src2 >>> 32));

pattern3(res, src1, src2);
validate("pattern3 ", res, src1, src2, (src1, src2) -> (src1 >>> 32) * (src2 >>> 32));

pattern4(res, src1, src2);
validate("pattern4 ", res, src1, src2, (src1, src2) -> (src1 >>> 32) * (src2 & 0xFFFFFFFFL));
}
System.out.println("PASSED");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package org.openjdk.bench.jdk.incubator.vector;

import java.util.concurrent.TimeUnit;
import org.openjdk.jmh.annotations.*;
import jdk.incubator.vector.*;
import java.util.stream.*;

@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
@Warmup(iterations = 3, time = 1)
@Measurement(iterations = 5, time = 1)
@Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public class VectorXXH3HashingBenchmark {
@Param({"1024", "2048", "4096", "8192"})
private int SIZE;
private long [] accumulators;
private byte [] input;
private byte [] SECRET;

private static final VectorShuffle<Long> LONG_SHUFFLE_PREFERRED = VectorShuffle.fromOp(LongVector.SPECIES_PREFERRED, i -> i ^ 1);

@Setup(Level.Trial)
public void Setup() {
accumulators = new long[SIZE];
input = new byte[SIZE * 8];
SECRET = new byte[SIZE*8];
IntStream.range(0, SIZE*8).forEach(
i -> {
input[i] = (byte)i;
SECRET[i] = (byte)-i;
}
);
}

@Benchmark
public void hashingKernel() {
for (int block = 0; block < input.length / 1024; block++) {
for (int stripe = 0; stripe < 16; stripe++) {
int inputOffset = block * 1024 + stripe * 64;
int secretOffset = stripe * 8;

for (int i = 0; i < 8; i += LongVector.SPECIES_PREFERRED.length()) {
LongVector accumulatorsVector = LongVector.fromArray(LongVector.SPECIES_PREFERRED, accumulators, i);
LongVector inputVector = ByteVector.fromArray(ByteVector.SPECIES_PREFERRED, input, inputOffset + i * 8).reinterpretAsLongs();
LongVector secretVector = ByteVector.fromArray(ByteVector.SPECIES_PREFERRED, SECRET, secretOffset + i * 8).reinterpretAsLongs();

LongVector key = inputVector
.lanewise(VectorOperators.XOR, secretVector)
.reinterpretAsLongs();

LongVector low = key.and(0xFFFF_FFFFL);
LongVector high = key.lanewise(VectorOperators.LSHR, 32);

accumulatorsVector
.add(inputVector.rearrange(LONG_SHUFFLE_PREFERRED))
.add(high.mul(low))
.intoArray(accumulators, i);
}
}
}
}
}

0 comments on commit be2da71

Please sign in to comment.