Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance: use Lucene IntIntHashMap to count hits for approximate queries #598

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,15 @@ lazy val `elastiknn-models` = project

lazy val `elastiknn-models-benchmarks` = project
.in(file("elastiknn-models-benchmarks"))
.dependsOn(`elastiknn-models`, `elastiknn-api4s`)
.dependsOn(`elastiknn-models`, `elastiknn-api4s`, `elastiknn-lucene`)
.enablePlugins(JmhPlugin)
.settings(
Jmh / javaOptions ++= Seq("--add-modules", "jdk.incubator.vector"),
TpolecatSettings
TpolecatSettings,
libraryDependencies ++= Seq(
"org.eclipse.collections" % "eclipse-collections" % "11.1.0",
"org.eclipse.collections" % "eclipse-collections-api" % "11.1.0"
)
)

lazy val `elastiknn-plugin` = project
Expand Down
1,547 changes: 781 additions & 766 deletions docs/pages/performance/fashion-mnist/plot.b64

Large diffs are not rendered by default.

Binary file modified docs/pages/performance/fashion-mnist/plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 8 additions & 8 deletions docs/pages/performance/fashion-mnist/results.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
|Model|Parameters|Recall|Queries per Second|
|---|---|---|---|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=0|0.378|363.121|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=0|0.446|299.144|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=3|0.634|270.522|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=3|0.716|240.419|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=0|0.768|280.053|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=0|0.847|240.014|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=3|0.922|186.668|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=3|0.960|166.241|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=0|0.379|272.975|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=0|0.447|257.790|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=3|0.634|192.264|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=3|0.717|177.546|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=0|0.767|179.336|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=0|0.847|165.566|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=3|0.922|113.075|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=3|0.960|106.883|
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ public int maxKey() {

@Override
public KthGreatest.Result kthGreatest(int k) {
return KthGreatest.kthGreatest(counts, Math.min(k, counts.length - 1));
int[] ints = new int[counts.length];
for (int i = 0; i < counts.length; i++) ints[i] = counts[i];
return KthGreatest.kthGreatest(ints, Math.min(k, counts.length - 1));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package com.klibisz.elastiknn.search;

import org.apache.lucene.search.KthGreatest;

public final class EmptyHitCounter implements HitCounter {
public EmptyHitCounter() { }

@Override
public void increment(int key, short count) { }

@Override
public void increment(int key, int count) { }

@Override
public boolean isEmpty() {
return true;
}

@Override
public short get(int key) {
return 0;
}

@Override
public int numHits() {
return 0;
}

@Override
public int capacity() {
return 0;
}

@Override
public int minKey() {
return 0;
}

@Override
public int maxKey() {
return 0;
}

@Override
public KthGreatest.Result kthGreatest(int k) {
return new KthGreatest.Result(0, 0, 0);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package com.klibisz.elastiknn.search;

import org.apache.lucene.search.KthGreatest;
import org.apache.lucene.util.hppc.IntIntHashMap;

public final class HashMapHitCounter implements HitCounter {

private final IntIntHashMap hashMap;
private final int capacity;
private int minKey;
private int maxKey;

public HashMapHitCounter(int capacity, int expectedElements, float loadFactor) {
this.capacity = capacity;
hashMap = new IntIntHashMap(expectedElements, loadFactor);
minKey = capacity;
maxKey = 0;
}


@Override
public void increment(int key, short count) {
hashMap.putOrAdd(key, count, count);
}

@Override
public void increment(int key, int count) {
minKey = Math.min(key, minKey);
maxKey = Math.max(key, maxKey);
hashMap.putOrAdd(key, count, count);
}

@Override
public boolean isEmpty() {
return hashMap.isEmpty();
}

@Override
public short get(int key) {
return (short) hashMap.get(key);
}

@Override
public int numHits() {
return hashMap.size();
}

@Override
public int capacity() {
return capacity;
}

@Override
public int minKey() {
return minKey;
}

@Override
public int maxKey() {
return maxKey;
}

@Override
public KthGreatest.Result kthGreatest(int k) {
return KthGreatest.kthGreatest(hashMap.values, Math.min(k, hashMap.size() - 1));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.klibisz.elastiknn.search;

public class TestCasting {

public static void main(String[] args) {
int[] ints = {1,2,3,4,5};
short[] shorts = new short[ints.length];
System.arraycopy(ints, 0, shorts, 0, shorts.length);
for (int i = 0; i < shorts.length; i++) {
System.out.println(shorts[i]);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
public class KthGreatest {

public static class Result {
public final short kthGreatest;
public final int kthGreatest;
public final int numGreaterThan;
public final int numNonZero;
public Result(short kthGreatest, int numGreaterThan, int numNonZero) {
public Result(int kthGreatest, int numGreaterThan, int numNonZero) {
this.kthGreatest = kthGreatest;
this.numGreaterThan = numGreaterThan;
this.numNonZero = numNonZero;
Expand All @@ -23,7 +23,7 @@ public Result(short kthGreatest, int numGreaterThan, int numNonZero) {
* @param k the desired largest value.
* @return the kth largest value.
*/
public static Result kthGreatest(short[] arr, int k) {
public static Result kthGreatest(int[] arr, int k) {
if (arr.length == 0) {
throw new IllegalArgumentException("Array must be non-empty");
} else if (k < 0 || k >= arr.length) {
Expand All @@ -33,24 +33,24 @@ public static Result kthGreatest(short[] arr, int k) {
));
} else {
// Find the min and max values.
short max = arr[0];
short min = arr[0];
for (short a: arr) {
int max = arr[0];
int min = arr[0];
for (int a: arr) {
if (a > max) max = a;
else if (a < min) min = a;
}

// Build and populate a histogram for non-zero values.
int[] hist = new int[max - min + 1];
int numNonZero = 0;
for (short a: arr) {
for (int a: arr) {
hist[a - min] += 1;
if (a > 0) numNonZero++;
}

// Find the kth largest value by iterating from the end of the histogram.
int numGreaterEqual = 0;
short kthGreatest = max;
int kthGreatest = max;
while (kthGreatest >= min) {
numGreaterEqual += hist[kthGreatest - min];;
if (numGreaterEqual > k) break;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package org.apache.lucene.search;

import com.klibisz.elastiknn.models.HashAndFreq;
import com.klibisz.elastiknn.search.ArrayHitCounter;
import com.klibisz.elastiknn.search.EmptyHitCounter;
import com.klibisz.elastiknn.search.HashMapHitCounter;
import com.klibisz.elastiknn.search.HitCounter;
import org.apache.lucene.index.*;
import org.apache.lucene.util.BytesRef;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;

import static java.lang.Math.min;
Expand Down Expand Up @@ -58,11 +58,11 @@ private HitCounter countHits(LeafReader reader) throws IOException {
Terms terms = reader.terms(field);
// terms seem to be null after deleting docs. https://github.com/alexklibisz/elastiknn/issues/158
if (terms == null) {
return new ArrayHitCounter(0);
return new EmptyHitCounter();
} else {
TermsEnum termsEnum = terms.iterator();
PostingsEnum docs = null;
HitCounter counter = new ArrayHitCounter(reader.maxDoc());
HitCounter counter = new HashMapHitCounter(reader.maxDoc(), candidates * 10, 0.9f);
double counterLimit = counter.capacity() + 1;
// TODO: Is this the right place to use the live docs bitset to check for deleted docs?
// Bits liveDocs = reader.getLiveDocs();
Expand All @@ -79,12 +79,6 @@ private HitCounter countHits(LeafReader reader) throws IOException {
}

private DocIdSetIterator buildDocIdSetIterator(HitCounter counter) {
// TODO: Add back this logging once log4j mess has settled.
// if (counter.numHits() < candidates) {
// logger.warn(String.format(
// "Found fewer approximate matches [%d] than the requested number of candidates [%d]",
// counter.numHits(), candidates));
// }
if (counter.isEmpty()) return DocIdSetIterator.empty();
else {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class KthGreatestSuite extends AnyFunSuite with Matchers {
}

test("example") {
val counts: Array[Short] = Array(2, 2, 8, 7, 4, 4)
val counts: Array[Int] = Array(2, 2, 8, 7, 4, 4)
val res = KthGreatest.kthGreatest(counts, 3)
res.kthGreatest shouldBe 4
res.numGreaterThan shouldBe 2
Expand All @@ -33,7 +33,7 @@ class KthGreatestSuite extends AnyFunSuite with Matchers {
val rng = new Random(seed)
info(s"Using seed $seed")
for (_ <- 0 until 999) {
val counts = (0 until (rng.nextInt(10000) + 1)).map(_ => rng.nextInt(Short.MaxValue).toShort).toArray
val counts = (0 until (rng.nextInt(10000) + 1)).map(_ => rng.nextInt(Short.MaxValue)).toArray
val k = rng.nextInt(counts.length)
val res = KthGreatest.kthGreatest(counts, k)
res.kthGreatest shouldBe counts.sorted.reverse(k)
Expand All @@ -43,15 +43,15 @@ class KthGreatestSuite extends AnyFunSuite with Matchers {
}

test("all zero except one") {
val counts = Array[Short](50, 0, 0, 0, 0, 0, 0, 0, 0, 0)
val counts = Array[Int](50, 0, 0, 0, 0, 0, 0, 0, 0, 0)
val res = KthGreatest.kthGreatest(counts, 3)
res.kthGreatest shouldBe 0
res.numGreaterThan shouldBe 1
res.numNonZero shouldBe 1
}

test("all zero") {
val counts = Array[Short](0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
val counts = Array[Int](0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
val res = KthGreatest.kthGreatest(counts, 3)
res.kthGreatest shouldBe 0
res.numGreaterThan shouldBe 0
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package com.klibisz.elastiknn.microbenchmarks

import org.openjdk.jmh.annotations._
import org.apache.lucene.util.hppc.IntIntHashMap
import org.eclipse.collections.impl.map.mutable.primitive.IntShortHashMap

import scala.util.Random

@State(Scope.Benchmark)
class HitCounterBenchmarksFixtures {
val rng = new Random(0)
val numDocs = 600000
val numHits = 6000
val candidates = 500
val docs = (1 to numHits).map(_ => rng.nextInt(numDocs)).toArray
}

class HitCounterBenchmarks {

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 5)
@Measurement(time = 5, iterations = 5)
def arrayCountBaseline(f: HitCounterBenchmarksFixtures): Unit = {
val arr = new Array[Int](f.numDocs)
for (d <- f.docs) arr.update(d, arr(d) + 1)
()
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 5)
@Measurement(time = 5, iterations = 5)
def hashMapGetOrDefault(f: HitCounterBenchmarksFixtures): Unit = {
val h = new java.util.HashMap[Int, Int](f.candidates * 10, 0.99f)
for (d <- f.docs) h.put(d, h.getOrDefault(d, 0) + 1)
()
}


@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 5)
@Measurement(time = 5, iterations = 5)
def luceneIntIntHashMap(f: HitCounterBenchmarksFixtures): Unit = {
val m = new IntIntHashMap(f.candidates * 10, 0.99d)
for (d <- f.docs) m.putOrAdd(d, 1, 1)
()
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 5)
@Measurement(time = 5, iterations = 5)
def eclipseIntShortHashMapAddToValue(f: HitCounterBenchmarksFixtures): Unit = {
val m = new IntShortHashMap(f.candidates * 10)
for (d <- f.docs) m.addToValue(d, 1)
()
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.klibisz.elastiknn.vectors
package com.klibisz.elastiknn.microbenchmarks

import com.klibisz.elastiknn.api.Vec
import com.klibisz.elastiknn.vectors._
import org.openjdk.jmh.annotations._

import scala.util.Random
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import com.klibisz.elastiknn.vectors.FloatVectorOps;
import jdk.internal.vm.annotation.ForceInline;

import java.util.Arrays;

public class ExactModel {

@ForceInline
Expand Down
Loading