Skip to content

Commit

Permalink
Modify TreeTraversal usage to match new format
Browse files Browse the repository at this point in the history
Signed-off-by: Finn Carroll <carrofin@amazon.com>
  • Loading branch information
finnegancarroll committed Aug 1, 2024
1 parent 7b1f8e4 commit 9ab1447
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.mapper.NumericPointEncoder;
import org.opensearch.search.optimization.filterrewrite.Ranges;
import org.opensearch.search.optimization.filterrewrite.TreeTraversal;

import java.util.*;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -141,16 +142,15 @@ public void tearDown() throws IOException {
public Map<Integer, List<Integer>> multiRangeTraverseTree(treeState state) throws Exception {
Map<Integer, List<Integer>> mockIDCollect = new HashMap<>();

BiConsumer<Integer, List<Integer>> collectRangeIDs = (activeIndex, docIDs) -> {
TreeTraversal.RangeAwareIntersectVisitor treeVisitor = new TreeTraversal.DocCollectRangeAwareIntersectVisitor(state.pointTree, state.ranges, state.maxNumNonZeroRanges, (activeIndex, docID) -> {
if (mockIDCollect.containsKey(activeIndex)) {
mockIDCollect.get(activeIndex).addAll(docIDs);
mockIDCollect.get(activeIndex).add(docID);
} else {
mockIDCollect.put(activeIndex, docIDs);
mockIDCollect.put(activeIndex, List.of(docID));
}
};

multiRangesTraverse(state.pointTree, state.ranges, collectRangeIDs, state.maxNumNonZeroRanges);
});

multiRangesTraverse(treeVisitor);
return mockIDCollect;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,29 @@ private boolean canOptimize(boolean missing, boolean hasScript, MappedFieldType
@Override
public final void tryOptimize(PointValues values, BiConsumer<Long, Long> incrementDocCount, final LeafBucketCollector sub) throws IOException {
DateFieldMapper.DateFieldType fieldType = getFieldType();
BiConsumer<Integer, List<Integer>> collectRangeIDs = (activeIndex, docIDs) -> {
long rangeStart = LongPoint.decodeDimension(optimizationContext.getRanges().lowers[activeIndex], 0);
rangeStart = fieldType.convertNanosToMillis(rangeStart);
long ord = getBucketOrd(bucketOrdProducer().apply(rangeStart));
incrementDocCount.accept(ord, (long) docIDs.size());
};

optimizationContext.consumeDebugInfo(multiRangesTraverse(values.getPointTree(), optimizationContext.getRanges(), collectRangeIDs, getSize()));
TreeTraversal.RangeAwareIntersectVisitor treeVisitor;
if (sub != null) {
treeVisitor = new TreeTraversal.DocCollectRangeAwareIntersectVisitor(values.getPointTree(), optimizationContext.getRanges(), getSize(), (activeIndex, docID) -> {
long rangeStart = LongPoint.decodeDimension(optimizationContext.getRanges().lowers[activeIndex], 0);
rangeStart = fieldType.convertNanosToMillis(rangeStart);
long ord = getBucketOrd(bucketOrdProducer().apply(rangeStart));

try {
incrementDocCount.accept(ord, (long) 1);
sub.collect(docID, activeIndex);
} catch ( IOException ioe) {
throw new RuntimeException(ioe);
}
});
} else {
treeVisitor = new TreeTraversal.DocCountRangeAwareIntersectVisitor(values.getPointTree(), optimizationContext.getRanges(), getSize(), (activeIndex, docCount) -> {
long rangeStart = LongPoint.decodeDimension(optimizationContext.getRanges().lowers[activeIndex], 0);
rangeStart = fieldType.convertNanosToMillis(rangeStart);
long ord = getBucketOrd(bucketOrdProducer().apply(rangeStart));
incrementDocCount.accept(ord, (long) docCount);
});
}

optimizationContext.consumeDebugInfo(multiRangesTraverse(treeVisitor));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,25 +126,31 @@ protected int getSize() {

@Override
public void tryOptimize(PointValues values, BiConsumer<Long, Long> incrementDocCount, final LeafBucketCollector sub) throws IOException {
int size = getSize();

DateFieldMapper.DateFieldType fieldType = getFieldType();
BiConsumer<Integer, List<Integer>> collectRangeIDs = (activeIndex, docIDs) -> {
long rangeStart = LongPoint.decodeDimension(optimizationContext.getRanges().lowers[activeIndex], 0);
rangeStart = fieldType.convertNanosToMillis(rangeStart);
long ord = getBucketOrd(bucketOrdProducer().apply(rangeStart));
incrementDocCount.accept(ord, (long) docIDs.size());

try {
for (int docID : docIDs) {
sub.collect(docID, ord);
TreeTraversal.RangeAwareIntersectVisitor treeVisitor;
if (sub != null) {
treeVisitor = new TreeTraversal.DocCollectRangeAwareIntersectVisitor(values.getPointTree(), optimizationContext.getRanges(), getSize(), (activeIndex, docID) -> {
long rangeStart = LongPoint.decodeDimension(optimizationContext.getRanges().lowers[activeIndex], 0);
rangeStart = fieldType.convertNanosToMillis(rangeStart);
long ord = getBucketOrd(bucketOrdProducer().apply(rangeStart));

try {
incrementDocCount.accept(ord, (long) 1);
sub.collect(docID, activeIndex);
} catch ( IOException ioe) {
throw new RuntimeException(ioe);
}
} catch ( IOException ioe) {
throw new RuntimeException(ioe);
}
};
});
} else {
treeVisitor = new TreeTraversal.DocCountRangeAwareIntersectVisitor(values.getPointTree(), optimizationContext.getRanges(), getSize(), (activeIndex, docCount) -> {
long rangeStart = LongPoint.decodeDimension(optimizationContext.getRanges().lowers[activeIndex], 0);
rangeStart = fieldType.convertNanosToMillis(rangeStart);
long ord = getBucketOrd(bucketOrdProducer().apply(rangeStart));
incrementDocCount.accept(ord, (long) docCount);
});
}

optimizationContext.consumeDebugInfo(multiRangesTraverse(values.getPointTree(), optimizationContext.getRanges(), collectRangeIDs, size));
optimizationContext.consumeDebugInfo(multiRangesTraverse(treeVisitor));
}

protected static long getBucketOrd(long bucketOrd) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.io.IOException;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.List;

import static org.opensearch.search.optimization.filterrewrite.TreeTraversal.multiRangesTraverse;

Expand Down Expand Up @@ -77,22 +76,26 @@ public void prepareFromSegment(LeafReaderContext leaf) {

@Override
public final void tryOptimize(PointValues values, BiConsumer<Long, Long> incrementDocCount, final LeafBucketCollector sub) throws IOException {
TreeTraversal.RangeAwareIntersectVisitor treeVisitor;
if (sub != null) {
treeVisitor = new TreeTraversal.DocCollectRangeAwareIntersectVisitor(values.getPointTree(), optimizationContext.getRanges(), Integer.MAX_VALUE, (activeIndex, docID) -> {
long ord = bucketOrdProducer().apply(activeIndex);


BiConsumer<Integer, List<Integer>> collectRangeIDs = (activeIndex, docIDs) -> {
long ord = bucketOrdProducer().apply(activeIndex);
incrementDocCount.accept(ord, (long) docIDs.size());

try {
for (int docID : docIDs) {
try {
incrementDocCount.accept(ord, (long) 1);
sub.collect(docID, activeIndex);
} catch ( IOException ioe) {
throw new RuntimeException(ioe);
}
} catch ( IOException ioe) {
throw new RuntimeException(ioe);
}
};
});
} else {
treeVisitor = new TreeTraversal.DocCountRangeAwareIntersectVisitor(values.getPointTree(), optimizationContext.getRanges(), Integer.MAX_VALUE, (activeIndex, docCount) -> {
long ord = bucketOrdProducer().apply(activeIndex);
incrementDocCount.accept(ord, (long) docCount);
});
}

optimizationContext.consumeDebugInfo(multiRangesTraverse(values.getPointTree(), optimizationContext.getRanges(), collectRangeIDs, Integer.MAX_VALUE));
optimizationContext.consumeDebugInfo(multiRangesTraverse(treeVisitor));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ protected boolean iterateRangeEnd(byte[] packedValue) {
* 1.) activeIndex for range in which document(s) reside
* 2.) total documents counted
*/
private static abstract class DocCountRangeAwareIntersectVisitor extends RangeAwareIntersectVisitor {
public static class DocCountRangeAwareIntersectVisitor extends RangeAwareIntersectVisitor {
BiConsumer<Integer, Integer> countDocs;

public DocCountRangeAwareIntersectVisitor(
Expand Down Expand Up @@ -220,7 +220,7 @@ protected void consumeCrossedNode(PointValues.PointTree pointTree) throws IOExce
* 1.) activeIndex for range in which document(s) reside
* 2.) document id to collect
*/
private static abstract class DocCollectRangeAwareIntersectVisitor extends RangeAwareIntersectVisitor {
public static class DocCollectRangeAwareIntersectVisitor extends RangeAwareIntersectVisitor {
BiConsumer<Integer, Integer> collectDocs;

public DocCollectRangeAwareIntersectVisitor(
Expand Down

0 comments on commit 9ab1447

Please sign in to comment.