Skip to content

Commit 367720c

Browse files
author
Flavio Brasil
committed
optimize hash switch generation
1 parent fb376ef commit 367720c

File tree

7 files changed

+175
-470
lines changed

7 files changed

+175
-470
lines changed

compiler/src/org.graalvm.compiler.core.amd64/src/org/graalvm/compiler/core/amd64/AMD64LIRGenerator.java

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@
107107
import org.graalvm.compiler.lir.amd64.vector.AMD64VectorCompareOp;
108108
import org.graalvm.compiler.lir.gen.LIRGenerationResult;
109109
import org.graalvm.compiler.lir.gen.LIRGenerator;
110-
import org.graalvm.compiler.lir.hashing.Hasher;
110+
import org.graalvm.compiler.lir.hashing.IntHasher;
111111
import org.graalvm.compiler.phases.util.Providers;
112112

113113
import jdk.vm.ci.amd64.AMD64;
@@ -686,16 +686,31 @@ protected void emitTableSwitch(int lowKey, LabelRef defaultTarget, LabelRef[] ta
686686
}
687687

688688
@Override
689-
protected Optional<Hasher> hasherFor(JavaConstant[] keyConstants, double minDensity) {
690-
return Hasher.forKeys(keyConstants, minDensity);
689+
protected Optional<IntHasher> hasherFor(JavaConstant[] keyConstants, double minDensity) {
690+
int[] keys = new int[keyConstants.length];
691+
for (int i = 0; i < keyConstants.length; i++) {
692+
keys[i] = keyConstants[i].asInt();
693+
}
694+
return IntHasher.forKeys(keys);
691695
}
692696

693697
@Override
694-
protected void emitHashTableSwitch(Hasher hasher, JavaConstant[] keys, LabelRef defaultTarget, LabelRef[] targets, Value value) {
695-
Value index = hasher.hash(value, arithmeticLIRGen);
698+
protected void emitHashTableSwitch(IntHasher hasher, JavaConstant[] keys, LabelRef defaultTarget, LabelRef[] targets, Value value) {
699+
Value hash = value;
700+
if (hasher.factor > 1) {
701+
Value factor = emitJavaConstant(JavaConstant.forShort(hasher.factor));
702+
hash = arithmeticLIRGen.emitMul(hash, factor, false);
703+
}
704+
if (hasher.shift > 0) {
705+
Value shift = emitJavaConstant(JavaConstant.forByte(hasher.shift));
706+
hash = arithmeticLIRGen.emitShr(hash, shift);
707+
}
708+
Value cardinalityAnd = emitJavaConstant(JavaConstant.forInt(hasher.cardinality - 1));
709+
hash = arithmeticLIRGen.emitAnd(hash, cardinalityAnd);
710+
696711
Variable scratch = newVariable(LIRKind.value(target().arch.getWordKind()));
697712
Variable entryScratch = newVariable(LIRKind.value(target().arch.getWordKind()));
698-
append(new HashTableSwitchOp(keys, defaultTarget, targets, value, index, scratch, entryScratch));
713+
append(new HashTableSwitchOp(keys, defaultTarget, targets, value, hash, scratch, entryScratch));
699714
}
700715

701716
@Override

compiler/src/org.graalvm.compiler.jtt/src/org/graalvm/compiler/jtt/optimize/SwitchHashTableTest.java

Lines changed: 0 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -24,37 +24,15 @@
2424
*/
2525
package org.graalvm.compiler.jtt.optimize;
2626

27-
import java.util.ArrayList;
28-
import java.util.Arrays;
29-
import java.util.HashSet;
30-
import java.util.List;
31-
import java.util.Random;
32-
import java.util.Set;
33-
import java.util.stream.Collectors;
34-
3527
import org.graalvm.compiler.jtt.JTTTest;
36-
import org.graalvm.compiler.lir.hashing.HashFunction;
37-
import org.graalvm.compiler.lir.hashing.Hasher;
3828
import org.junit.Test;
3929

40-
import jdk.vm.ci.meta.JavaConstant;
41-
4230
/*
4331
* Tests optimization of hash table switches.
4432
* Code generated by `SwitchHashTableTest.TestGenerator.main`
4533
*/
4634
public class SwitchHashTableTest extends JTTTest {
47-
@Test
48-
public void checkHashFunctionInstances() {
49-
List<String> coveredByTestCases = Arrays.asList("val >> min", "val", "val >> (val & min)", "(val >> min) ^ val", "val - min", "rotateRight(val, prime)", "rotateRight(val, prime) ^ val",
50-
"rotateRight(val, prime) + val", "(val >> min) * val", "(val * prime) >> min");
51-
Set<String> functions = HashFunction.instances().stream().map(Object::toString).collect(Collectors.toSet());
52-
functions.removeAll(coveredByTestCases);
53-
assertTrue("The following hash functions are not covered by the `Switch03` test: " + functions +
54-
". Re-run the `Switch03.TestGenerator.main` and update the test class.", functions.isEmpty());
55-
}
5635

57-
// Hasher[function=rotateRight(val, prime), effort=4, cardinality=16]
5836
public static int test1(int arg) {
5937
switch (arg) {
6038
case 3080012:
@@ -103,7 +81,6 @@ public void run1() throws Throwable {
10381
runTest("test1", 3080013); // miss
10482
}
10583

106-
// Hasher[function=rotateRight(val, prime) ^ val, effort=5, cardinality=28]
10784
public static int test2(int arg) {
10885
switch (arg) {
10986
case 718707335:
@@ -152,7 +129,6 @@ public void run2() throws Throwable {
152129
runTest("test2", 718707337); // miss
153130
}
154131

155-
// Hasher[function=(val * prime) >> min, effort=4, cardinality=16]
156132
public static int test3(int arg) {
157133
switch (arg) {
158134
case 880488712:
@@ -201,7 +177,6 @@ public void run3() throws Throwable {
201177
runTest("test3", 880488713); // miss
202178
}
203179

204-
// Hasher[function=rotateRight(val, prime) + val, effort=5, cardinality=28]
205180
public static int test4(int arg) {
206181
switch (arg) {
207182
case 189404658:
@@ -250,7 +225,6 @@ public void run4() throws Throwable {
250225
runTest("test4", 189404659); // miss
251226
}
252227

253-
// Hasher[function=val - min, effort=2, cardinality=24]
254228
public static int test5(int arg) {
255229
switch (arg) {
256230
case 527674226:
@@ -299,7 +273,6 @@ public void run5() throws Throwable {
299273
runTest("test5", 527674227); // miss
300274
}
301275

302-
// Hasher[function=val, effort=1, cardinality=24]
303276
public static int test6(int arg) {
304277
switch (arg) {
305278
case 676979121:
@@ -348,7 +321,6 @@ public void run6() throws Throwable {
348321
runTest("test6", 676979122); // miss
349322
}
350323

351-
// Hasher[function=(val >> min) ^ val, effort=3, cardinality=16]
352324
public static int test7(int arg) {
353325
switch (arg) {
354326
case 634218696:
@@ -397,7 +369,6 @@ public void run7() throws Throwable {
397369
runTest("test7", 634218697); // miss
398370
}
399371

400-
// Hasher[function=val >> min, effort=2, cardinality=16]
401372
public static int test8(int arg) {
402373
switch (arg) {
403374
case 473982403:
@@ -446,7 +417,6 @@ public void run8() throws Throwable {
446417
runTest("test8", 473982404); // miss
447418
}
448419

449-
// Hasher[function=val >> (val & min), effort=3, cardinality=16]
450420
public static int test9(int arg) {
451421
switch (arg) {
452422
case 15745090:
@@ -495,7 +465,6 @@ public void run9() throws Throwable {
495465
runTest("test9", 15745091); // miss
496466
}
497467

498-
// Hasher[function=(val >> min) * val, effort=4, cardinality=28]
499468
public static int test10(int arg) {
500469
switch (arg) {
501470
case 989358996:
@@ -543,93 +512,4 @@ public void run10() throws Throwable {
543512
runTest("test10", 989359109); // above
544513
runTest("test10", 989358997); // miss
545514
}
546-
547-
public static class TestGenerator {
548-
549-
private static int nextId = 0;
550-
private static final int size = 15;
551-
private static double minDensity = 0.5;
552-
553-
// test code generator
554-
public static void main(String[] args) {
555-
556-
Random r = new Random(0);
557-
Set<String> seen = new HashSet<>();
558-
Set<String> all = HashFunction.instances().stream().map(Object::toString).collect(Collectors.toSet());
559-
560-
println("@Test");
561-
println("public void checkHashFunctionInstances() {");
562-
println(" List<String> coveredByTestCases = Arrays.asList(" + String.join(", ", all.stream().map(s -> "\"" + s + "\"").collect(Collectors.toSet())) + ");");
563-
println(" Set<String> functions = HashFunction.instances().stream().map(Object::toString).collect(Collectors.toSet());");
564-
println(" functions.removeAll(coveredByTestCases);");
565-
println(" assertTrue(\"The following hash functions are not covered by the `Switch03` test: \" + functions +");
566-
println(" \". Re-run the `Switch03.TestGenerator.main` and update the test class.\", functions.isEmpty());");
567-
println("}");
568-
569-
while (seen.size() < all.size()) {
570-
int v = r.nextInt(Integer.MAX_VALUE / 2);
571-
List<Integer> keys = new ArrayList<>();
572-
while (keys.size() < 15) {
573-
keys.add(v);
574-
v += r.nextInt(15);
575-
}
576-
keys.sort(Integer::compare);
577-
double density = ((double) keys.size() + 1) / (keys.get(keys.size() - 1) - keys.get(0));
578-
if (density < minDensity) {
579-
Hasher.forKeys(toConstants(keys), minDensity).ifPresent(h -> {
580-
String f = h.function().toString();
581-
if (!seen.contains(f)) {
582-
gen(keys, h);
583-
seen.add(f);
584-
}
585-
});
586-
}
587-
}
588-
}
589-
590-
private static void gen(List<Integer> keys, Hasher hasher) {
591-
int id = ++nextId;
592-
593-
println("// " + hasher + "");
594-
println("public static int test" + id + "(int arg) {");
595-
println(" switch (arg) {");
596-
597-
for (Integer key : keys) {
598-
println(" case " + key + ": return " + key + ";");
599-
}
600-
601-
println(" default: return -1;");
602-
println(" }");
603-
println("}");
604-
605-
int miss = keys.get(0) + 1;
606-
while (keys.contains(miss)) {
607-
miss++;
608-
}
609-
610-
println("@Test");
611-
println("public void run" + id + "() throws Throwable {");
612-
println(" runTest(\"test" + id + "\", 0); // zero ");
613-
println(" runTest(\"test" + id + "\", " + (keys.get(0) - 1) + "); // bellow ");
614-
println(" runTest(\"test" + id + "\", " + keys.get(0) + "); // first ");
615-
println(" runTest(\"test" + id + "\", " + keys.get(size / 2) + "); // middle ");
616-
println(" runTest(\"test" + id + "\", " + keys.get(size - 1) + "); // last ");
617-
println(" runTest(\"test" + id + "\", " + (keys.get(size - 1) + 1) + "); // above ");
618-
println(" runTest(\"test" + id + "\", " + miss + "); // miss ");
619-
println("}");
620-
}
621-
622-
private static void println(String s) {
623-
System.out.println(s);
624-
}
625-
626-
private static JavaConstant[] toConstants(List<Integer> keys) {
627-
JavaConstant[] ckeys = new JavaConstant[keys.size()];
628-
629-
for (int i = 0; i < keys.size(); i++) {
630-
ckeys[i] = JavaConstant.forInt(keys.get(i));
631-
}
632-
return ckeys;
633-
}
634-
}
635515
}

compiler/src/org.graalvm.compiler.lir.amd64/src/org/graalvm/compiler/lir/amd64/AMD64ControlFlow.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -786,11 +786,13 @@ public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
786786
masm.jmp(scratchReg);
787787

788788
// Inserting padding so that jump the table address is aligned
789+
int entrySize;
789790
if (defaultTarget != null) {
790-
masm.align(8);
791+
entrySize = 8;
791792
} else {
792-
masm.align(4);
793+
entrySize = 4;
793794
}
795+
masm.align(entrySize);
794796

795797
// Patch LEA instruction above now that we know the position of the jump table
796798
// this is ugly but there is no better way to do this given the assembler API
@@ -818,7 +820,7 @@ public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
818820
}
819821
}
820822

821-
JumpTable jt = new JumpTable(jumpTablePos, keys[0].asInt(), keys[keys.length - 1].asInt(), 4);
823+
JumpTable jt = new JumpTable(jumpTablePos, keys[0].asInt(), keys[keys.length - 1].asInt(), entrySize);
822824
crb.compilationResult.addAnnotation(jt);
823825
}
824826
}

compiler/src/org.graalvm.compiler.lir/src/org/graalvm/compiler/lir/gen/LIRGenerator.java

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
import org.graalvm.compiler.lir.StandardOp.ZapRegistersOp;
6363
import org.graalvm.compiler.lir.SwitchStrategy;
6464
import org.graalvm.compiler.lir.Variable;
65-
import org.graalvm.compiler.lir.hashing.Hasher;
65+
import org.graalvm.compiler.lir.hashing.IntHasher;
6666
import org.graalvm.compiler.options.Option;
6767
import org.graalvm.compiler.options.OptionKey;
6868
import org.graalvm.compiler.options.OptionType;
@@ -482,8 +482,8 @@ public void emitStrategySwitch(JavaConstant[] keyConstants, double[] keyProbabil
482482

483483
int keyCount = keyConstants.length;
484484
double minDensity = 1 / Math.sqrt(strategy.getAverageEffort());
485-
Optional<Hasher> hasher = hasherFor(keyConstants, minDensity);
486-
double hashTableSwitchDensity = hasher.map(h -> keyCount / (double) h.cardinality()).orElse(0d);
485+
Optional<IntHasher> hasher = hasherFor(keyConstants, minDensity);
486+
double hashTableSwitchDensity = hasher.map(h -> (double) keyCount / h.cardinality).orElse(0d);
487487
// The value range computation below may overflow, so compute it as a long.
488488
long valueRange = (long) keyConstants[keyCount - 1].asInt() - (long) keyConstants[0].asInt() + 1;
489489
double tableSwitchDensity = keyCount / (double) valueRange;
@@ -498,11 +498,10 @@ public void emitStrategySwitch(JavaConstant[] keyConstants, double[] keyProbabil
498498
emitStrategySwitch(strategy, value, keyTargets, defaultTarget);
499499
} else {
500500
if (hashTableSwitchDensity > tableSwitchDensity) {
501-
Hasher h = hasher.get();
502-
int cardinality = h.cardinality();
503-
LabelRef[] targets = new LabelRef[cardinality];
504-
JavaConstant[] keys = new JavaConstant[cardinality];
505-
for (int i = 0; i < cardinality; i++) {
501+
IntHasher h = hasher.get();
502+
LabelRef[] targets = new LabelRef[h.cardinality];
503+
JavaConstant[] keys = new JavaConstant[h.cardinality];
504+
for (int i = 0; i < h.cardinality; i++) {
506505
keys[i] = JavaConstant.INT_0;
507506
targets[i] = defaultTarget;
508507
}
@@ -532,12 +531,12 @@ public void emitStrategySwitch(JavaConstant[] keyConstants, double[] keyProbabil
532531
protected abstract void emitTableSwitch(int lowKey, LabelRef defaultTarget, LabelRef[] targets, Value key);
533532

534533
@SuppressWarnings("unused")
535-
protected Optional<Hasher> hasherFor(JavaConstant[] keyConstants, double minDensity) {
534+
protected Optional<IntHasher> hasherFor(JavaConstant[] keyConstants, double minDensity) {
536535
return Optional.empty();
537536
}
538537

539538
@SuppressWarnings("unused")
540-
protected void emitHashTableSwitch(Hasher hasher, JavaConstant[] keys, LabelRef defaultTarget, LabelRef[] targets, Value value) {
539+
protected void emitHashTableSwitch(IntHasher hasher, JavaConstant[] keys, LabelRef defaultTarget, LabelRef[] targets, Value value) {
541540
throw new UnsupportedOperationException(getClass().getSimpleName() + " doesn't support hash table switches");
542541
}
543542

0 commit comments

Comments
 (0)