Skip to content

Commit

Permalink
Handle nulls in OrdinalsGroupingOperator (elastic#100117)
Browse files Browse the repository at this point in the history
This change introduces null handling in the OrdinalsGroupingOperator, 
replacing the current behavior which skips null keys. Ordinals are now
incremented by 1, with 0 being used to represent null ordinals.

Closes elastic#100109
  • Loading branch information
dnhatn authored and piergm committed Oct 2, 2023
1 parent e5f0cfe commit 290d989
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 94 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.lucene.BlockOrdinalsReader;
import org.elasticsearch.compute.lucene.ValueSourceInfo;
import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator;
import org.elasticsearch.compute.operator.HashAggregationOperator.GroupSpec;
Expand Down Expand Up @@ -234,18 +233,31 @@ protected boolean lessThan(AggregatedResultIterator a, AggregatedResultIterator
};
final List<GroupingAggregator> aggregators = createGroupingAggregators();
try {
boolean seenNulls = false;
for (OrdinalSegmentAggregator agg : ordinalAggregators.values()) {
if (agg.seenNulls()) {
seenNulls = true;
for (int i = 0; i < aggregators.size(); i++) {
aggregators.get(i).addIntermediateRow(0, agg.aggregators.get(i), 0);
}
}
}
for (OrdinalSegmentAggregator agg : ordinalAggregators.values()) {
final AggregatedResultIterator it = agg.getResultIterator();
if (it.next()) {
pq.add(it);
}
}
int position = -1;
final int startPosition = seenNulls ? 0 : -1;
int position = startPosition;
final BytesRefBuilder lastTerm = new BytesRefBuilder();
var blockBuilder = BytesRefBlock.newBlockBuilder(1);
if (seenNulls) {
blockBuilder.appendNull();
}
while (pq.size() > 0) {
final AggregatedResultIterator top = pq.top();
if (position == -1 || lastTerm.get().equals(top.currentTerm) == false) {
if (position == startPosition || lastTerm.get().equals(top.currentTerm) == false) {
position++;
lastTerm.copyBytes(top.currentTerm);
blockBuilder.appendBytesRef(top.currentTerm);
Expand Down Expand Up @@ -338,20 +350,17 @@ void addInput(IntVector docs, Page page) {
if (BlockOrdinalsReader.canReuse(currentReader, docs.getInt(0)) == false) {
currentReader = new BlockOrdinalsReader(withOrdinals.ordinalsValues(leafReaderContext));
}
final IntBlock ordinals = currentReader.readOrdinals(docs);
final IntBlock ordinals = currentReader.readOrdinalsAdded1(docs);
for (int p = 0; p < ordinals.getPositionCount(); p++) {
if (ordinals.isNull(p)) {
continue;
}
int start = ordinals.getFirstValueIndex(p);
int end = start + ordinals.getValueCount(p);
for (int i = start; i < end; i++) {
long ord = ordinals.getInt(i);
visitedOrds.set(ord);
}
}
for (GroupingAggregator aggregator : aggregators) {
aggregator.prepareProcessPage(this, page).add(0, ordinals);
for (GroupingAggregatorFunction.AddInput addInput : prepared) {
addInput.add(0, ordinals);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
Expand All @@ -362,6 +371,10 @@ AggregatedResultIterator getResultIterator() throws IOException {
return new AggregatedResultIterator(aggregators, visitedOrds, withOrdinals.ordinalsValues(leafReaderContext));
}

boolean seenNulls() {
return visitedOrds.get(0);
}

@Override
public BitArray seenGroupIds(BigArrays bigArrays) {
BitArray seen = new BitArray(0, bigArrays);
Expand All @@ -377,7 +390,7 @@ public void close() {

private static class AggregatedResultIterator {
private BytesRef currentTerm;
private long currentOrd = -1;
private long currentOrd = 0;
private final List<GroupingAggregator> aggregators;
private final BitArray ords;
private final SortedSetDocValues dv;
Expand All @@ -395,8 +408,9 @@ int currentPosition() {

boolean next() throws IOException {
currentOrd = ords.nextSetBit(currentOrd + 1);
assert currentOrd > 0 : currentOrd;
if (currentOrd < Long.MAX_VALUE) {
currentTerm = dv.lookupOrd(currentOrd);
currentTerm = dv.lookupOrd(currentOrd - 1);
return true;
} else {
currentTerm = null;
Expand Down Expand Up @@ -448,4 +462,49 @@ public void close() {
Releasables.close(extractor, aggregator);
}
}

static final class BlockOrdinalsReader {
private final SortedSetDocValues sortedSetDocValues;
private final Thread creationThread;

BlockOrdinalsReader(SortedSetDocValues sortedSetDocValues) {
this.sortedSetDocValues = sortedSetDocValues;
this.creationThread = Thread.currentThread();
}

IntBlock readOrdinalsAdded1(IntVector docs) throws IOException {
final int positionCount = docs.getPositionCount();
IntBlock.Builder builder = IntBlock.newBlockBuilder(positionCount);
for (int p = 0; p < positionCount; p++) {
int doc = docs.getInt(p);
if (false == sortedSetDocValues.advanceExact(doc)) {
builder.appendInt(0);
continue;
}
int count = sortedSetDocValues.docValueCount();
// TODO don't come this way if there are a zillion ords on the field
if (count == 1) {
builder.appendInt(Math.toIntExact(sortedSetDocValues.nextOrd() + 1));
continue;
}
builder.beginPositionEntry();
for (int i = 0; i < count; i++) {
builder.appendInt(Math.toIntExact(sortedSetDocValues.nextOrd() + 1));
}
builder.endPositionEntry();
}
return builder.build();
}

int docID() {
return sortedSetDocValues.docID();
}

/**
* Checks if the reader can be used to read a range documents starting with the given docID by the current thread.
*/
static boolean canReuse(BlockOrdinalsReader reader, int startingDocID) {
return reader != null && reader.creationThread == Thread.currentThread() && reader.docID() <= startingDocID;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -569,3 +569,12 @@ ca:l | cx:l | l:i
1 | 1 | 5
1 | 1 | null
;

aggsWithoutStats
from employees | stats by gender | sort gender;

gender:keyword
F
M
null
;
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import org.elasticsearch.Build;
import org.elasticsearch.action.admin.indices.alias.IndicesAliasesRequest;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.delete.DeleteRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.support.WriteRequest;
Expand All @@ -34,6 +33,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
Expand Down Expand Up @@ -265,7 +265,7 @@ public void testFromStatsGroupingByKeywordWithNulls() {
EsqlQueryResponse results = run("from test | stats avg = avg(" + field + ") by color");
logger.info(results);
Assert.assertEquals(2, results.columns().size());
Assert.assertEquals(4, getValuesList(results).size());
Assert.assertEquals(5, getValuesList(results).size());

// assert column metadata
assertEquals("avg", results.columns().get(0).name());
Expand All @@ -276,25 +276,18 @@ record Group(String color, Double avg) {

}
List<Group> expectedGroups = List.of(
new Group(null, 120.0),
new Group("blue", 42.0),
new Group("green", 44.0),
new Group("red", 43.0),
new Group("yellow", null)
);
List<Group> actualGroups = getValuesList(results).stream()
.map(l -> new Group((String) l.get(1), (Double) l.get(0)))
.sorted(comparing(c -> c.color))
.sorted(Comparator.comparing(c -> c.color, Comparator.nullsFirst(String::compareTo)))
.toList();
assertThat(actualGroups, equalTo(expectedGroups));
}
for (int i = 0; i < 5; i++) {
client().prepareBulk()
.add(new DeleteRequest("test").id("no_color_" + i))
.add(new DeleteRequest("test").id("no_count_red_" + i))
.add(new DeleteRequest("test").id("no_count_yellow_" + i))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get();
}
}

public void testFromStatsMultipleAggs() {
Expand Down Expand Up @@ -562,11 +555,6 @@ public void testFilterWithNullAndEvalFromIndex() {
assertThat(results.columns(), hasItem(equalTo(new ColumnInfo("data", "long"))));
assertThat(results.columns(), hasItem(equalTo(new ColumnInfo("data_d", "double"))));
assertThat(results.columns(), hasItem(equalTo(new ColumnInfo("time", "long"))));

// restore index to original pre-test state
client().prepareBulk().add(new DeleteRequest("test").id("no_count")).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).get();
results = run("from test");
Assert.assertEquals(40, getValuesList(results).size());
}

public void testMultiConditionalWhere() {
Expand Down Expand Up @@ -963,9 +951,6 @@ public void testInWithNullValue() {
}

public void testTopNPushedToLucene() {
BulkRequestBuilder bulkDelete = client().prepareBulk();
bulkDelete.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

for (int i = 5; i < 11; i++) {
var yellowDocId = "yellow_" + i;
var yellowNullCountDocId = "yellow_null_count_" + i;
Expand All @@ -979,11 +964,6 @@ public void testTopNPushedToLucene() {
if (randomBoolean()) {
client().admin().indices().prepareRefresh("test").get();
}

// build the cleanup request now, as well, not to miss anything ;-)
bulkDelete.add(new DeleteRequest("test").id(yellowDocId))
.add(new DeleteRequest("test").id(yellowNullCountDocId))
.add(new DeleteRequest("test").id(yellowNullDataDocId));
}
client().admin().indices().prepareRefresh("test").get();

Expand Down

0 comments on commit 290d989

Please sign in to comment.