diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDReader.java index 6b989497cb39..8fd13c94380c 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDReader.java @@ -333,17 +333,11 @@ private int getNumLeavesSlow(int node) { @Override public void visitDocIDs(PointValues.IntersectVisitor visitor) throws IOException { - addAll(visitor, false); + visitor.grow((int) Math.min(docCount, size())); + addAll(visitor); } - public void addAll(PointValues.IntersectVisitor visitor, boolean grown) throws IOException { - if (grown == false) { - final long size = size(); - if (size <= Integer.MAX_VALUE) { - visitor.grow((int) size); - grown = true; - } - } + public void addAll(PointValues.IntersectVisitor visitor) throws IOException { if (isLeafNode()) { // Leaf node BytesRefBuilder scratch = new BytesRefBuilder(); @@ -356,10 +350,10 @@ public void addAll(PointValues.IntersectVisitor visitor, boolean grown) throws I } } else { pushLeft(); - addAll(visitor, grown); + addAll(visitor); pop(true); pushRight(); - addAll(visitor, grown); + addAll(visitor); pop(false); } } diff --git a/lucene/core/src/java/org/apache/lucene/index/PointValues.java b/lucene/core/src/java/org/apache/lucene/index/PointValues.java index 64229d18936c..74188905c1cb 100644 --- a/lucene/core/src/java/org/apache/lucene/index/PointValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/PointValues.java @@ -332,7 +332,7 @@ default void grow(int count) {} * Finds all documents and points matching the provided visitor. This method does not enforce live * documents, so it's up to the caller to test whether each document is deleted, if necessary. */ - public final void intersect(IntersectVisitor visitor) throws IOException { + public void intersect(IntersectVisitor visitor) throws IOException { final PointTree pointTree = getPointTree(); intersect(visitor, pointTree); assert pointTree.moveToParent() == false; diff --git a/lucene/core/src/java/org/apache/lucene/util/DocIdSetBuilder.java b/lucene/core/src/java/org/apache/lucene/util/DocIdSetBuilder.java index 67b3dde9f200..b5fc9f41ea7b 100644 --- a/lucene/core/src/java/org/apache/lucene/util/DocIdSetBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/DocIdSetBuilder.java @@ -102,7 +102,7 @@ public void add(int doc) { private final int threshold; // pkg-private for testing final boolean multivalued; - final double numValuesPerDoc; + final int docCount; private List buffers = new ArrayList<>(); private int totalAllocated; // accumulated size of the allocated buffers @@ -136,16 +136,17 @@ public DocIdSetBuilder(int maxDoc, PointValues values, String field) throws IOEx DocIdSetBuilder(int maxDoc, int docCount, long valueCount) { this.maxDoc = maxDoc; this.multivalued = docCount < 0 || docCount != valueCount; - if (docCount <= 0 || valueCount < 0) { - // assume one value per doc, this means the cost will be overestimated + + if (docCount < 0) { + // this means the cost will be overestimated // if the docs are actually multi-valued - this.numValuesPerDoc = 1; + this.docCount = Integer.MAX_VALUE; } else { // otherwise compute from index stats - this.numValuesPerDoc = (double) valueCount / docCount; + this.docCount = docCount; } - assert numValuesPerDoc >= 1 : "valueCount=" + valueCount + " docCount=" + docCount; + assert this.docCount >= 0 : "valueCount=" + valueCount + " docCount=" + docCount; // For ridiculously small sets, we'll just use a sorted int[] // maxDoc >>> 7 is a good value if you want to save memory, lower values @@ -267,8 +268,7 @@ public DocIdSet build() { try { if (bitSet != null) { assert counter >= 0; - final long cost = Math.round(counter / numValuesPerDoc); - return new BitDocIdSet(bitSet, cost); + return new BitDocIdSet(bitSet, Math.min(counter, docCount)); } else { Buffer concatenated = concat(buffers); LSBRadixSorter sorter = new LSBRadixSorter(); diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java index d2d326b3a156..54a765671689 100644 --- a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java +++ b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java @@ -164,6 +164,7 @@ public PointTree getPointTree() throws IOException { numLeaves, version, pointCount, + docCount, minPackedValue, maxPackedValue, isTreeBalanced); @@ -207,6 +208,8 @@ private static class BKDPointTree implements PointTree { private final int version; // total number of points final long pointCount; + // total number of docs + final int docCount; // last node might not be fully populated private final int lastLeafNodePointCount; // right most leaf node ID @@ -228,6 +231,7 @@ private BKDPointTree( int numLeaves, int version, long pointCount, + int docCount, byte[] minPackedValue, byte[] maxPackedValue, boolean isTreeBalanced) @@ -239,6 +243,7 @@ private BKDPointTree( numLeaves, version, pointCount, + docCount, 1, 1, minPackedValue, @@ -260,6 +265,7 @@ private BKDPointTree( int numLeaves, int version, long pointCount, + int docCount, int nodeID, int level, byte[] minPackedValue, @@ -293,6 +299,7 @@ private BKDPointTree( negativeDeltas = new boolean[config.numIndexDims * treeDepth]; // information about the unbalance of the tree so we can report the exact size below a node this.pointCount = pointCount; + this.docCount = docCount; rightMostLeafNode = (1 << treeDepth - 1) - 1; int lastLeafNodePointCount = Math.toIntExact(pointCount % config.maxPointsInLeafNode); this.lastLeafNodePointCount = @@ -317,6 +324,7 @@ public PointTree clone() { leafNodeOffset, version, pointCount, + docCount, nodeID, level, minPackedValue, @@ -555,17 +563,11 @@ private int balanceTreeNodePosition( @Override public void visitDocIDs(PointValues.IntersectVisitor visitor) throws IOException { resetNodeDataPosition(); - addAll(visitor, false); + visitor.grow((int) Math.min(docCount, size())); + addAll(visitor); } - public void addAll(PointValues.IntersectVisitor visitor, boolean grown) throws IOException { - if (grown == false) { - final long size = size(); - if (size <= Integer.MAX_VALUE) { - visitor.grow((int) size); - grown = true; - } - } + public void addAll(PointValues.IntersectVisitor visitor) throws IOException { if (isLeafNode()) { // Leaf node leafNodes.seek(getLeafBlockFP()); @@ -575,10 +577,10 @@ public void addAll(PointValues.IntersectVisitor visitor, boolean grown) throws I docIdsWriter.readInts(leafNodes, count, visitor); } else { pushLeft(); - addAll(visitor, grown); + addAll(visitor); pop(); pushRight(); - addAll(visitor, grown); + addAll(visitor); pop(); } } diff --git a/lucene/core/src/test/org/apache/lucene/util/TestDocIdSetBuilder.java b/lucene/core/src/test/org/apache/lucene/util/TestDocIdSetBuilder.java index 2fa146581c68..856f44b6b95b 100644 --- a/lucene/core/src/test/org/apache/lucene/util/TestDocIdSetBuilder.java +++ b/lucene/core/src/test/org/apache/lucene/util/TestDocIdSetBuilder.java @@ -168,14 +168,14 @@ public void testMisleadingDISICost() throws IOException { public void testEmptyPoints() throws IOException { PointValues values = new DummyPointValues(0, 0); DocIdSetBuilder builder = new DocIdSetBuilder(1, values, "foo"); - assertEquals(1d, builder.numValuesPerDoc, 0d); + assertEquals(0, builder.docCount); } public void testLeverageStats() throws IOException { // single-valued points PointValues values = new DummyPointValues(42, 42); DocIdSetBuilder builder = new DocIdSetBuilder(100, values, "foo"); - assertEquals(1d, builder.numValuesPerDoc, 0d); + assertEquals(42, builder.docCount); assertFalse(builder.multivalued); DocIdSetBuilder.BulkAdder adder = builder.grow(2); adder.add(5); @@ -187,30 +187,30 @@ public void testLeverageStats() throws IOException { // multi-valued points values = new DummyPointValues(42, 63); builder = new DocIdSetBuilder(100, values, "foo"); - assertEquals(1.5, builder.numValuesPerDoc, 0d); + assertEquals(42, builder.docCount); assertTrue(builder.multivalued); - adder = builder.grow(2); + adder = builder.grow(100); adder.add(5); adder.add(7); set = builder.build(); assertTrue(set instanceof BitDocIdSet); - assertEquals(1, set.iterator().cost()); // it thinks the same doc was added twice + assertEquals(42, set.iterator().cost()); // it thinks all docs have been added // incomplete stats values = new DummyPointValues(42, -1); builder = new DocIdSetBuilder(100, values, "foo"); - assertEquals(1d, builder.numValuesPerDoc, 0d); + assertEquals(1d, builder.docCount, 42); assertTrue(builder.multivalued); values = new DummyPointValues(-1, 84); builder = new DocIdSetBuilder(100, values, "foo"); - assertEquals(1d, builder.numValuesPerDoc, 0d); + assertEquals(Integer.MAX_VALUE, builder.docCount); assertTrue(builder.multivalued); // single-valued terms Terms terms = new DummyTerms(42, 42); builder = new DocIdSetBuilder(100, terms); - assertEquals(1d, builder.numValuesPerDoc, 0d); + assertEquals(42, builder.docCount); assertFalse(builder.multivalued); adder = builder.grow(2); adder.add(5); @@ -222,24 +222,24 @@ public void testLeverageStats() throws IOException { // multi-valued terms terms = new DummyTerms(42, 63); builder = new DocIdSetBuilder(100, terms); - assertEquals(1.5, builder.numValuesPerDoc, 0d); + assertEquals(42, builder.docCount); assertTrue(builder.multivalued); - adder = builder.grow(2); + adder = builder.grow(100); adder.add(5); adder.add(7); set = builder.build(); assertTrue(set instanceof BitDocIdSet); - assertEquals(1, set.iterator().cost()); // it thinks the same doc was added twice + assertEquals(42, set.iterator().cost()); // it thinks all docs have been added // incomplete stats terms = new DummyTerms(42, -1); builder = new DocIdSetBuilder(100, terms); - assertEquals(1d, builder.numValuesPerDoc, 0d); + assertEquals(42, builder.docCount, 0d); assertTrue(builder.multivalued); terms = new DummyTerms(-1, 84); builder = new DocIdSetBuilder(100, terms); - assertEquals(1d, builder.numValuesPerDoc, 0d); + assertEquals(Integer.MAX_VALUE, builder.docCount); assertTrue(builder.multivalued); } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/AssertingLeafReader.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/AssertingLeafReader.java index a3f8b9650a1a..9c542efb4a62 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/AssertingLeafReader.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/AssertingLeafReader.java @@ -47,6 +47,7 @@ import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.VirtualMethod; import org.apache.lucene.util.automaton.CompiledAutomaton; @@ -1096,10 +1097,12 @@ public long lookupTerm(BytesRef key) throws IOException { public static class AssertingPointValues extends PointValues { private final Thread creationThread = Thread.currentThread(); private final PointValues in; + private final int maxDoc; /** Sole constructor. */ public AssertingPointValues(PointValues in, int maxDoc) { this.in = in; + this.maxDoc = maxDoc; assertStats(maxDoc); } @@ -1120,6 +1123,17 @@ public PointTree getPointTree() throws IOException { return new AssertingPointTree(in, in.getPointTree()); } + @Override + public void intersect(IntersectVisitor visitor) throws IOException { + in.intersect( + new AssertingIntersectVisitor( + maxDoc, + getNumDimensions(), + getNumIndexDimensions(), + getBytesPerDimension(), + visitor)); + } + @Override public byte[] getMinPackedValue() throws IOException { assertThread("Points", creationThread); @@ -1212,22 +1226,12 @@ public long size() { @Override public void visitDocIDs(IntersectVisitor visitor) throws IOException { - in.visitDocIDs( - new AssertingIntersectVisitor( - pointValues.getNumDimensions(), - pointValues.getNumIndexDimensions(), - pointValues.getBytesPerDimension(), - visitor)); + in.visitDocIDs(visitor); } @Override public void visitDocValues(IntersectVisitor visitor) throws IOException { - in.visitDocValues( - new AssertingIntersectVisitor( - pointValues.getNumDimensions(), - pointValues.getNumIndexDimensions(), - pointValues.getBytesPerDimension(), - visitor)); + in.visitDocValues(visitor); } } @@ -1244,12 +1248,16 @@ static class AssertingIntersectVisitor implements IntersectVisitor { final byte[] lastMinPackedValue; final byte[] lastMaxPackedValue; private Relation lastCompareResult; + private final FixedBitSet docCounter; + private final int maxDoc; private int lastDocID = -1; private int docBudget; AssertingIntersectVisitor( - int numDataDims, int numIndexDims, int bytesPerDim, IntersectVisitor in) { + int maxDoc, int numDataDims, int numIndexDims, int bytesPerDim, IntersectVisitor in) { this.in = in; + this.docCounter = new FixedBitSet(maxDoc); + this.maxDoc = maxDoc; this.numDataDims = numDataDims; this.numIndexDims = numIndexDims; this.bytesPerDim = bytesPerDim; @@ -1264,8 +1272,9 @@ static class AssertingIntersectVisitor implements IntersectVisitor { @Override public void visit(int docID) throws IOException { - assert --docBudget >= 0 : "called add() more times than the last call to grow() reserved"; - + if (docCounter.getAndSet(docID) == false) { + assert --docBudget >= 0 : "called add() more times than the last call to grow() reserved"; + } // This method, not filtering each hit, should only be invoked when the cell is inside the // query shape: assert lastCompareResult == null || lastCompareResult == Relation.CELL_INSIDE_QUERY; @@ -1274,7 +1283,9 @@ public void visit(int docID) throws IOException { @Override public void visit(int docID, byte[] packedValue) throws IOException { - assert --docBudget >= 0 : "called add() more times than the last call to grow() reserved"; + if (docCounter.getAndSet(docID) == false) { + assert --docBudget >= 0 : "called add() more times than the last call to grow() reserved"; + } // This method, to filter each doc's value, should only be invoked when the cell crosses the // query shape: @@ -1329,6 +1340,7 @@ public void visit(int docID, byte[] packedValue) throws IOException { public void grow(int count) { in.grow(count); docBudget = count; + docCounter.clear(0, maxDoc); } @Override