Skip to content

Commit

Permalink
LUCENE-10315: Speed up BKD leaf block ids codec by a 512 ints ForUtil (
Browse files Browse the repository at this point in the history
  • Loading branch information
gf2121 authored Feb 7, 2022
1 parent 9174720 commit 28ba89b
Show file tree
Hide file tree
Showing 11 changed files with 367 additions and 94 deletions.
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ Optimizations

* LUCENE-10388: Remove MultiLevelSkipListReader#SkipBuffer to make JVM less confused. (Guo Feng)

* LUCENE-10315: Use SIMD instructions to decode BKD doc IDs. (Guo Feng, Adrien Grand, Ignacio Vera)

Changes in runtime behavior
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ public void readLongs(long[] dst, int offset, int length) throws IOException {
}
}

@Override
public void readInts(int[] dst, int offset, int length) throws IOException {
in.readInts(dst, offset, length);
for (int i = 0; i < length; ++i) {
dst[offset + i] = Integer.reverseBytes(dst[offset + i]);
}
}

@Override
public void readFloats(float[] dst, int offset, int length) throws IOException {
in.readFloats(dst, offset, length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.concurrent.atomic.AtomicInteger;

Expand Down Expand Up @@ -138,6 +139,11 @@ public void getLongs(LongBuffer receiver, long[] dst, int offset, int length) {
receiver.get(dst, offset, length);
}

public void getInts(IntBuffer receiver, int[] dst, int offset, int length) {
ensureValid();
receiver.get(dst, offset, length);
}

public void getFloats(FloatBuffer receiver, float[] dst, int offset, int length) {
ensureValid();
receiver.get(dst, offset, length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;

/**
Expand All @@ -36,6 +37,7 @@
public abstract class ByteBufferIndexInput extends IndexInput implements RandomAccessInput {
private static final FloatBuffer EMPTY_FLOATBUFFER = FloatBuffer.allocate(0);
private static final LongBuffer EMPTY_LONGBUFFER = LongBuffer.allocate(0);
private static final IntBuffer EMPTY_INTBUFFER = IntBuffer.allocate(0);

protected final long length;
protected final long chunkSizeMask;
Expand All @@ -46,6 +48,7 @@ public abstract class ByteBufferIndexInput extends IndexInput implements RandomA
protected int curBufIndex = -1;
protected ByteBuffer curBuf; // redundant for speed: buffers[curBufIndex]
private LongBuffer[] curLongBufferViews;
private IntBuffer[] curIntBufferViews;
private FloatBuffer[] curFloatBufferViews;

protected boolean isClone = false;
Expand Down Expand Up @@ -83,6 +86,7 @@ protected void setCurBuf(ByteBuffer curBuf) {
this.curBuf = curBuf;
curLongBufferViews = null;
curFloatBufferViews = null;
curIntBufferViews = null;
}

@Override
Expand Down Expand Up @@ -176,6 +180,37 @@ public void readLongs(long[] dst, int offset, int length) throws IOException {
}
}

@Override
public void readInts(int[] dst, int offset, int length) throws IOException {
// See notes about readLongs above
if (curIntBufferViews == null) {
curIntBufferViews = new IntBuffer[Integer.BYTES];
for (int i = 0; i < Integer.BYTES; ++i) {
if (i < curBuf.limit()) {
curIntBufferViews[i] =
curBuf.duplicate().position(i).order(ByteOrder.LITTLE_ENDIAN).asIntBuffer();
} else {
curIntBufferViews[i] = EMPTY_INTBUFFER;
}
}
}
try {
final int position = curBuf.position();
guard.getInts(
curIntBufferViews[position & 0x03].position(position >>> 2), dst, offset, length);
// if the above call succeeded, then we know the below sum cannot overflow
curBuf.position(position + (length << 2));
} catch (
@SuppressWarnings("unused")
BufferUnderflowException e) {
super.readInts(dst, offset, length);
} catch (
@SuppressWarnings("unused")
NullPointerException npe) {
throw new AlreadyClosedException("Already closed: " + this);
}
}

@Override
public final void readFloats(float[] floats, int offset, int len) throws IOException {
// See notes about readLongs above
Expand Down Expand Up @@ -503,6 +538,7 @@ private void unsetBuffers() {
curBuf = null;
curBufIndex = 0;
curLongBufferViews = null;
curIntBufferViews = null;
}

/** Optimization of ByteBufferIndexInput for when there is only one buffer */
Expand Down
14 changes: 14 additions & 0 deletions lucene/core/src/java/org/apache/lucene/store/DataInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,20 @@ public void readLongs(long[] dst, int offset, int length) throws IOException {
}
}

/**
* Reads a specified number of ints into an array at the specified offset.
*
* @param dst the array to read bytes into
* @param offset the offset in the array to start storing ints
* @param length the number of ints to read
*/
public void readInts(int[] dst, int offset, int length) throws IOException {
Objects.checkFromIndexSize(offset, length, dst.length);
for (int i = 0; i < length; ++i) {
dst[offset + i] = readInt();
}
}

/**
* Reads a specified number of floats into an array at the specified offset.
*
Expand Down
108 changes: 108 additions & 0 deletions lucene/core/src/java/org/apache/lucene/util/bkd/BKDForUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.bkd;

import java.io.IOException;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.DataOutput;

final class BKDForUtil {

private final int[] tmp;

BKDForUtil(int maxPointsInLeaf) {
// For encode16/decode16, we do not need to use tmp array.
// For encode24/decode24, we need a (3/4 * maxPointsInLeaf) length tmp array.
// For encode32/decode32, we reuse the scratch in DocIdsWriter.
// So (3/4 * maxPointsInLeaf) is enough here.
final int len = (maxPointsInLeaf >>> 2) * 3;
tmp = new int[len];
}

void encode16(int len, int[] ints, DataOutput out) throws IOException {
final int halfLen = len >>> 1;
for (int i = 0; i < halfLen; ++i) {
ints[i] = ints[halfLen + i] | (ints[i] << 16);
}
for (int i = 0; i < halfLen; i++) {
out.writeInt(ints[i]);
}
if ((len & 1) == 1) {
out.writeShort((short) ints[len - 1]);
}
}

void encode32(int off, int len, int[] ints, DataOutput out) throws IOException {
for (int i = 0; i < len; i++) {
out.writeInt(ints[off + i]);
}
}

void encode24(int off, int len, int[] ints, DataOutput out) throws IOException {
final int quarterLen = len >>> 2;
final int quarterLen3 = quarterLen * 3;
for (int i = 0; i < quarterLen3; ++i) {
tmp[i] = ints[off + i] << 8;
}
for (int i = 0; i < quarterLen; i++) {
final int longIdx = off + i + quarterLen3;
tmp[i] |= ints[longIdx] >>> 16;
tmp[i + quarterLen] |= (ints[longIdx] >>> 8) & 0xFF;
tmp[i + quarterLen * 2] |= ints[longIdx] & 0xFF;
}
for (int i = 0; i < quarterLen3; ++i) {
out.writeInt(tmp[i]);
}

final int remainder = len & 0x3;
for (int i = 0; i < remainder; i++) {
out.writeInt(ints[quarterLen * 4 + i]);
}
}

void decode16(DataInput in, int[] ints, int len, final int base) throws IOException {
final int halfLen = len >>> 1;
in.readInts(ints, 0, halfLen);
for (int i = 0; i < halfLen; ++i) {
int l = ints[i];
ints[i] = (l >>> 16) + base;
ints[halfLen + i] = (l & 0xFFFF) + base;
}
if ((len & 1) == 1) {
ints[len - 1] = Short.toUnsignedInt(in.readShort()) + base;
}
}

void decode24(DataInput in, int[] ints, int len) throws IOException {
final int quarterLen = len >>> 2;
final int quarterLen3 = quarterLen * 3;
in.readInts(tmp, 0, quarterLen3);
for (int i = 0; i < quarterLen3; ++i) {
ints[i] = tmp[i] >>> 8;
}
for (int i = 0; i < quarterLen; i++) {
ints[i + quarterLen3] =
((tmp[i] & 0xFF) << 16)
| ((tmp[i + quarterLen] & 0xFF) << 8)
| (tmp[i + quarterLen * 2] & 0xFF);
}
int remainder = len & 0x3;
if (remainder > 0) {
in.readInts(ints, quarterLen << 2, remainder);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ private static class BKDPointTree implements PointTree {
scratchMaxIndexPackedValue;
private final int[] commonPrefixLengths;
private final BKDReaderDocIDSetIterator scratchIterator;
private final DocIdsWriter docIdsWriter;
// if true the tree is balanced, otherwise unbalanced
private final boolean isTreeBalanced;

Expand Down Expand Up @@ -303,6 +304,7 @@ private BKDPointTree(
this.scratchDataPackedValue = scratchDataPackedValue;
this.scratchMinIndexPackedValue = scratchMinIndexPackedValue;
this.scratchMaxIndexPackedValue = scratchMaxIndexPackedValue;
this.docIdsWriter = scratchIterator.docIdsWriter;
}

@Override
Expand Down Expand Up @@ -570,7 +572,7 @@ public void addAll(PointValues.IntersectVisitor visitor, boolean grown) throws I
// How many points are stored in this leaf cell:
int count = leafNodes.readVInt();
// No need to call grow(), it has been called up-front
DocIdsWriter.readInts(leafNodes, count, visitor);
docIdsWriter.readInts(leafNodes, count, visitor);
} else {
pushLeft();
addAll(visitor, grown);
Expand Down Expand Up @@ -633,7 +635,7 @@ private int readDocIDs(IndexInput in, long blockFP, BKDReaderDocIDSetIterator it
// How many points are stored in this leaf cell:
int count = in.readVInt();

DocIdsWriter.readInts(in, count, iterator.docIDs);
docIdsWriter.readInts(in, count, iterator.docIDs);

return count;
}
Expand Down Expand Up @@ -1002,9 +1004,11 @@ private static class BKDReaderDocIDSetIterator extends DocIdSetIterator {
private int offset;
private int docID;
final int[] docIDs;
private final DocIdsWriter docIdsWriter;

public BKDReaderDocIDSetIterator(int maxPointsInLeafNode) {
this.docIDs = new int[maxPointsInLeafNode];
this.docIdsWriter = new DocIdsWriter(maxPointsInLeafNode);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ public class BKDWriter implements Closeable {
private final long totalPointCount;

private final int maxDoc;
private final DocIdsWriter docIdsWriter;

public BKDWriter(
int maxDoc,
Expand Down Expand Up @@ -165,7 +166,7 @@ public BKDWriter(

// Maximum number of points we hold in memory at any time
maxPointsSortInHeap = (int) ((maxMBSortInHeap * 1024 * 1024) / (config.bytesPerDoc));

docIdsWriter = new DocIdsWriter(config.maxPointsInLeafNode);
// Finally, we must be able to hold at least the leaf node in heap during build:
if (maxPointsSortInHeap < config.maxPointsInLeafNode) {
throw new IllegalArgumentException(
Expand Down Expand Up @@ -1288,7 +1289,7 @@ private void writeLeafBlockDocs(DataOutput out, int[] docIDs, int start, int cou
throws IOException {
assert count > 0 : "config.maxPointsInLeafNode=" + config.maxPointsInLeafNode;
out.writeVInt(count);
DocIdsWriter.writeDocIds(docIDs, start, count, out);
docIdsWriter.writeDocIds(docIDs, start, count, out);
}

private void writeLeafBlockPackedValues(
Expand Down
Loading

0 comments on commit 28ba89b

Please sign in to comment.