Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
import org.graalvm.compiler.lir.amd64.vector.AMD64VectorCompareOp;
import org.graalvm.compiler.lir.gen.LIRGenerationResult;
import org.graalvm.compiler.lir.gen.LIRGenerator;
import org.graalvm.compiler.lir.hashing.Hasher;
import org.graalvm.compiler.lir.hashing.IntHasher;
import org.graalvm.compiler.phases.util.Providers;

import jdk.vm.ci.amd64.AMD64;
Expand Down Expand Up @@ -686,16 +686,31 @@ protected void emitTableSwitch(int lowKey, LabelRef defaultTarget, LabelRef[] ta
}

@Override
protected Optional<Hasher> hasherFor(JavaConstant[] keyConstants, double minDensity) {
return Hasher.forKeys(keyConstants, minDensity);
protected Optional<IntHasher> hasherFor(JavaConstant[] keyConstants, double minDensity) {
int[] keys = new int[keyConstants.length];
for (int i = 0; i < keyConstants.length; i++) {
keys[i] = keyConstants[i].asInt();
}
return IntHasher.forKeys(keys);
}

@Override
protected void emitHashTableSwitch(Hasher hasher, JavaConstant[] keys, LabelRef defaultTarget, LabelRef[] targets, Value value) {
Value index = hasher.hash(value, arithmeticLIRGen);
protected void emitHashTableSwitch(IntHasher hasher, JavaConstant[] keys, LabelRef defaultTarget, LabelRef[] targets, Value value) {
Value hash = value;
if (hasher.factor > 1) {
Value factor = emitJavaConstant(JavaConstant.forShort(hasher.factor));
hash = arithmeticLIRGen.emitMul(hash, factor, false);
}
if (hasher.shift > 0) {
Value shift = emitJavaConstant(JavaConstant.forByte(hasher.shift));
hash = arithmeticLIRGen.emitShr(hash, shift);
}
Value cardinalityAnd = emitJavaConstant(JavaConstant.forInt(hasher.cardinality - 1));
hash = arithmeticLIRGen.emitAnd(hash, cardinalityAnd);

Variable scratch = newVariable(LIRKind.value(target().arch.getWordKind()));
Variable entryScratch = newVariable(LIRKind.value(target().arch.getWordKind()));
append(new HashTableSwitchOp(keys, defaultTarget, targets, value, index, scratch, entryScratch));
append(new HashTableSwitchOp(keys, defaultTarget, targets, value, hash, scratch, entryScratch));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,37 +24,15 @@
*/
package org.graalvm.compiler.jtt.optimize;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.stream.Collectors;

import org.graalvm.compiler.jtt.JTTTest;
import org.graalvm.compiler.lir.hashing.HashFunction;
import org.graalvm.compiler.lir.hashing.Hasher;
import org.junit.Test;

import jdk.vm.ci.meta.JavaConstant;

/*
* Tests optimization of hash table switches.
* Code generated by `SwitchHashTableTest.TestGenerator.main`
*/
public class SwitchHashTableTest extends JTTTest {
@Test
public void checkHashFunctionInstances() {
List<String> coveredByTestCases = Arrays.asList("val >> min", "val", "val >> (val & min)", "(val >> min) ^ val", "val - min", "rotateRight(val, prime)", "rotateRight(val, prime) ^ val",
"rotateRight(val, prime) + val", "(val >> min) * val", "(val * prime) >> min");
Set<String> functions = HashFunction.instances().stream().map(Object::toString).collect(Collectors.toSet());
functions.removeAll(coveredByTestCases);
assertTrue("The following hash functions are not covered by the `Switch03` test: " + functions +
". Re-run the `Switch03.TestGenerator.main` and update the test class.", functions.isEmpty());
}

// Hasher[function=rotateRight(val, prime), effort=4, cardinality=16]
public static int test1(int arg) {
switch (arg) {
case 3080012:
Expand Down Expand Up @@ -103,7 +81,6 @@ public void run1() throws Throwable {
runTest("test1", 3080013); // miss
}

// Hasher[function=rotateRight(val, prime) ^ val, effort=5, cardinality=28]
public static int test2(int arg) {
switch (arg) {
case 718707335:
Expand Down Expand Up @@ -152,7 +129,6 @@ public void run2() throws Throwable {
runTest("test2", 718707337); // miss
}

// Hasher[function=(val * prime) >> min, effort=4, cardinality=16]
public static int test3(int arg) {
switch (arg) {
case 880488712:
Expand Down Expand Up @@ -201,7 +177,6 @@ public void run3() throws Throwable {
runTest("test3", 880488713); // miss
}

// Hasher[function=rotateRight(val, prime) + val, effort=5, cardinality=28]
public static int test4(int arg) {
switch (arg) {
case 189404658:
Expand Down Expand Up @@ -250,7 +225,6 @@ public void run4() throws Throwable {
runTest("test4", 189404659); // miss
}

// Hasher[function=val - min, effort=2, cardinality=24]
public static int test5(int arg) {
switch (arg) {
case 527674226:
Expand Down Expand Up @@ -299,7 +273,6 @@ public void run5() throws Throwable {
runTest("test5", 527674227); // miss
}

// Hasher[function=val, effort=1, cardinality=24]
public static int test6(int arg) {
switch (arg) {
case 676979121:
Expand Down Expand Up @@ -348,7 +321,6 @@ public void run6() throws Throwable {
runTest("test6", 676979122); // miss
}

// Hasher[function=(val >> min) ^ val, effort=3, cardinality=16]
public static int test7(int arg) {
switch (arg) {
case 634218696:
Expand Down Expand Up @@ -397,7 +369,6 @@ public void run7() throws Throwable {
runTest("test7", 634218697); // miss
}

// Hasher[function=val >> min, effort=2, cardinality=16]
public static int test8(int arg) {
switch (arg) {
case 473982403:
Expand Down Expand Up @@ -446,7 +417,6 @@ public void run8() throws Throwable {
runTest("test8", 473982404); // miss
}

// Hasher[function=val >> (val & min), effort=3, cardinality=16]
public static int test9(int arg) {
switch (arg) {
case 15745090:
Expand Down Expand Up @@ -495,7 +465,6 @@ public void run9() throws Throwable {
runTest("test9", 15745091); // miss
}

// Hasher[function=(val >> min) * val, effort=4, cardinality=28]
public static int test10(int arg) {
switch (arg) {
case 989358996:
Expand Down Expand Up @@ -543,93 +512,4 @@ public void run10() throws Throwable {
runTest("test10", 989359109); // above
runTest("test10", 989358997); // miss
}

public static class TestGenerator {

private static int nextId = 0;
private static final int size = 15;
private static double minDensity = 0.5;

// test code generator
public static void main(String[] args) {

Random r = new Random(0);
Set<String> seen = new HashSet<>();
Set<String> all = HashFunction.instances().stream().map(Object::toString).collect(Collectors.toSet());

println("@Test");
println("public void checkHashFunctionInstances() {");
println(" List<String> coveredByTestCases = Arrays.asList(" + String.join(", ", all.stream().map(s -> "\"" + s + "\"").collect(Collectors.toSet())) + ");");
println(" Set<String> functions = HashFunction.instances().stream().map(Object::toString).collect(Collectors.toSet());");
println(" functions.removeAll(coveredByTestCases);");
println(" assertTrue(\"The following hash functions are not covered by the `Switch03` test: \" + functions +");
println(" \". Re-run the `Switch03.TestGenerator.main` and update the test class.\", functions.isEmpty());");
println("}");

while (seen.size() < all.size()) {
int v = r.nextInt(Integer.MAX_VALUE / 2);
List<Integer> keys = new ArrayList<>();
while (keys.size() < 15) {
keys.add(v);
v += r.nextInt(15);
}
keys.sort(Integer::compare);
double density = ((double) keys.size() + 1) / (keys.get(keys.size() - 1) - keys.get(0));
if (density < minDensity) {
Hasher.forKeys(toConstants(keys), minDensity).ifPresent(h -> {
String f = h.function().toString();
if (!seen.contains(f)) {
gen(keys, h);
seen.add(f);
}
});
}
}
}

private static void gen(List<Integer> keys, Hasher hasher) {
int id = ++nextId;

println("// " + hasher + "");
println("public static int test" + id + "(int arg) {");
println(" switch (arg) {");

for (Integer key : keys) {
println(" case " + key + ": return " + key + ";");
}

println(" default: return -1;");
println(" }");
println("}");

int miss = keys.get(0) + 1;
while (keys.contains(miss)) {
miss++;
}

println("@Test");
println("public void run" + id + "() throws Throwable {");
println(" runTest(\"test" + id + "\", 0); // zero ");
println(" runTest(\"test" + id + "\", " + (keys.get(0) - 1) + "); // bellow ");
println(" runTest(\"test" + id + "\", " + keys.get(0) + "); // first ");
println(" runTest(\"test" + id + "\", " + keys.get(size / 2) + "); // middle ");
println(" runTest(\"test" + id + "\", " + keys.get(size - 1) + "); // last ");
println(" runTest(\"test" + id + "\", " + (keys.get(size - 1) + 1) + "); // above ");
println(" runTest(\"test" + id + "\", " + miss + "); // miss ");
println("}");
}

private static void println(String s) {
System.out.println(s);
}

private static JavaConstant[] toConstants(List<Integer> keys) {
JavaConstant[] ckeys = new JavaConstant[keys.size()];

for (int i = 0; i < keys.size(); i++) {
ckeys[i] = JavaConstant.forInt(keys.get(i));
}
return ckeys;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -786,11 +786,13 @@ public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
masm.jmp(scratchReg);

// Inserting padding so that jump the table address is aligned
int entrySize;
if (defaultTarget != null) {
masm.align(8);
entrySize = 8;
} else {
masm.align(4);
entrySize = 4;
}
masm.align(entrySize);

// Patch LEA instruction above now that we know the position of the jump table
// this is ugly but there is no better way to do this given the assembler API
Expand Down Expand Up @@ -818,7 +820,7 @@ public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
}
}

JumpTable jt = new JumpTable(jumpTablePos, keys[0].asInt(), keys[keys.length - 1].asInt(), 4);
JumpTable jt = new JumpTable(jumpTablePos, keys[0].asInt(), keys[keys.length - 1].asInt(), entrySize);
crb.compilationResult.addAnnotation(jt);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
import org.graalvm.compiler.lir.StandardOp.ZapRegistersOp;
import org.graalvm.compiler.lir.SwitchStrategy;
import org.graalvm.compiler.lir.Variable;
import org.graalvm.compiler.lir.hashing.Hasher;
import org.graalvm.compiler.lir.hashing.IntHasher;
import org.graalvm.compiler.options.Option;
import org.graalvm.compiler.options.OptionKey;
import org.graalvm.compiler.options.OptionType;
Expand Down Expand Up @@ -482,8 +482,8 @@ public void emitStrategySwitch(JavaConstant[] keyConstants, double[] keyProbabil

int keyCount = keyConstants.length;
double minDensity = 1 / Math.sqrt(strategy.getAverageEffort());
Optional<Hasher> hasher = hasherFor(keyConstants, minDensity);
double hashTableSwitchDensity = hasher.map(h -> keyCount / (double) h.cardinality()).orElse(0d);
Optional<IntHasher> hasher = hasherFor(keyConstants, minDensity);
double hashTableSwitchDensity = hasher.map(h -> (double) keyCount / h.cardinality).orElse(0d);
// The value range computation below may overflow, so compute it as a long.
long valueRange = (long) keyConstants[keyCount - 1].asInt() - (long) keyConstants[0].asInt() + 1;
double tableSwitchDensity = keyCount / (double) valueRange;
Expand All @@ -498,11 +498,10 @@ public void emitStrategySwitch(JavaConstant[] keyConstants, double[] keyProbabil
emitStrategySwitch(strategy, value, keyTargets, defaultTarget);
} else {
if (hashTableSwitchDensity > tableSwitchDensity) {
Hasher h = hasher.get();
int cardinality = h.cardinality();
LabelRef[] targets = new LabelRef[cardinality];
JavaConstant[] keys = new JavaConstant[cardinality];
for (int i = 0; i < cardinality; i++) {
IntHasher h = hasher.get();
LabelRef[] targets = new LabelRef[h.cardinality];
JavaConstant[] keys = new JavaConstant[h.cardinality];
for (int i = 0; i < h.cardinality; i++) {
keys[i] = JavaConstant.INT_0;
targets[i] = defaultTarget;
}
Expand Down Expand Up @@ -532,12 +531,12 @@ public void emitStrategySwitch(JavaConstant[] keyConstants, double[] keyProbabil
protected abstract void emitTableSwitch(int lowKey, LabelRef defaultTarget, LabelRef[] targets, Value key);

@SuppressWarnings("unused")
protected Optional<Hasher> hasherFor(JavaConstant[] keyConstants, double minDensity) {
protected Optional<IntHasher> hasherFor(JavaConstant[] keyConstants, double minDensity) {
return Optional.empty();
}

@SuppressWarnings("unused")
protected void emitHashTableSwitch(Hasher hasher, JavaConstant[] keys, LabelRef defaultTarget, LabelRef[] targets, Value value) {
protected void emitHashTableSwitch(IntHasher hasher, JavaConstant[] keys, LabelRef defaultTarget, LabelRef[] targets, Value value) {
throw new UnsupportedOperationException(getClass().getSimpleName() + " doesn't support hash table switches");
}

Expand Down
Loading