Skip to content

Commit 2cf3e61

Browse files
author
Flavio Brasil
committed
optimize hash switch generation
1 parent d3820cb commit 2cf3e61

File tree

7 files changed

+158
-460
lines changed

7 files changed

+158
-460
lines changed

compiler/src/org.graalvm.compiler.core.amd64/src/org/graalvm/compiler/core/amd64/AMD64LIRGenerator.java

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@
107107
import org.graalvm.compiler.lir.amd64.vector.AMD64VectorCompareOp;
108108
import org.graalvm.compiler.lir.gen.LIRGenerationResult;
109109
import org.graalvm.compiler.lir.gen.LIRGenerator;
110-
import org.graalvm.compiler.lir.hashing.Hasher;
110+
import org.graalvm.compiler.lir.hashing.IntHasher;
111111
import org.graalvm.compiler.phases.util.Providers;
112112

113113
import jdk.vm.ci.amd64.AMD64;
@@ -686,16 +686,31 @@ protected void emitTableSwitch(int lowKey, LabelRef defaultTarget, LabelRef[] ta
686686
}
687687

688688
@Override
689-
protected Optional<Hasher> hasherFor(JavaConstant[] keyConstants, double minDensity) {
690-
return Hasher.forKeys(keyConstants, minDensity);
689+
protected Optional<IntHasher> hasherFor(JavaConstant[] keyConstants, double minDensity) {
690+
int[] keys = new int[keyConstants.length];
691+
for (int i = 0; i < keyConstants.length; i++) {
692+
keys[i] = keyConstants[i].asInt();
693+
}
694+
return IntHasher.forKeys(keys);
691695
}
692696

693697
@Override
694-
protected void emitHashTableSwitch(Hasher hasher, JavaConstant[] keys, LabelRef defaultTarget, LabelRef[] targets, Value value) {
695-
Value index = hasher.hash(value, arithmeticLIRGen);
698+
protected void emitHashTableSwitch(IntHasher hasher, JavaConstant[] keys, LabelRef defaultTarget, LabelRef[] targets, Value value) {
699+
Value hash = value;
700+
if (hasher.factor > 1) {
701+
Value factor = emitJavaConstant(JavaConstant.forShort(hasher.factor));
702+
hash = arithmeticLIRGen.emitMul(hash, factor, false);
703+
}
704+
if (hasher.shift > 0) {
705+
Value shift = emitJavaConstant(JavaConstant.forByte(hasher.shift));
706+
hash = arithmeticLIRGen.emitShr(hash, shift);
707+
}
708+
Value cardinalityAnd = emitJavaConstant(JavaConstant.forInt(hasher.cardinality - 1));
709+
hash = arithmeticLIRGen.emitAnd(hash, cardinalityAnd);
710+
696711
Variable scratch = newVariable(LIRKind.value(target().arch.getWordKind()));
697712
Variable entryScratch = newVariable(LIRKind.value(target().arch.getWordKind()));
698-
append(new HashTableSwitchOp(keys, defaultTarget, targets, value, index, scratch, entryScratch));
713+
append(new HashTableSwitchOp(keys, defaultTarget, targets, value, hash, scratch, entryScratch));
699714
}
700715

701716
@Override

compiler/src/org.graalvm.compiler.jtt/src/org/graalvm/compiler/jtt/optimize/SwitchHashTableTest.java

Lines changed: 0 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333
import java.util.stream.Collectors;
3434

3535
import org.graalvm.compiler.jtt.JTTTest;
36-
import org.graalvm.compiler.lir.hashing.HashFunction;
37-
import org.graalvm.compiler.lir.hashing.Hasher;
3836
import org.junit.Test;
3937

4038
import jdk.vm.ci.meta.JavaConstant;
@@ -44,17 +42,7 @@
4442
* Code generated by `SwitchHashTableTest.TestGenerator.main`
4543
*/
4644
public class SwitchHashTableTest extends JTTTest {
47-
@Test
48-
public void checkHashFunctionInstances() {
49-
List<String> coveredByTestCases = Arrays.asList("val >> min", "val", "val >> (val & min)", "(val >> min) ^ val", "val - min", "rotateRight(val, prime)", "rotateRight(val, prime) ^ val",
50-
"rotateRight(val, prime) + val", "(val >> min) * val", "(val * prime) >> min");
51-
Set<String> functions = HashFunction.instances().stream().map(Object::toString).collect(Collectors.toSet());
52-
functions.removeAll(coveredByTestCases);
53-
assertTrue("The following hash functions are not covered by the `Switch03` test: " + functions +
54-
". Re-run the `Switch03.TestGenerator.main` and update the test class.", functions.isEmpty());
55-
}
5645

57-
// Hasher[function=rotateRight(val, prime), effort=4, cardinality=16]
5846
public static int test1(int arg) {
5947
switch (arg) {
6048
case 3080012:
@@ -103,7 +91,6 @@ public void run1() throws Throwable {
10391
runTest("test1", 3080013); // miss
10492
}
10593

106-
// Hasher[function=rotateRight(val, prime) ^ val, effort=5, cardinality=28]
10794
public static int test2(int arg) {
10895
switch (arg) {
10996
case 718707335:
@@ -152,7 +139,6 @@ public void run2() throws Throwable {
152139
runTest("test2", 718707337); // miss
153140
}
154141

155-
// Hasher[function=(val * prime) >> min, effort=4, cardinality=16]
156142
public static int test3(int arg) {
157143
switch (arg) {
158144
case 880488712:
@@ -201,7 +187,6 @@ public void run3() throws Throwable {
201187
runTest("test3", 880488713); // miss
202188
}
203189

204-
// Hasher[function=rotateRight(val, prime) + val, effort=5, cardinality=28]
205190
public static int test4(int arg) {
206191
switch (arg) {
207192
case 189404658:
@@ -250,7 +235,6 @@ public void run4() throws Throwable {
250235
runTest("test4", 189404659); // miss
251236
}
252237

253-
// Hasher[function=val - min, effort=2, cardinality=24]
254238
public static int test5(int arg) {
255239
switch (arg) {
256240
case 527674226:
@@ -299,7 +283,6 @@ public void run5() throws Throwable {
299283
runTest("test5", 527674227); // miss
300284
}
301285

302-
// Hasher[function=val, effort=1, cardinality=24]
303286
public static int test6(int arg) {
304287
switch (arg) {
305288
case 676979121:
@@ -348,7 +331,6 @@ public void run6() throws Throwable {
348331
runTest("test6", 676979122); // miss
349332
}
350333

351-
// Hasher[function=(val >> min) ^ val, effort=3, cardinality=16]
352334
public static int test7(int arg) {
353335
switch (arg) {
354336
case 634218696:
@@ -397,7 +379,6 @@ public void run7() throws Throwable {
397379
runTest("test7", 634218697); // miss
398380
}
399381

400-
// Hasher[function=val >> min, effort=2, cardinality=16]
401382
public static int test8(int arg) {
402383
switch (arg) {
403384
case 473982403:
@@ -446,7 +427,6 @@ public void run8() throws Throwable {
446427
runTest("test8", 473982404); // miss
447428
}
448429

449-
// Hasher[function=val >> (val & min), effort=3, cardinality=16]
450430
public static int test9(int arg) {
451431
switch (arg) {
452432
case 15745090:
@@ -495,7 +475,6 @@ public void run9() throws Throwable {
495475
runTest("test9", 15745091); // miss
496476
}
497477

498-
// Hasher[function=(val >> min) * val, effort=4, cardinality=28]
499478
public static int test10(int arg) {
500479
switch (arg) {
501480
case 989358996:
@@ -543,93 +522,4 @@ public void run10() throws Throwable {
543522
runTest("test10", 989359109); // above
544523
runTest("test10", 989358997); // miss
545524
}
546-
547-
public static class TestGenerator {
548-
549-
private static int nextId = 0;
550-
private static final int size = 15;
551-
private static double minDensity = 0.5;
552-
553-
// test code generator
554-
public static void main(String[] args) {
555-
556-
Random r = new Random(0);
557-
Set<String> seen = new HashSet<>();
558-
Set<String> all = HashFunction.instances().stream().map(Object::toString).collect(Collectors.toSet());
559-
560-
println("@Test");
561-
println("public void checkHashFunctionInstances() {");
562-
println(" List<String> coveredByTestCases = Arrays.asList(" + String.join(", ", all.stream().map(s -> "\"" + s + "\"").collect(Collectors.toSet())) + ");");
563-
println(" Set<String> functions = HashFunction.instances().stream().map(Object::toString).collect(Collectors.toSet());");
564-
println(" functions.removeAll(coveredByTestCases);");
565-
println(" assertTrue(\"The following hash functions are not covered by the `Switch03` test: \" + functions +");
566-
println(" \". Re-run the `Switch03.TestGenerator.main` and update the test class.\", functions.isEmpty());");
567-
println("}");
568-
569-
while (seen.size() < all.size()) {
570-
int v = r.nextInt(Integer.MAX_VALUE / 2);
571-
List<Integer> keys = new ArrayList<>();
572-
while (keys.size() < 15) {
573-
keys.add(v);
574-
v += r.nextInt(15);
575-
}
576-
keys.sort(Integer::compare);
577-
double density = ((double) keys.size() + 1) / (keys.get(keys.size() - 1) - keys.get(0));
578-
if (density < minDensity) {
579-
Hasher.forKeys(toConstants(keys), minDensity).ifPresent(h -> {
580-
String f = h.function().toString();
581-
if (!seen.contains(f)) {
582-
gen(keys, h);
583-
seen.add(f);
584-
}
585-
});
586-
}
587-
}
588-
}
589-
590-
private static void gen(List<Integer> keys, Hasher hasher) {
591-
int id = ++nextId;
592-
593-
println("// " + hasher + "");
594-
println("public static int test" + id + "(int arg) {");
595-
println(" switch (arg) {");
596-
597-
for (Integer key : keys) {
598-
println(" case " + key + ": return " + key + ";");
599-
}
600-
601-
println(" default: return -1;");
602-
println(" }");
603-
println("}");
604-
605-
int miss = keys.get(0) + 1;
606-
while (keys.contains(miss)) {
607-
miss++;
608-
}
609-
610-
println("@Test");
611-
println("public void run" + id + "() throws Throwable {");
612-
println(" runTest(\"test" + id + "\", 0); // zero ");
613-
println(" runTest(\"test" + id + "\", " + (keys.get(0) - 1) + "); // bellow ");
614-
println(" runTest(\"test" + id + "\", " + keys.get(0) + "); // first ");
615-
println(" runTest(\"test" + id + "\", " + keys.get(size / 2) + "); // middle ");
616-
println(" runTest(\"test" + id + "\", " + keys.get(size - 1) + "); // last ");
617-
println(" runTest(\"test" + id + "\", " + (keys.get(size - 1) + 1) + "); // above ");
618-
println(" runTest(\"test" + id + "\", " + miss + "); // miss ");
619-
println("}");
620-
}
621-
622-
private static void println(String s) {
623-
System.out.println(s);
624-
}
625-
626-
private static JavaConstant[] toConstants(List<Integer> keys) {
627-
JavaConstant[] ckeys = new JavaConstant[keys.size()];
628-
629-
for (int i = 0; i < keys.size(); i++) {
630-
ckeys[i] = JavaConstant.forInt(keys.get(i));
631-
}
632-
return ckeys;
633-
}
634-
}
635525
}

compiler/src/org.graalvm.compiler.lir.amd64/src/org/graalvm/compiler/lir/amd64/AMD64ControlFlow.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -786,11 +786,14 @@ public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
786786
masm.jmp(scratchReg);
787787

788788
// Inserting padding so that jump the table address is aligned
789+
int entrySize;
789790
if (defaultTarget != null) {
790-
masm.align(8);
791+
entrySize = 8;
791792
} else {
792-
masm.align(4);
793+
entrySize = 4;
793794
}
795+
masm.align(entrySize);
796+
794797

795798
// Patch LEA instruction above now that we know the position of the jump table
796799
// this is ugly but there is no better way to do this given the assembler API
@@ -818,7 +821,7 @@ public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
818821
}
819822
}
820823

821-
JumpTable jt = new JumpTable(jumpTablePos, keys[0].asInt(), keys[keys.length - 1].asInt(), 4);
824+
JumpTable jt = new JumpTable(jumpTablePos, keys[0].asInt(), keys[keys.length - 1].asInt(), entrySize);
822825
crb.compilationResult.addAnnotation(jt);
823826
}
824827
}

compiler/src/org.graalvm.compiler.lir/src/org/graalvm/compiler/lir/gen/LIRGenerator.java

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
import org.graalvm.compiler.lir.StandardOp.ZapRegistersOp;
6363
import org.graalvm.compiler.lir.SwitchStrategy;
6464
import org.graalvm.compiler.lir.Variable;
65-
import org.graalvm.compiler.lir.hashing.Hasher;
65+
import org.graalvm.compiler.lir.hashing.IntHasher;
6666
import org.graalvm.compiler.options.Option;
6767
import org.graalvm.compiler.options.OptionKey;
6868
import org.graalvm.compiler.options.OptionType;
@@ -494,8 +494,8 @@ public void emitStrategySwitch(JavaConstant[] keyConstants, double[] keyProbabil
494494

495495
int keyCount = keyConstants.length;
496496
double minDensity = 1 / Math.sqrt(strategy.getAverageEffort());
497-
Optional<Hasher> hasher = hasherFor(keyConstants, minDensity);
498-
double hashTableSwitchDensity = hasher.map(h -> keyCount / (double) h.cardinality()).orElse(0d);
497+
Optional<IntHasher> hasher = hasherFor(keyConstants, minDensity);
498+
double hashTableSwitchDensity = hasher.map(h -> (double) keyCount / h.cardinality).orElse(0d);
499499
// The value range computation below may overflow, so compute it as a long.
500500
long valueRange = (long) keyConstants[keyCount - 1].asInt() - (long) keyConstants[0].asInt() + 1;
501501
double tableSwitchDensity = keyCount / (double) valueRange;
@@ -510,11 +510,10 @@ public void emitStrategySwitch(JavaConstant[] keyConstants, double[] keyProbabil
510510
emitStrategySwitch(strategy, value, keyTargets, defaultTarget);
511511
} else {
512512
if (hashTableSwitchDensity > tableSwitchDensity) {
513-
Hasher h = hasher.get();
514-
int cardinality = h.cardinality();
515-
LabelRef[] targets = new LabelRef[cardinality];
516-
JavaConstant[] keys = new JavaConstant[cardinality];
517-
for (int i = 0; i < cardinality; i++) {
513+
IntHasher h = hasher.get();
514+
LabelRef[] targets = new LabelRef[h.cardinality];
515+
JavaConstant[] keys = new JavaConstant[h.cardinality];
516+
for (int i = 0; i < h.cardinality; i++) {
518517
keys[i] = JavaConstant.INT_0;
519518
targets[i] = defaultTarget;
520519
}
@@ -545,12 +544,12 @@ public void emitStrategySwitch(JavaConstant[] keyConstants, double[] keyProbabil
545544
protected abstract void emitTableSwitch(int lowKey, LabelRef defaultTarget, LabelRef[] targets, Value key);
546545

547546
@SuppressWarnings("unused")
548-
protected Optional<Hasher> hasherFor(JavaConstant[] keyConstants, double minDensity) {
547+
protected Optional<IntHasher> hasherFor(JavaConstant[] keyConstants, double minDensity) {
549548
return Optional.empty();
550549
}
551550

552551
@SuppressWarnings("unused")
553-
protected void emitHashTableSwitch(Hasher hasher, JavaConstant[] keys, LabelRef defaultTarget, LabelRef[] targets, Value value) {
552+
protected void emitHashTableSwitch(IntHasher hasher, JavaConstant[] keys, LabelRef defaultTarget, LabelRef[] targets, Value value) {
554553
throw new UnsupportedOperationException(getClass().getSimpleName() + " doesn't support hash table switches");
555554
}
556555

0 commit comments

Comments
 (0)