diff --git a/parquet-column/src/main/java/org/apache/parquet/column/MinMax.java b/parquet-column/src/main/java/org/apache/parquet/column/MinMax.java new file mode 100644 index 0000000000..c97b681b5a --- /dev/null +++ b/parquet-column/src/main/java/org/apache/parquet/column/MinMax.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.parquet.column; + +import org.apache.parquet.schema.PrimitiveComparator; + +/** + * This class calculates the max and min values of an iterable collection. + */ +public final class MinMax { + private T min = null; + private T max = null; + + public MinMax(PrimitiveComparator comparator, Iterable iterable) { + getMinAndMax(comparator, iterable); + } + + public T getMin() { + return min; + } + + public T getMax() { + return max; + } + + private void getMinAndMax(PrimitiveComparator comparator, Iterable iterable) { + iterable.forEach(element -> { + if (max == null) { + max = element; + } else if (element != null && comparator.compare(max, element) < 0) { + max = element; + } + if (min == null) { + min = element; + } else if (element != null && comparator.compare(min, element) > 0) { + min = element; + } + }); + } +} diff --git a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/FilterApi.java b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/FilterApi.java index b209fc7b6f..a0490d9ac9 100644 --- a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/FilterApi.java +++ b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/FilterApi.java @@ -19,6 +19,7 @@ package org.apache.parquet.filter2.predicate; import java.io.Serializable; +import java.util.Set; import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.filter2.predicate.Operators.And; @@ -30,12 +31,14 @@ import org.apache.parquet.filter2.predicate.Operators.FloatColumn; import org.apache.parquet.filter2.predicate.Operators.Gt; import org.apache.parquet.filter2.predicate.Operators.GtEq; +import org.apache.parquet.filter2.predicate.Operators.In; import org.apache.parquet.filter2.predicate.Operators.IntColumn; import org.apache.parquet.filter2.predicate.Operators.LongColumn; import org.apache.parquet.filter2.predicate.Operators.Lt; import org.apache.parquet.filter2.predicate.Operators.LtEq; import org.apache.parquet.filter2.predicate.Operators.Not; import org.apache.parquet.filter2.predicate.Operators.NotEq; +import org.apache.parquet.filter2.predicate.Operators.NotIn; import org.apache.parquet.filter2.predicate.Operators.Or; import org.apache.parquet.filter2.predicate.Operators.SupportsEqNotEq; import org.apache.parquet.filter2.predicate.Operators.SupportsLtGt; @@ -204,6 +207,56 @@ public static , C extends Column & SupportsLtGt> GtEq return new GtEq<>(column, value); } + /** + * Keeps records if their value is in the provided values. + * The provided values set could not be null, but could contains a null value. + *

+ * For example: + *

+   *   {@code
+   *   Set set = new HashSet<>();
+   *   set.add(9);
+   *   set.add(null);
+   *   set.add(50);
+   *   in(column, set);}
+   * 
+ * will keep all records whose values are 9, null, or 50. + * + * @param column a column reference created by FilterApi + * @param values a set of values that match the column's type + * @param the Java type of values in the column + * @param the column type that corresponds to values of type T + * @return an in predicate for the given column and value + */ + public static , C extends Column & SupportsEqNotEq> In in(C column, Set values) { + return new In<>(column, values); + } + + /** + * Keeps records if their value is not in the provided values. + * The provided values set could not be null, but could contains a null value. + *

+ * For example: + *

+   *   {@code
+   *   Set set = new HashSet<>();
+   *   set.add(9);
+   *   set.add(null);
+   *   set.add(50);
+   *   notIn(column, set);}
+   * 
+ * will keep all records whose values are not 9, null, and 50. + * + * @param column a column reference created by FilterApi + * @param values a set of values that match the column's type + * @param the Java type of values in the column + * @param the column type that corresponds to values of type T + * @return an notIn predicate for the given column and value + */ + public static , C extends Column & SupportsEqNotEq> NotIn notIn(C column, Set values) { + return new NotIn<>(column, values); + } + /** * Keeps records that pass the provided {@link UserDefinedPredicate} *

diff --git a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/FilterPredicate.java b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/FilterPredicate.java index 211c71e6d7..d9156c2544 100644 --- a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/FilterPredicate.java +++ b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/FilterPredicate.java @@ -22,11 +22,13 @@ import org.apache.parquet.filter2.predicate.Operators.Eq; import org.apache.parquet.filter2.predicate.Operators.Gt; import org.apache.parquet.filter2.predicate.Operators.GtEq; +import org.apache.parquet.filter2.predicate.Operators.In; import org.apache.parquet.filter2.predicate.Operators.LogicalNotUserDefined; import org.apache.parquet.filter2.predicate.Operators.Lt; import org.apache.parquet.filter2.predicate.Operators.LtEq; import org.apache.parquet.filter2.predicate.Operators.Not; import org.apache.parquet.filter2.predicate.Operators.NotEq; +import org.apache.parquet.filter2.predicate.Operators.NotIn; import org.apache.parquet.filter2.predicate.Operators.Or; import org.apache.parquet.filter2.predicate.Operators.UserDefined; @@ -66,6 +68,12 @@ public static interface Visitor { > R visit(LtEq ltEq); > R visit(Gt gt); > R visit(GtEq gtEq); + default > R visit(In in) { + throw new UnsupportedOperationException("visit in is not supported."); + } + default > R visit(NotIn notIn) { + throw new UnsupportedOperationException("visit NotIn is not supported."); + } R visit(And and); R visit(Or or); R visit(Not not); diff --git a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/LogicalInverseRewriter.java b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/LogicalInverseRewriter.java index 88cb836e2c..49b862f602 100644 --- a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/LogicalInverseRewriter.java +++ b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/LogicalInverseRewriter.java @@ -23,11 +23,13 @@ import org.apache.parquet.filter2.predicate.Operators.Eq; import org.apache.parquet.filter2.predicate.Operators.Gt; import org.apache.parquet.filter2.predicate.Operators.GtEq; +import org.apache.parquet.filter2.predicate.Operators.In; import org.apache.parquet.filter2.predicate.Operators.LogicalNotUserDefined; import org.apache.parquet.filter2.predicate.Operators.Lt; import org.apache.parquet.filter2.predicate.Operators.LtEq; import org.apache.parquet.filter2.predicate.Operators.Not; import org.apache.parquet.filter2.predicate.Operators.NotEq; +import org.apache.parquet.filter2.predicate.Operators.NotIn; import org.apache.parquet.filter2.predicate.Operators.Or; import org.apache.parquet.filter2.predicate.Operators.UserDefined; @@ -87,6 +89,16 @@ public > FilterPredicate visit(GtEq gtEq) { return gtEq; } + @Override + public > FilterPredicate visit(In in) { + return in; + } + + @Override + public > FilterPredicate visit(NotIn notIn) { + return notIn; + } + @Override public FilterPredicate visit(And and) { return and(and.getLeft().accept(this), and.getRight().accept(this)); diff --git a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/LogicalInverter.java b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/LogicalInverter.java index cc0186b8b7..b95d473ef2 100644 --- a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/LogicalInverter.java +++ b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/LogicalInverter.java @@ -23,11 +23,13 @@ import org.apache.parquet.filter2.predicate.Operators.Eq; import org.apache.parquet.filter2.predicate.Operators.Gt; import org.apache.parquet.filter2.predicate.Operators.GtEq; +import org.apache.parquet.filter2.predicate.Operators.In; import org.apache.parquet.filter2.predicate.Operators.LogicalNotUserDefined; import org.apache.parquet.filter2.predicate.Operators.Lt; import org.apache.parquet.filter2.predicate.Operators.LtEq; import org.apache.parquet.filter2.predicate.Operators.Not; import org.apache.parquet.filter2.predicate.Operators.NotEq; +import org.apache.parquet.filter2.predicate.Operators.NotIn; import org.apache.parquet.filter2.predicate.Operators.Or; import org.apache.parquet.filter2.predicate.Operators.UserDefined; @@ -81,6 +83,16 @@ public > FilterPredicate visit(GtEq gtEq) { return new Lt<>(gtEq.getColumn(), gtEq.getValue()); } + @Override + public > FilterPredicate visit(In in) { + return new NotIn<>(in.getColumn(), in.getValues()); + } + + @Override + public > FilterPredicate visit(NotIn notIn) { + return new In<>(notIn.getColumn(), notIn.getValues()); + } + @Override public FilterPredicate visit(And and) { return new Or(and.getLeft().accept(this), and.getRight().accept(this)); diff --git a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/Operators.java b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/Operators.java index 9a1696c411..d52aa92495 100644 --- a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/Operators.java +++ b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/Operators.java @@ -21,10 +21,13 @@ import java.io.Serializable; import java.util.Locale; import java.util.Objects; +import java.util.Set; import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.io.api.Binary; +import static org.apache.parquet.Preconditions.checkArgument; + /** * These are the operators in a filter predicate expression tree. * They are constructed by using the methods in {@link FilterApi} @@ -169,7 +172,7 @@ public int hashCode() { public static final class Eq> extends ColumnFilterPredicate { // value can be null - Eq(Column column, T value) { + public Eq(Column column, T value) { super(column, value); } @@ -247,6 +250,82 @@ public R accept(Visitor visitor) { } } + /** + * Base class for {@link In} and {@link NotIn}. {@link In} is used to filter data based on a list of values. + * {@link NotIn} is used to filter data that are not in the list of values. + */ + public static abstract class SetColumnFilterPredicate> implements FilterPredicate, Serializable { + private final Column column; + private final Set values; + + protected SetColumnFilterPredicate(Column column, Set values) { + this.column = Objects.requireNonNull(column, "column cannot be null"); + this.values = Objects.requireNonNull(values, "values cannot be null"); + checkArgument(!values.isEmpty(), "values in SetColumnFilterPredicate shouldn't be empty!"); + } + + public Column getColumn() { + return column; + } + + public Set getValues() { + return values; + } + + @Override + public String toString() { + String name = getClass().getSimpleName().toLowerCase(Locale.ENGLISH); + StringBuilder str = new StringBuilder(); + str.append(name).append("(").append(column.getColumnPath().toDotString()).append(", "); + int iter = 0; + for (T value : values) { + if (iter >= 100) break; + str.append(value).append(", "); + iter++; + } + int length = str.length(); + str = values.size() <= 100 ? str.delete(length - 2, length) : str.append("..."); + return str.append(")").toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SetColumnFilterPredicate that = (SetColumnFilterPredicate) o; + return column.equals(that.column) && values.equals(that.values); + } + + @Override + public int hashCode() { + return Objects.hash(column, values); + } + } + + public static final class In> extends SetColumnFilterPredicate { + + public In(Column column, Set values) { + super(column, values); + } + + @Override + public R accept(Visitor visitor) { + return visitor.visit(this); + } + } + + public static final class NotIn> extends SetColumnFilterPredicate { + + NotIn(Column column, Set values) { + super(column, values); + } + + @Override + public R accept(Visitor visitor) { + return visitor.visit(this); + } + } + // base class for And, Or private static abstract class BinaryLogicalFilterPredicate implements FilterPredicate, Serializable { private final FilterPredicate left; diff --git a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/SchemaCompatibilityValidator.java b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/SchemaCompatibilityValidator.java index c75036bbdc..49fd10cc81 100644 --- a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/SchemaCompatibilityValidator.java +++ b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/SchemaCompatibilityValidator.java @@ -29,12 +29,15 @@ import org.apache.parquet.filter2.predicate.Operators.Eq; import org.apache.parquet.filter2.predicate.Operators.Gt; import org.apache.parquet.filter2.predicate.Operators.GtEq; +import org.apache.parquet.filter2.predicate.Operators.In; import org.apache.parquet.filter2.predicate.Operators.LogicalNotUserDefined; import org.apache.parquet.filter2.predicate.Operators.Lt; import org.apache.parquet.filter2.predicate.Operators.LtEq; import org.apache.parquet.filter2.predicate.Operators.Not; import org.apache.parquet.filter2.predicate.Operators.NotEq; +import org.apache.parquet.filter2.predicate.Operators.NotIn; import org.apache.parquet.filter2.predicate.Operators.Or; +import org.apache.parquet.filter2.predicate.Operators.SetColumnFilterPredicate; import org.apache.parquet.filter2.predicate.Operators.UserDefined; import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.schema.MessageType; @@ -114,6 +117,18 @@ public > Void visit(GtEq pred) { return null; } + @Override + public > Void visit(In pred) { + validateColumnFilterPredicate(pred); + return null; + } + + @Override + public > Void visit(NotIn pred) { + validateColumnFilterPredicate(pred); + return null; + } + @Override public Void visit(And and) { and.getLeft().accept(this); @@ -149,6 +164,10 @@ private > void validateColumnFilterPredicate(ColumnFilte validateColumn(pred.getColumn()); } + private > void validateColumnFilterPredicate(SetColumnFilterPredicate pred) { + validateColumn(pred.getColumn()); + } + private > void validateColumn(Column column) { ColumnPath path = column.getColumnPath(); diff --git a/parquet-column/src/main/java/org/apache/parquet/internal/column/columnindex/ColumnIndexBuilder.java b/parquet-column/src/main/java/org/apache/parquet/internal/column/columnindex/ColumnIndexBuilder.java index 15be50e55d..fc3859b9ca 100644 --- a/parquet-column/src/main/java/org/apache/parquet/internal/column/columnindex/ColumnIndexBuilder.java +++ b/parquet-column/src/main/java/org/apache/parquet/internal/column/columnindex/ColumnIndexBuilder.java @@ -23,21 +23,28 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Formatter; +import java.util.Iterator; import java.util.List; import java.util.PrimitiveIterator; +import java.util.Set; +import java.util.function.IntConsumer; import java.util.function.IntPredicate; +import org.apache.parquet.column.MinMax; import org.apache.parquet.column.statistics.Statistics; import org.apache.parquet.filter2.predicate.Operators.And; import org.apache.parquet.filter2.predicate.Operators.Eq; import org.apache.parquet.filter2.predicate.Operators.Gt; import org.apache.parquet.filter2.predicate.Operators.GtEq; +import org.apache.parquet.filter2.predicate.Operators.In; import org.apache.parquet.filter2.predicate.Operators.LogicalNotUserDefined; import org.apache.parquet.filter2.predicate.Operators.Lt; import org.apache.parquet.filter2.predicate.Operators.LtEq; import org.apache.parquet.filter2.predicate.Operators.Not; import org.apache.parquet.filter2.predicate.Operators.NotEq; +import org.apache.parquet.filter2.predicate.Operators.NotIn; import org.apache.parquet.filter2.predicate.Operators.Or; +import org.apache.parquet.filter2.predicate.Operators.SetColumnFilterPredicate; import org.apache.parquet.filter2.predicate.Operators.UserDefined; import org.apache.parquet.filter2.predicate.UserDefinedPredicate; import org.apache.parquet.io.api.Binary; @@ -287,6 +294,69 @@ public > PrimitiveIterator.OfInt visit(NotEq notEq) { pageIndex -> nullCounts[pageIndex] > 0 || matchingIndexes.contains(pageIndex)); } + @Override + public > PrimitiveIterator.OfInt visit(In in) { + Set values = in.getValues(); + IntSet matchingIndexesForNull = new IntOpenHashSet(); // for null + Iterator it = values.iterator(); + while(it.hasNext()) { + T value = it.next(); + if (value == null) { + if (nullCounts == null) { + // Searching for nulls so if we don't have null related statistics we have to return all pages + return IndexIterator.all(getPageCount()); + } else { + for (int i = 0; i < nullCounts.length; i++) { + if (nullCounts[i] > 0) { + matchingIndexesForNull.add(i); + } + } + if (values.size() == 1) { + return IndexIterator.filter(getPageCount(), pageIndex -> matchingIndexesForNull.contains(pageIndex)); + } + } + } + } + + IntSet matchingIndexesLessThanMax = new IntOpenHashSet(); + IntSet matchingIndexesGreaterThanMin = new IntOpenHashSet(); + + MinMax minMax = new MinMax(comparator, values); + T min = minMax.getMin(); + T max = minMax.getMax(); + + // We don't want to iterate through each of the values in the IN set to compare, + // because the size of the IN set might be very large. Instead, we want to only + // compare the max and min value of the IN set to see if the page might contain the + // values in the IN set. + // If there might be values in a page that are <= the max value in the IN set, + // and >= the min value in the IN set, then the page might contain + // the values in the IN set. + getBoundaryOrder().ltEq(createValueComparator(max)) + .forEachRemaining((int index) -> matchingIndexesLessThanMax.add(index)); + getBoundaryOrder().gtEq(createValueComparator(min)) + .forEachRemaining((int index) -> matchingIndexesGreaterThanMin.add(index)); + matchingIndexesLessThanMax.retainAll(matchingIndexesGreaterThanMin); + IntSet matchingIndex = matchingIndexesLessThanMax; + matchingIndex.addAll(matchingIndexesForNull); // add the matching null pages + return IndexIterator.filter(getPageCount(), pageIndex -> matchingIndex.contains(pageIndex)); + } + + @Override + public > PrimitiveIterator.OfInt visit(NotIn notIn) { + IntSet indexes = getMatchingIndexes(notIn); + return IndexIterator.filter(getPageCount(), pageIndex -> !indexes.contains(pageIndex)); + } + + private > IntSet getMatchingIndexes(SetColumnFilterPredicate in) { + IntSet matchingIndexes = new IntOpenHashSet(); + for (T value : in.getValues()) { + Eq eq = new Eq<>(in.getColumn(), value); + visit(eq).forEachRemaining((IntConsumer) matchingIndexes::add); + } + return matchingIndexes; + } + @Override public , U extends UserDefinedPredicate> PrimitiveIterator.OfInt visit( UserDefined udp) { diff --git a/parquet-column/src/main/java/org/apache/parquet/internal/filter2/columnindex/ColumnIndexFilter.java b/parquet-column/src/main/java/org/apache/parquet/internal/filter2/columnindex/ColumnIndexFilter.java index 6dec7741dd..6c27f98097 100644 --- a/parquet-column/src/main/java/org/apache/parquet/internal/filter2/columnindex/ColumnIndexFilter.java +++ b/parquet-column/src/main/java/org/apache/parquet/internal/filter2/columnindex/ColumnIndexFilter.java @@ -27,6 +27,7 @@ import org.apache.parquet.filter2.compat.FilterCompat.NoOpFilter; import org.apache.parquet.filter2.compat.FilterCompat.UnboundRecordFilterCompat; import org.apache.parquet.filter2.predicate.FilterPredicate.Visitor; +import org.apache.parquet.filter2.predicate.Operators; import org.apache.parquet.filter2.predicate.Operators.And; import org.apache.parquet.filter2.predicate.Operators.Column; import org.apache.parquet.filter2.predicate.Operators.Eq; @@ -146,6 +147,18 @@ public > RowRanges visit(GtEq gtEq) { return applyPredicate(gtEq.getColumn(), ci -> ci.visit(gtEq), RowRanges.EMPTY); } + @Override + public > RowRanges visit(Operators.In in) { + boolean isNull = in.getValues().contains(null); + return applyPredicate(in.getColumn(), ci -> ci.visit(in), isNull ? allRows() : RowRanges.EMPTY); + } + + @Override + public > RowRanges visit(Operators.NotIn notIn) { + boolean isNull = notIn.getValues().contains(null); + return applyPredicate(notIn.getColumn(), ci -> ci.visit(notIn), isNull ? RowRanges.EMPTY : allRows()); + } + @Override public , U extends UserDefinedPredicate> RowRanges visit(UserDefined udp) { return applyPredicate(udp.getColumn(), ci -> ci.visit(udp), diff --git a/parquet-column/src/test/java/org/apache/parquet/internal/column/columnindex/TestColumnIndexBuilder.java b/parquet-column/src/test/java/org/apache/parquet/internal/column/columnindex/TestColumnIndexBuilder.java index 5a3947c980..9c1d4dcedd 100644 --- a/parquet-column/src/test/java/org/apache/parquet/internal/column/columnindex/TestColumnIndexBuilder.java +++ b/parquet-column/src/test/java/org/apache/parquet/internal/column/columnindex/TestColumnIndexBuilder.java @@ -26,11 +26,13 @@ import static org.apache.parquet.filter2.predicate.FilterApi.floatColumn; import static org.apache.parquet.filter2.predicate.FilterApi.gt; import static org.apache.parquet.filter2.predicate.FilterApi.gtEq; +import static org.apache.parquet.filter2.predicate.FilterApi.in; import static org.apache.parquet.filter2.predicate.FilterApi.intColumn; import static org.apache.parquet.filter2.predicate.FilterApi.longColumn; import static org.apache.parquet.filter2.predicate.FilterApi.lt; import static org.apache.parquet.filter2.predicate.FilterApi.ltEq; import static org.apache.parquet.filter2.predicate.FilterApi.notEq; +import static org.apache.parquet.filter2.predicate.FilterApi.notIn; import static org.apache.parquet.filter2.predicate.FilterApi.userDefined; import static org.apache.parquet.filter2.predicate.LogicalInverter.invert; import static org.apache.parquet.schema.OriginalType.DECIMAL; @@ -55,7 +57,9 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Comparator; +import java.util.HashSet; import java.util.List; +import java.util.Set; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.statistics.Statistics; @@ -274,6 +278,13 @@ public void testBuildBinaryDecimal() { decimalBinary("87656273")); assertCorrectFiltering(columnIndex, eq(col, decimalBinary("0.0")), 1, 4); assertCorrectFiltering(columnIndex, eq(col, null), 0, 2, 3, 5, 6); + Set set1 = new HashSet<>(); + set1.add(Binary.fromString("0.0")); + assertCorrectFiltering(columnIndex, in(col, set1), 1, 4); + assertCorrectFiltering(columnIndex, notIn(col, set1), 0, 2, 3, 5, 6, 7); + set1.add(null); + assertCorrectFiltering(columnIndex, in(col, set1), 0, 1, 2, 3, 4, 5, 6); + assertCorrectFiltering(columnIndex, notIn(col, set1), 7); assertCorrectFiltering(columnIndex, notEq(col, decimalBinary("87656273")), 0, 1, 2, 3, 4, 5, 6); assertCorrectFiltering(columnIndex, notEq(col, null), 1, 2, 4, 7); assertCorrectFiltering(columnIndex, gt(col, decimalBinary("2348978.45")), 1); @@ -319,6 +330,13 @@ public void testBuildBinaryDecimal() { null); assertCorrectFiltering(columnIndex, eq(col, decimalBinary("87656273")), 2, 4); assertCorrectFiltering(columnIndex, eq(col, null), 0, 3, 5, 6, 7); + Set set2 = new HashSet<>(); + set2.add(decimalBinary("87656273")); + assertCorrectFiltering(columnIndex, in(col, set2), 2, 4); + assertCorrectFiltering(columnIndex, notIn(col, set2), 0, 1, 3, 5, 6, 7); + set2.add(null); + assertCorrectFiltering(columnIndex, in(col, set2), 0, 2, 3, 4, 5, 6, 7); + assertCorrectFiltering(columnIndex, notIn(col, set2), 1); assertCorrectFiltering(columnIndex, notEq(col, decimalBinary("87656273")), 0, 1, 2, 3, 5, 6, 7); assertCorrectFiltering(columnIndex, notEq(col, null), 1, 2, 4, 6); assertCorrectFiltering(columnIndex, gt(col, decimalBinary("87656273")), 6); @@ -364,6 +382,13 @@ public void testBuildBinaryDecimal() { decimalBinary("-9999293.23")); assertCorrectFiltering(columnIndex, eq(col, decimalBinary("1234567890.12")), 2, 4); assertCorrectFiltering(columnIndex, eq(col, null), 0, 1, 2, 3, 6); + Set set3 = new HashSet<>(); + set3.add(decimalBinary("1234567890.12")); + assertCorrectFiltering(columnIndex, in(col, set3), 2, 4); + assertCorrectFiltering(columnIndex, notIn(col, set3), 0, 1, 3, 5, 6, 7); + set3.add(null); + assertCorrectFiltering(columnIndex, in(col, set3), 0, 1, 2, 3, 4, 6); + assertCorrectFiltering(columnIndex, notIn(col, set3), 5, 7); assertCorrectFiltering(columnIndex, notEq(col, decimalBinary("0.0")), 0, 1, 2, 3, 4, 5, 6, 7); assertCorrectFiltering(columnIndex, notEq(col, null), 2, 4, 5, 7); assertCorrectFiltering(columnIndex, gt(col, decimalBinary("1234567890.12"))); @@ -417,6 +442,13 @@ public void testBuildBinaryUtf8() { null); assertCorrectFiltering(columnIndex, eq(col, stringBinary("Marvin")), 1, 4, 5); assertCorrectFiltering(columnIndex, eq(col, null), 0, 1, 2, 3, 5, 7); + Set set1 = new HashSet<>(); + set1.add(stringBinary("Marvin")); + assertCorrectFiltering(columnIndex, in(col, set1), 1, 4, 5); + assertCorrectFiltering(columnIndex, notIn(col, set1), 0, 2, 3, 6, 7); + set1.add(null); + assertCorrectFiltering(columnIndex, in(col, set1), 0, 1, 2, 3, 4, 5, 7); + assertCorrectFiltering(columnIndex, notIn(col, set1), 6); assertCorrectFiltering(columnIndex, notEq(col, stringBinary("Beeblebrox")), 0, 1, 2, 3, 4, 5, 7); assertCorrectFiltering(columnIndex, notEq(col, null), 1, 4, 5, 6); assertCorrectFiltering(columnIndex, gt(col, stringBinary("Prefect")), 1, 5); @@ -462,6 +494,13 @@ public void testBuildBinaryUtf8() { null); assertCorrectFiltering(columnIndex, eq(col, stringBinary("Jeltz")), 3, 4); assertCorrectFiltering(columnIndex, eq(col, null), 0, 1, 2, 4, 5, 7); + Set set2 = new HashSet<>(); + set2.add( stringBinary("Jeltz")); + assertCorrectFiltering(columnIndex, in(col, set2), 3, 4); + assertCorrectFiltering(columnIndex, notIn(col, set2), 0, 1, 2, 5, 6, 7); + set2.add(null); + assertCorrectFiltering(columnIndex, in(col, set2), 0, 1, 2, 3, 4, 5, 7); + assertCorrectFiltering(columnIndex, notIn(col, set2), 6); assertCorrectFiltering(columnIndex, notEq(col, stringBinary("Slartibartfast")), 0, 1, 2, 3, 4, 5, 7); assertCorrectFiltering(columnIndex, notEq(col, null), 0, 3, 4, 6); assertCorrectFiltering(columnIndex, gt(col, stringBinary("Marvin")), 4, 6); @@ -507,6 +546,13 @@ public void testBuildBinaryUtf8() { stringBinary("Beeblebrox")); assertCorrectFiltering(columnIndex, eq(col, stringBinary("Marvin")), 3); assertCorrectFiltering(columnIndex, eq(col, null), 0, 2, 3, 5, 6, 7); + Set set3 = new HashSet<>(); + set3.add(stringBinary("Marvin")); + assertCorrectFiltering(columnIndex, in(col, set3), 3); + assertCorrectFiltering(columnIndex, notIn(col, set3), 0, 1, 2, 4, 5, 6, 7); + set3.add(null); + assertCorrectFiltering(columnIndex, in(col, set3), 0, 2, 3, 5, 6, 7); + assertCorrectFiltering(columnIndex, notIn(col, set3), 1, 4); assertCorrectFiltering(columnIndex, notEq(col, stringBinary("Dent")), 0, 1, 2, 3, 5, 6, 7); assertCorrectFiltering(columnIndex, notEq(col, null), 1, 3, 4, 7); assertCorrectFiltering(columnIndex, gt(col, stringBinary("Prefect")), 1); @@ -615,6 +661,13 @@ public void testFilterWithoutNullCounts() { BinaryColumn col = binaryColumn("test_col"); assertCorrectFiltering(columnIndex, eq(col, stringBinary("Dent")), 2, 3); assertCorrectFiltering(columnIndex, eq(col, null), 0, 1, 2, 3, 4, 5, 6, 7); + Set set = new HashSet<>(); + set.add(stringBinary("Dent")); + assertCorrectFiltering(columnIndex, in(col, set), 2, 3); + assertCorrectFiltering(columnIndex, notIn(col, set), 0, 1, 4, 5, 6, 7); + set.add(null); + assertCorrectFiltering(columnIndex, in(col, set), 0, 1, 2, 3, 4, 5, 6, 7); + assertCorrectFiltering(columnIndex, notIn(col, set), new int[0]); assertCorrectFiltering(columnIndex, notEq(col, stringBinary("Dent")), 0, 1, 2, 3, 4, 5, 6, 7); assertCorrectFiltering(columnIndex, notEq(col, null), 2, 3, 5, 7); assertCorrectFiltering(columnIndex, userDefined(col, BinaryDecimalIsNullOrZeroUdp.class), 0, 1, 2, 3, 4, 5, 6, 7); @@ -646,6 +699,13 @@ public void testBuildBoolean() { assertCorrectValues(columnIndex.getMinValues(), false, false, true, null, false); assertCorrectFiltering(columnIndex, eq(col, true), 0, 1, 2); assertCorrectFiltering(columnIndex, eq(col, null), 1, 2, 3); + Set set1 = new HashSet<>(); + set1.add(true); + assertCorrectFiltering(columnIndex, in(col, set1), 0, 1, 2); + assertCorrectFiltering(columnIndex, notIn(col, set1), 3, 4); + set1.add(null); + assertCorrectFiltering(columnIndex, in(col, set1), 0, 1, 2, 3); + assertCorrectFiltering(columnIndex, notIn(col, set1), 4); assertCorrectFiltering(columnIndex, notEq(col, true), 0, 1, 2, 3, 4); assertCorrectFiltering(columnIndex, notEq(col, null), 0, 1, 2, 4); assertCorrectFiltering(columnIndex, userDefined(col, BooleanIsTrueOrNull.class), 0, 1, 2, 3); @@ -670,6 +730,13 @@ public void testBuildBoolean() { assertCorrectValues(columnIndex.getMinValues(), null, false, null, null, false, false, null); assertCorrectFiltering(columnIndex, eq(col, true), 4, 5); assertCorrectFiltering(columnIndex, eq(col, null), 0, 2, 3, 4, 5, 6); + Set set2 = new HashSet<>(); + set2.add(true); + assertCorrectFiltering(columnIndex, in(col, set2), 4, 5); + assertCorrectFiltering(columnIndex, notIn(col, set2), 0, 1, 2, 3, 6); + set2.add(null); + assertCorrectFiltering(columnIndex, in(col, set2), 0, 2, 3, 4, 5, 6); + assertCorrectFiltering(columnIndex, notIn(col, set2), 1); assertCorrectFiltering(columnIndex, notEq(col, true), 0, 1, 2, 3, 4, 5, 6); assertCorrectFiltering(columnIndex, notEq(col, null), 1, 4, 5); assertCorrectFiltering(columnIndex, userDefined(col, BooleanIsTrueOrNull.class), 0, 2, 3, 4, 5, 6); @@ -694,6 +761,13 @@ public void testBuildBoolean() { assertCorrectValues(columnIndex.getMinValues(), null, true, null, null, false, false, null); assertCorrectFiltering(columnIndex, eq(col, true), 1, 4); assertCorrectFiltering(columnIndex, eq(col, null), 0, 2, 3, 4, 5, 6); + Set set3 = new HashSet<>(); + set3.add(true); + assertCorrectFiltering(columnIndex, in(col, set3), 1, 4); + assertCorrectFiltering(columnIndex, notIn(col, set3), 0, 2, 3, 5, 6); + set3.add(null); + assertCorrectFiltering(columnIndex, in(col, set3), 0, 1, 2, 3, 4, 5, 6); + assertCorrectFiltering(columnIndex, notIn(col, set3), new int[0]); assertCorrectFiltering(columnIndex, notEq(col, true), 0, 2, 3, 4, 5, 6); assertCorrectFiltering(columnIndex, notEq(col, null), 1, 4, 5); assertCorrectFiltering(columnIndex, userDefined(col, BooleanIsTrueOrNull.class), 0, 1, 2, 3, 4, 5, 6); @@ -741,6 +815,14 @@ public void testBuildDouble() { assertCorrectValues(columnIndex.getMinValues(), -4.2, -11.7, 2.2, null, 1.9, -21.0); assertCorrectFiltering(columnIndex, eq(col, 0.0), 1, 5); assertCorrectFiltering(columnIndex, eq(col, null), 1, 2, 3); + Set set1 = new HashSet<>(); + set1.add(0.0); + set1.add(-4.2); + assertCorrectFiltering(columnIndex, in(col, set1), 0, 1, 5); + assertCorrectFiltering(columnIndex, notIn(col, set1), 2, 3, 4); + set1.add(null); + assertCorrectFiltering(columnIndex, in(col, set1), 0, 1, 2, 3, 5); + assertCorrectFiltering(columnIndex, notIn(col, set1), 4); assertCorrectFiltering(columnIndex, notEq(col, 2.2), 0, 1, 2, 3, 4, 5); assertCorrectFiltering(columnIndex, notEq(col, null), 0, 1, 2, 4, 5); assertCorrectFiltering(columnIndex, gt(col, 2.2), 1, 4, 5); @@ -771,6 +853,15 @@ public void testBuildDouble() { assertCorrectValues(columnIndex.getMinValues(), null, -532.3, -234.7, null, null, -234.6, null, 3.0, null); assertCorrectFiltering(columnIndex, eq(col, 0.0), 5); assertCorrectFiltering(columnIndex, eq(col, null), 0, 1, 2, 3, 4, 6, 8); + Set set2 = new HashSet<>(); + set2.add(0.0); + set2.add(3.5); + set2.add(-346.0); + assertCorrectFiltering(columnIndex, in(col, set2), 1, 2, 5, 7); + assertCorrectFiltering(columnIndex, notIn(col, set2), 0, 2, 3, 4, 6, 8); + set2.add(null); + assertCorrectFiltering(columnIndex, in(col, set2), 0, 1, 2, 3, 4, 5, 6, 7, 8); + assertCorrectFiltering(columnIndex, notIn(col, set2), new int[0]); assertCorrectFiltering(columnIndex, notEq(col, 0.0), 0, 1, 2, 3, 4, 5, 6, 7, 8); assertCorrectFiltering(columnIndex, notEq(col, null), 1, 2, 5, 7); assertCorrectFiltering(columnIndex, gt(col, 2.99999), 7); @@ -801,6 +892,13 @@ public void testBuildDouble() { assertCorrectValues(columnIndex.getMinValues(), null, 345.2, null, 234.6, null, -2.99999, null, null, -42.83); assertCorrectFiltering(columnIndex, eq(col, 234.6), 3, 5); assertCorrectFiltering(columnIndex, eq(col, null), 0, 2, 3, 4, 6, 7); + Set set3 = new HashSet<>(); + set3.add(234.6); + assertCorrectFiltering(columnIndex, in(col, set3), 3, 5); + assertCorrectFiltering(columnIndex, notIn(col, set3), 0, 1, 2, 4, 6, 7, 8); + set3.add(null); + assertCorrectFiltering(columnIndex, in(col, set3), 0, 2, 3, 4, 5, 6, 7); + assertCorrectFiltering(columnIndex, notIn(col, set3), 1, 8); assertCorrectFiltering(columnIndex, notEq(col, 2.2), 0, 1, 2, 3, 4, 5, 6, 7, 8); assertCorrectFiltering(columnIndex, notEq(col, null), 1, 3, 5, 8); assertCorrectFiltering(columnIndex, gt(col, 2.2), 1, 3, 5); @@ -871,6 +969,13 @@ public void testBuildFloat() { assertCorrectValues(columnIndex.getMinValues(), -4.2f, -11.7f, 2.2f, null, 1.9f, -21.0f); assertCorrectFiltering(columnIndex, eq(col, 0.0f), 1, 5); assertCorrectFiltering(columnIndex, eq(col, null), 1, 2, 3); + Set set1 = new HashSet<>(); + set1.add(0.0f); + assertCorrectFiltering(columnIndex, in(col, set1), 1, 5); + assertCorrectFiltering(columnIndex, notIn(col, set1), 0, 2, 3, 4); + set1.add(null); + assertCorrectFiltering(columnIndex, in(col, set1), 1, 2, 3, 5); + assertCorrectFiltering(columnIndex, notIn(col, set1), 0, 4); assertCorrectFiltering(columnIndex, notEq(col, 2.2f), 0, 1, 2, 3, 4, 5); assertCorrectFiltering(columnIndex, notEq(col, null), 0, 1, 2, 4, 5); assertCorrectFiltering(columnIndex, gt(col, 2.2f), 1, 4, 5); @@ -901,6 +1006,13 @@ public void testBuildFloat() { assertCorrectValues(columnIndex.getMinValues(), null, -532.3f, -300.6f, null, null, -234.6f, null, 3.0f, null); assertCorrectFiltering(columnIndex, eq(col, 0.0f), 5); assertCorrectFiltering(columnIndex, eq(col, null), 0, 1, 2, 3, 4, 6, 8); + Set set2 = new HashSet<>(); + set2.add(0.0f); + assertCorrectFiltering(columnIndex, in(col, set2), 5); + assertCorrectFiltering(columnIndex, notIn(col, set2), 0, 1, 2, 3, 4, 6, 7, 8); + set2.add(null); + assertCorrectFiltering(columnIndex, in(col, set2), 0, 1, 2, 3, 4, 5, 6, 8); + assertCorrectFiltering(columnIndex, notIn(col, set2), 7); assertCorrectFiltering(columnIndex, notEq(col, 2.2f), 0, 1, 2, 3, 4, 5, 6, 7, 8); assertCorrectFiltering(columnIndex, notEq(col, null), 1, 2, 5, 7); assertCorrectFiltering(columnIndex, gt(col, 2.2f), 5, 7); @@ -931,6 +1043,13 @@ public void testBuildFloat() { assertCorrectValues(columnIndex.getMinValues(), null, 345.2f, null, 234.6f, null, -2.99999f, null, null, -42.83f); assertCorrectFiltering(columnIndex, eq(col, 234.65f), 3); assertCorrectFiltering(columnIndex, eq(col, null), 0, 2, 3, 4, 6, 7); + Set set3 = new HashSet<>(); + set3.add(234.65f); + assertCorrectFiltering(columnIndex, in(col, set3), 3); + assertCorrectFiltering(columnIndex, notIn(col, set3), 0, 1, 2, 4, 5, 6, 7, 8); + set3.add(null); + assertCorrectFiltering(columnIndex, in(col, set3), 0, 2, 3, 4, 6, 7); + assertCorrectFiltering(columnIndex, notIn(col, set3), 1, 5, 8); assertCorrectFiltering(columnIndex, notEq(col, 2.2f), 0, 1, 2, 3, 4, 5, 6, 7, 8); assertCorrectFiltering(columnIndex, notEq(col, null), 1, 3, 5, 8); assertCorrectFiltering(columnIndex, gt(col, 2.2f), 1, 3, 5); @@ -1001,6 +1120,13 @@ public void testBuildInt32() { assertCorrectValues(columnIndex.getMinValues(), -4, -11, 2, null, 1, -21); assertCorrectFiltering(columnIndex, eq(col, 2), 0, 1, 2, 4, 5); assertCorrectFiltering(columnIndex, eq(col, null), 1, 2, 3); + Set set1 = new HashSet<>(); + set1.add(2); + assertCorrectFiltering(columnIndex, in(col, set1), 0, 1, 2, 4, 5); + assertCorrectFiltering(columnIndex, notIn(col, set1), 3); + set1.add(null); + assertCorrectFiltering(columnIndex, in(col, set1), 0, 1, 2, 3, 4, 5); + assertCorrectFiltering(columnIndex, notIn(col, set1), new int[0]); assertCorrectFiltering(columnIndex, notEq(col, 2), 0, 1, 2, 3, 4, 5); assertCorrectFiltering(columnIndex, notEq(col, null), 0, 1, 2, 4, 5); assertCorrectFiltering(columnIndex, gt(col, 2), 0, 1, 5); @@ -1031,6 +1157,13 @@ public void testBuildInt32() { assertCorrectValues(columnIndex.getMinValues(), null, -532, -500, null, null, -42, null, 3, null); assertCorrectFiltering(columnIndex, eq(col, 2), 5); assertCorrectFiltering(columnIndex, eq(col, null), 0, 1, 2, 3, 4, 6, 8); + Set set2 = new HashSet<>(); + set2.add(2); + assertCorrectFiltering(columnIndex, in(col, set2), 5); + assertCorrectFiltering(columnIndex, notIn(col, set2), 0, 1, 2, 3, 4, 6, 7, 8); + set2.add(null); + assertCorrectFiltering(columnIndex, in(col, set2), 0, 1, 2, 3, 4, 5, 6, 8); + assertCorrectFiltering(columnIndex, notIn(col, set2), 7); assertCorrectFiltering(columnIndex, notEq(col, 2), 0, 1, 2, 3, 4, 5, 6, 7, 8); assertCorrectFiltering(columnIndex, notEq(col, null), 1, 2, 5, 7); assertCorrectFiltering(columnIndex, gt(col, 2), 7); @@ -1062,6 +1195,13 @@ public void testBuildInt32() { assertCorrectValues(columnIndex.getMinValues(), null, 345, null, 42, null, -2, null, null, -42); assertCorrectFiltering(columnIndex, eq(col, 2), 5); assertCorrectFiltering(columnIndex, eq(col, null), 0, 2, 3, 4, 6, 7); + Set set3 = new HashSet<>(); + set3.add(2); + assertCorrectFiltering(columnIndex, in(col, set3), 5); + assertCorrectFiltering(columnIndex, notIn(col, set3), 0, 1, 2, 3, 4, 6, 7, 8); + set3.add(null); + assertCorrectFiltering(columnIndex, in(col, set3), 0, 2, 3, 4, 5, 6, 7); + assertCorrectFiltering(columnIndex, notIn(col, set3), 1, 8); assertCorrectFiltering(columnIndex, notEq(col, 2), 0, 1, 2, 3, 4, 5, 6, 7, 8); assertCorrectFiltering(columnIndex, notEq(col, null), 1, 3, 5, 8); assertCorrectFiltering(columnIndex, gt(col, 2), 1, 3, 5); @@ -1114,6 +1254,13 @@ public void testBuildUInt8() { assertCorrectValues(columnIndex.getMinValues(), 4, 11, 2, null, 1, 0xEF); assertCorrectFiltering(columnIndex, eq(col, 2), 2, 4); assertCorrectFiltering(columnIndex, eq(col, null), 1, 2, 3); + Set set1 = new HashSet<>(); + set1.add(2); + assertCorrectFiltering(columnIndex, in(col, set1), 2, 4); + assertCorrectFiltering(columnIndex, notIn(col, set1), 0, 1, 3, 5); + set1.add(null); + assertCorrectFiltering(columnIndex, in(col, set1), 1, 2, 3, 4); + assertCorrectFiltering(columnIndex, notIn(col, set1), 0, 5); assertCorrectFiltering(columnIndex, notEq(col, 2), 0, 1, 2, 3, 4, 5); assertCorrectFiltering(columnIndex, notEq(col, null), 0, 1, 2, 4, 5); assertCorrectFiltering(columnIndex, gt(col, 2), 0, 1, 4, 5); @@ -1144,6 +1291,13 @@ public void testBuildUInt8() { assertCorrectValues(columnIndex.getMinValues(), null, 0, 0, null, null, 42, null, 0xEF, null); assertCorrectFiltering(columnIndex, eq(col, 2), 2); assertCorrectFiltering(columnIndex, eq(col, null), 0, 1, 2, 3, 4, 6, 8); + Set set2 = new HashSet<>(); + set2.add(2); + assertCorrectFiltering(columnIndex, in(col, set2), 2); + assertCorrectFiltering(columnIndex, notIn(col, set2), 0, 1, 3, 4, 5, 6, 7, 8); + set2.add(null); + assertCorrectFiltering(columnIndex, in(col, set2), 0, 1, 2, 3, 4, 6, 8); + assertCorrectFiltering(columnIndex, notIn(col, set2), 5, 7); assertCorrectFiltering(columnIndex, notEq(col, 2), 0, 1, 2, 3, 4, 5, 6, 7, 8); assertCorrectFiltering(columnIndex, notEq(col, null), 1, 2, 5, 7); assertCorrectFiltering(columnIndex, gt(col, 0xEE), 7); @@ -1175,6 +1329,13 @@ public void testBuildUInt8() { assertCorrectValues(columnIndex.getMinValues(), null, 0xFF, null, 0xEA, null, 42, null, null, 0); assertCorrectFiltering(columnIndex, eq(col, 0xAB), 5); assertCorrectFiltering(columnIndex, eq(col, null), 0, 2, 3, 4, 6, 7); + Set set3 = new HashSet<>(); + set3.add(0xAB); + assertCorrectFiltering(columnIndex, in(col, set3), 5); + assertCorrectFiltering(columnIndex, notIn(col, set3), 0, 1, 2, 3, 4, 6, 7, 8); + set3.add(null); + assertCorrectFiltering(columnIndex, in(col, set3), 0, 2, 3, 4, 5, 6, 7); + assertCorrectFiltering(columnIndex, notIn(col, set3), 1, 8); assertCorrectFiltering(columnIndex, notEq(col, 0xFF), 0, 2, 3, 4, 5, 6, 7, 8); assertCorrectFiltering(columnIndex, notEq(col, null), 1, 3, 5, 8); assertCorrectFiltering(columnIndex, gt(col, 0xFF)); @@ -1211,6 +1372,13 @@ public void testBuildInt64() { assertCorrectValues(columnIndex.getMinValues(), -4l, -11l, 2l, null, 1l, -21l); assertCorrectFiltering(columnIndex, eq(col, 0l), 0, 1, 5); assertCorrectFiltering(columnIndex, eq(col, null), 1, 2, 3); + Set set1 = new HashSet<>(); + set1.add(0l); + assertCorrectFiltering(columnIndex, in(col, set1), 0, 1, 5); + assertCorrectFiltering(columnIndex, notIn(col, set1), 2, 3, 4); + set1.add(null); + assertCorrectFiltering(columnIndex, in(col, set1), 0, 1, 2, 3, 5); + assertCorrectFiltering(columnIndex, notIn(col, set1), 4); assertCorrectFiltering(columnIndex, notEq(col, 0l), 0, 1, 2, 3, 4, 5); assertCorrectFiltering(columnIndex, notEq(col, null), 0, 1, 2, 4, 5); assertCorrectFiltering(columnIndex, gt(col, 2l), 0, 1, 5); @@ -1241,6 +1409,13 @@ public void testBuildInt64() { assertCorrectValues(columnIndex.getMinValues(), null, -532l, -234l, null, null, -42l, null, -3l, null); assertCorrectFiltering(columnIndex, eq(col, -42l), 2, 5); assertCorrectFiltering(columnIndex, eq(col, null), 0, 1, 2, 3, 4, 6, 8); + Set set2 = new HashSet<>(); + set2.add(-42l); + assertCorrectFiltering(columnIndex, in(col, set2), 2, 5); + assertCorrectFiltering(columnIndex, notIn(col, set2), 0, 1, 3, 4, 6, 7, 8); + set2.add(null); + assertCorrectFiltering(columnIndex, in(col, set2), 0, 1, 2, 3, 4, 5, 6, 8); + assertCorrectFiltering(columnIndex, notIn(col, set2), 7); assertCorrectFiltering(columnIndex, notEq(col, -42l), 0, 1, 2, 3, 4, 5, 6, 7, 8); assertCorrectFiltering(columnIndex, notEq(col, null), 1, 2, 5, 7); assertCorrectFiltering(columnIndex, gt(col, 2l), 7); @@ -1272,6 +1447,13 @@ public void testBuildInt64() { assertCorrectValues(columnIndex.getMinValues(), null, 345l, null, 42l, null, -2l, null, null, -42l); assertCorrectFiltering(columnIndex, eq(col, 0l), 5); assertCorrectFiltering(columnIndex, eq(col, null), 0, 2, 3, 4, 6, 7); + Set set3 = new HashSet<>(); + set3.add(0l); + assertCorrectFiltering(columnIndex, in(col, set3), 5); + assertCorrectFiltering(columnIndex, notIn(col, set3), 0, 1, 2, 3, 4, 6, 7, 8); + set3.add(null); + assertCorrectFiltering(columnIndex, in(col, set3), 0, 2, 3, 4, 5, 6, 7); + assertCorrectFiltering(columnIndex, notIn(col, set3), 1, 8); assertCorrectFiltering(columnIndex, notEq(col, 0l), 0, 1, 2, 3, 4, 5, 6, 7, 8); assertCorrectFiltering(columnIndex, notEq(col, null), 1, 3, 5, 8); assertCorrectFiltering(columnIndex, gt(col, 2l), 1, 3, 5); diff --git a/parquet-column/src/test/java/org/apache/parquet/internal/filter2/columnindex/TestColumnIndexFilter.java b/parquet-column/src/test/java/org/apache/parquet/internal/filter2/columnindex/TestColumnIndexFilter.java index f37a343b40..47ea5fc5c1 100644 --- a/parquet-column/src/test/java/org/apache/parquet/internal/filter2/columnindex/TestColumnIndexFilter.java +++ b/parquet-column/src/test/java/org/apache/parquet/internal/filter2/columnindex/TestColumnIndexFilter.java @@ -26,11 +26,13 @@ import static org.apache.parquet.filter2.predicate.FilterApi.eq; import static org.apache.parquet.filter2.predicate.FilterApi.gt; import static org.apache.parquet.filter2.predicate.FilterApi.gtEq; +import static org.apache.parquet.filter2.predicate.FilterApi.in; import static org.apache.parquet.filter2.predicate.FilterApi.intColumn; import static org.apache.parquet.filter2.predicate.FilterApi.longColumn; import static org.apache.parquet.filter2.predicate.FilterApi.lt; import static org.apache.parquet.filter2.predicate.FilterApi.ltEq; import static org.apache.parquet.filter2.predicate.FilterApi.notEq; +import static org.apache.parquet.filter2.predicate.FilterApi.notIn; import static org.apache.parquet.filter2.predicate.FilterApi.or; import static org.apache.parquet.filter2.predicate.FilterApi.userDefined; import static org.apache.parquet.filter2.predicate.LogicalInverter.invert; @@ -68,6 +70,7 @@ import org.apache.parquet.internal.column.columnindex.TestColumnIndexBuilder.BinaryUtf8StartsWithB; import org.apache.parquet.internal.column.columnindex.TestColumnIndexBuilder.DoubleIsInteger; import org.apache.parquet.internal.column.columnindex.TestColumnIndexBuilder.IntegerIsDivisableWith3; +import org.apache.parquet.io.api.Binary; import org.apache.parquet.schema.PrimitiveType; import org.junit.Test; @@ -360,6 +363,58 @@ public void testFiltering() { calculateRowRanges(FilterCompat.get( userDefined(intColumn("column1"), AnyInt.class)), STORE, paths, TOTAL_ROW_COUNT), TOTAL_ROW_COUNT); + + Set set1 = new HashSet<>(); + set1.add(7); + assertRows(calculateRowRanges(FilterCompat.get(in(intColumn("column1"), set1)), STORE, paths, TOTAL_ROW_COUNT), + 7, 8, 9, 10, 11, 12, 13); + set1.add(1); + assertRows(calculateRowRanges(FilterCompat.get(in(intColumn("column1"), set1)), STORE, paths, TOTAL_ROW_COUNT), + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13); + assertRows(calculateRowRanges(FilterCompat.get(notIn(intColumn("column1"), set1)), STORE, paths, TOTAL_ROW_COUNT), + 1, 2, 3, 4, 5, 6, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29); + + Set set2 = new HashSet<>(); + set2.add(fromString("Zulu")); + set2.add(fromString("Alfa")); + assertRows(calculateRowRanges(FilterCompat.get(in(binaryColumn("column2"), set2)), STORE, paths, TOTAL_ROW_COUNT), + 0, 1, 2, 3, 4, 5, 6, 7, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29); + assertRows(calculateRowRanges(FilterCompat.get(notIn(binaryColumn("column2"), set2)), STORE, paths, TOTAL_ROW_COUNT), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28); + + Set set3 = new HashSet<>(); + set3.add(2.03); + assertRows(calculateRowRanges(FilterCompat.get(in(doubleColumn("column3"), set3)), STORE, paths, TOTAL_ROW_COUNT), + 0, 1, 2, 3, 4, 5, 16, 17, 18, 19, 20, 21, 22); + assertRows(calculateRowRanges(FilterCompat.get(notIn(doubleColumn("column3"), set3)), STORE, paths, TOTAL_ROW_COUNT), + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 23, 24, 25, 26, 27, 28, 29); + set3.add(9.98); + assertRows(calculateRowRanges(FilterCompat.get(in(doubleColumn("column3"), set3)), STORE, paths, TOTAL_ROW_COUNT), + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25); + assertRows(calculateRowRanges(FilterCompat.get(notIn(doubleColumn("column3"), set3)), STORE, paths, TOTAL_ROW_COUNT), + 6, 7, 8, 9, 23, 24, 25, 26, 27, 28, 29); + set3.add(null); + assertRows(calculateRowRanges(FilterCompat.get(in(doubleColumn("column3"), set3)), STORE, paths, TOTAL_ROW_COUNT), + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29); + assertRows(calculateRowRanges(FilterCompat.get(notIn(doubleColumn("column3"), set3)), STORE, paths, TOTAL_ROW_COUNT), + 23, 24, 25); + + Set set4 = new HashSet<>(); + set4.add(null); + assertRows(calculateRowRanges(FilterCompat.get(in(booleanColumn("column4"), set4)), STORE, paths, TOTAL_ROW_COUNT), + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29); + // no column index, can't filter this row + assertRows(calculateRowRanges(FilterCompat.get(notIn(booleanColumn("column4"), set4)), STORE, paths, TOTAL_ROW_COUNT), + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29); + + Set set5 = new HashSet<>(); + set5.add(7); + set5.add(20); + assertRows(calculateRowRanges(FilterCompat.get(in(intColumn("column5"), set5)), STORE, paths, TOTAL_ROW_COUNT), + new long[0]); + assertRows(calculateRowRanges(FilterCompat.get(notIn(intColumn("column5"), set5)), STORE, paths, TOTAL_ROW_COUNT), + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29); + assertRows(calculateRowRanges(FilterCompat.get( and( and( diff --git a/parquet-generator/src/main/java/org/apache/parquet/filter2/IncrementallyUpdatedFilterPredicateGenerator.java b/parquet-generator/src/main/java/org/apache/parquet/filter2/IncrementallyUpdatedFilterPredicateGenerator.java index 3c1cf4866c..2827167cb2 100644 --- a/parquet-generator/src/main/java/org/apache/parquet/filter2/IncrementallyUpdatedFilterPredicateGenerator.java +++ b/parquet-generator/src/main/java/org/apache/parquet/filter2/IncrementallyUpdatedFilterPredicateGenerator.java @@ -67,15 +67,18 @@ public void run() throws IOException { add("package org.apache.parquet.filter2.recordlevel;\n" + "\n" + "import java.util.List;\n" + + "import java.util.Set;\n" + "\n" + "import org.apache.parquet.hadoop.metadata.ColumnPath;\n" + "import org.apache.parquet.filter2.predicate.Operators.Eq;\n" + "import org.apache.parquet.filter2.predicate.Operators.Gt;\n" + "import org.apache.parquet.filter2.predicate.Operators.GtEq;\n" + + "import org.apache.parquet.filter2.predicate.Operators.In;\n" + "import org.apache.parquet.filter2.predicate.Operators.LogicalNotUserDefined;\n" + "import org.apache.parquet.filter2.predicate.Operators.Lt;\n" + "import org.apache.parquet.filter2.predicate.Operators.LtEq;\n" + "import org.apache.parquet.filter2.predicate.Operators.NotEq;\n" + + "import org.apache.parquet.filter2.predicate.Operators.NotIn;\n" + "import org.apache.parquet.filter2.predicate.Operators.UserDefined;\n" + "import org.apache.parquet.filter2.predicate.UserDefinedPredicate;\n" + "import org.apache.parquet.filter2.recordlevel.IncrementallyUpdatedFilterPredicate.ValueInspector;\n" + @@ -106,6 +109,18 @@ public void run() throws IOException { } addVisitEnd(); + addVisitBegin("In"); + for (TypeInfo info : TYPES) { + addInNotInCase(info, true); + } + addVisitEnd(); + + addVisitBegin("NotIn"); + for (TypeInfo info : TYPES) { + addInNotInCase(info, false); + } + addVisitEnd(); + addVisitBegin("Lt"); for (TypeInfo info : TYPES) { addInequalityCase(info, "<"); @@ -233,6 +248,56 @@ private void addInequalityCase(TypeInfo info, String op) throws IOException { " }\n\n"); } + private void addInNotInCase(TypeInfo info, boolean isEq) throws IOException { + add(" if (clazz.equals(" + info.className + ".class)) {\n" + + " if (pred.getValues().contains(null)) {\n" + + " valueInspector = new ValueInspector() {\n" + + " @Override\n" + + " public void updateNull() {\n" + + " setResult(" + isEq + ");\n" + + " }\n" + + "\n" + + " @Override\n" + + " public void update(" + info.primitiveName + " value) {\n" + + " setResult(" + !isEq + ");\n" + + " }\n" + + " };\n" + + " } else {\n" + + " final Set<" + info.className + "> target = (Set<" + info.className + ">) pred.getValues();\n" + + " final PrimitiveComparator<" + info.className + "> comparator = getComparator(columnPath);\n" + + "\n" + + " valueInspector = new ValueInspector() {\n" + + " @Override\n" + + " public void updateNull() {\n" + + " setResult(" + !isEq +");\n" + + " }\n" + + "\n" + + " @Override\n" + + " public void update(" + info.primitiveName + " value) {\n" + + " boolean set = false;\n"); + + add(" for (" + info.primitiveName + " i : target) {\n"); + + add(" if(" + compareEquality("value", "i", isEq) + ") {\n"); + + add(" setResult(true);\n"); + + add(" set = true;\n"); + + add(" break;\n"); + + add(" }\n"); + + add(" }\n"); + add(" if (!set) setResult(false);\n"); + add(" }\n"); + + add(" };\n" + + " }\n" + + " }\n\n"); + } + + private void addUdpBegin() throws IOException { add(" ColumnPath columnPath = pred.getColumn().getColumnPath();\n" + " Class clazz = pred.getColumn().getColumnType();\n" + diff --git a/parquet-hadoop/src/main/java/org/apache/parquet/filter2/bloomfilterlevel/BloomFilterImpl.java b/parquet-hadoop/src/main/java/org/apache/parquet/filter2/bloomfilterlevel/BloomFilterImpl.java index d98416445f..d069836337 100644 --- a/parquet-hadoop/src/main/java/org/apache/parquet/filter2/bloomfilterlevel/BloomFilterImpl.java +++ b/parquet-hadoop/src/main/java/org/apache/parquet/filter2/bloomfilterlevel/BloomFilterImpl.java @@ -22,6 +22,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -118,6 +120,41 @@ public > Boolean visit(Operators.GtEq gtEq) { return BLOCK_MIGHT_MATCH; } + @Override + public > Boolean visit(Operators.In in) { + Set values = in.getValues(); + + if (values.contains(null)) { + // the bloom filter bitset contains only non-null values so isn't helpful. this + // could check the column stats, but the StatisticsFilter is responsible + return BLOCK_MIGHT_MATCH; + } + + Operators.Column filterColumn = in.getColumn(); + ColumnChunkMetaData meta = getColumnChunk(filterColumn.getColumnPath()); + if (meta == null) { + // the column isn't in this file so all values are null, but the value + // must be non-null because of the above check. + return BLOCK_CANNOT_MATCH; + } + + BloomFilter bloomFilter = bloomFilterReader.readBloomFilter(meta); + if (bloomFilter != null) { + for (T value : values) { + if (bloomFilter.findHash(bloomFilter.hash(value))) { + return BLOCK_MIGHT_MATCH; + } + } + return BLOCK_CANNOT_MATCH; + } + return BLOCK_MIGHT_MATCH; + } + + @Override + public > Boolean visit(Operators.NotIn notIn) { + return BLOCK_MIGHT_MATCH; + } + @Override public Boolean visit(Operators.And and) { return and.getLeft().accept(this) || and.getRight().accept(this); diff --git a/parquet-hadoop/src/main/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilter.java b/parquet-hadoop/src/main/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilter.java index 2f69fa6684..c21212ac14 100644 --- a/parquet-hadoop/src/main/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilter.java +++ b/parquet-hadoop/src/main/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilter.java @@ -364,6 +364,108 @@ public > Boolean visit(GtEq gtEq) { return BLOCK_MIGHT_MATCH; } + @Override + public > Boolean visit(In in) { + Set values = in.getValues(); + + if (values.contains(null)) { + // the dictionary contains only non-null values so isn't helpful. this + // could check the column stats, but the StatisticsFilter is responsible + return BLOCK_MIGHT_MATCH; + } + + Column filterColumn = in.getColumn(); + ColumnChunkMetaData meta = getColumnChunk(filterColumn.getColumnPath()); + + if (meta == null) { + // the column isn't in this file so all values are null, but the value + // must be non-null because of the above check. + return BLOCK_CANNOT_MATCH; + } + + // if the chunk has non-dictionary pages, don't bother decoding the + // dictionary because the row group can't be eliminated. + if (hasNonDictionaryPages(meta)) { + return BLOCK_MIGHT_MATCH; + } + + try { + Set dictSet = expandDictionary(meta); + if (dictSet != null) { + return drop(dictSet, values); + } + } catch (IOException e) { + LOG.warn("Failed to process dictionary for filter evaluation.", e); + } + return BLOCK_MIGHT_MATCH; // cannot drop the row group based on this dictionary + } + + private > Boolean drop(Set dictSet, Set values) { + // we need to find out the smaller set to iterate through + Set smallerSet; + Set biggerSet; + + if (values.size() < dictSet.size()) { + smallerSet = values; + biggerSet = dictSet; + } else { + smallerSet = dictSet; + biggerSet = values; + } + + for (T e : smallerSet) { + if (biggerSet.contains(e)) { + // value sets intersect so rows match + return BLOCK_MIGHT_MATCH; + } + } + return BLOCK_CANNOT_MATCH; + } + + @Override + public > Boolean visit(NotIn notIn) { + Set values = notIn.getValues(); + + Column filterColumn = notIn.getColumn(); + ColumnChunkMetaData meta = getColumnChunk(filterColumn.getColumnPath()); + + if (values.size() == 1 && values.contains(null) && meta == null) { + // the predicate value is null and all rows have a null value, so the + // predicate is always false (null != null) + return BLOCK_CANNOT_MATCH; + } + + if (values.contains(null)) { + // the dictionary contains only non-null values so isn't helpful. this + // could check the column stats, but the StatisticsFilter is responsible + return BLOCK_MIGHT_MATCH; + } + + if (meta == null) { + // the column isn't in this file so all values are null, but the value + // must be non-null because of the above check. + return BLOCK_MIGHT_MATCH; + } + + // if the chunk has non-dictionary pages, don't bother decoding the + // dictionary because the row group can't be eliminated. + if (hasNonDictionaryPages(meta)) { + return BLOCK_MIGHT_MATCH; + } + + try { + Set dictSet = expandDictionary(meta); + if (dictSet != null) { + if (dictSet.size() > values.size()) return BLOCK_MIGHT_MATCH; + // ROWS_CANNOT_MATCH if no values in the dictionary that are not also in the set + return values.containsAll(dictSet) ? BLOCK_CANNOT_MATCH : BLOCK_MIGHT_MATCH; + } + } catch (IOException e) { + LOG.warn("Failed to process dictionary for filter evaluation.", e); + } + return BLOCK_MIGHT_MATCH; + } + @Override public Boolean visit(And and) { return and.getLeft().accept(this) || and.getRight().accept(this); diff --git a/parquet-hadoop/src/main/java/org/apache/parquet/filter2/statisticslevel/StatisticsFilter.java b/parquet-hadoop/src/main/java/org/apache/parquet/filter2/statisticslevel/StatisticsFilter.java index 4db2eb9ef6..23609a93d5 100644 --- a/parquet-hadoop/src/main/java/org/apache/parquet/filter2/statisticslevel/StatisticsFilter.java +++ b/parquet-hadoop/src/main/java/org/apache/parquet/filter2/statisticslevel/StatisticsFilter.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import org.apache.parquet.column.statistics.Statistics; import org.apache.parquet.hadoop.metadata.ColumnPath; @@ -31,15 +32,18 @@ import org.apache.parquet.filter2.predicate.Operators.Eq; import org.apache.parquet.filter2.predicate.Operators.Gt; import org.apache.parquet.filter2.predicate.Operators.GtEq; +import org.apache.parquet.filter2.predicate.Operators.In; import org.apache.parquet.filter2.predicate.Operators.LogicalNotUserDefined; import org.apache.parquet.filter2.predicate.Operators.Lt; import org.apache.parquet.filter2.predicate.Operators.LtEq; import org.apache.parquet.filter2.predicate.Operators.Not; import org.apache.parquet.filter2.predicate.Operators.NotEq; +import org.apache.parquet.filter2.predicate.Operators.NotIn; import org.apache.parquet.filter2.predicate.Operators.Or; import org.apache.parquet.filter2.predicate.Operators.UserDefined; import org.apache.parquet.filter2.predicate.UserDefinedPredicate; import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; +import org.apache.parquet.column.MinMax; /** * Applies a {@link org.apache.parquet.filter2.predicate.FilterPredicate} to statistics about a group of @@ -144,6 +148,71 @@ public > Boolean visit(Eq eq) { return stats.compareMinToValue(value) > 0 || stats.compareMaxToValue(value) < 0; } + @Override + @SuppressWarnings("unchecked") + public > Boolean visit(In in) { + Column filterColumn = in.getColumn(); + ColumnChunkMetaData meta = getColumnChunk(filterColumn.getColumnPath()); + + Set values = in.getValues(); + + if (meta == null) { + // the column isn't in this file so all values are null. + if (!values.contains(null)) { + // non-null is never null + return BLOCK_CANNOT_MATCH; + } + return BLOCK_MIGHT_MATCH; + } + + Statistics stats = meta.getStatistics(); + + if (stats.isEmpty()) { + // we have no statistics available, we cannot drop any chunks + return BLOCK_MIGHT_MATCH; + } + + if (isAllNulls(meta)) { + // we are looking for records where v in(someNonNull) + // and this is a column of all nulls, so drop it unless in set contains null. + if (values.contains(null)) { + return BLOCK_MIGHT_MATCH; + } else { + return BLOCK_CANNOT_MATCH; + } + } + + if (!stats.hasNonNullValue()) { + // stats does not contain min/max values, we cannot drop any chunks + return BLOCK_MIGHT_MATCH; + } + + if (stats.isNumNullsSet()) { + if (stats.getNumNulls() == 0) { + if (values.contains(null) && values.size() == 1) return BLOCK_CANNOT_MATCH; + } else { + if (values.contains(null)) return BLOCK_MIGHT_MATCH; + } + } + + MinMax minMax = new MinMax(meta.getPrimitiveType().comparator(), values); + T min = minMax.getMin(); + T max = minMax.getMax(); + + // drop if all the element in value < min || all the element in value > max + if (stats.compareMinToValue(max) <= 0 && + stats.compareMaxToValue(min) >= 0) { + return BLOCK_MIGHT_MATCH; + } else { + return BLOCK_CANNOT_MATCH; + } + } + + @Override + public > Boolean visit(NotIn notIn) { + return BLOCK_MIGHT_MATCH; + } + @Override @SuppressWarnings("unchecked") public > Boolean visit(NotEq notEq) { diff --git a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/TestFiltersWithMissingColumns.java b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/TestFiltersWithMissingColumns.java index 3282f27fe2..3d18e1c3ad 100644 --- a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/TestFiltersWithMissingColumns.java +++ b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/TestFiltersWithMissingColumns.java @@ -37,6 +37,8 @@ import org.junit.rules.TemporaryFolder; import java.io.File; import java.io.IOException; +import java.util.HashSet; +import java.util.Set; import static org.apache.parquet.filter2.predicate.FilterApi.and; import static org.apache.parquet.filter2.predicate.FilterApi.binaryColumn; @@ -44,10 +46,12 @@ import static org.apache.parquet.filter2.predicate.FilterApi.eq; import static org.apache.parquet.filter2.predicate.FilterApi.gt; import static org.apache.parquet.filter2.predicate.FilterApi.gtEq; +import static org.apache.parquet.filter2.predicate.FilterApi.in; import static org.apache.parquet.filter2.predicate.FilterApi.longColumn; import static org.apache.parquet.filter2.predicate.FilterApi.lt; import static org.apache.parquet.filter2.predicate.FilterApi.ltEq; import static org.apache.parquet.filter2.predicate.FilterApi.notEq; +import static org.apache.parquet.filter2.predicate.FilterApi.notIn; import static org.apache.parquet.filter2.predicate.FilterApi.or; import static org.apache.parquet.io.api.Binary.fromString; import static org.apache.parquet.schema.OriginalType.UTF8; @@ -98,6 +102,12 @@ public void testNormalFilter() throws Exception { @Test public void testSimpleMissingColumnFilter() throws Exception { assertEquals(0, countFilteredRecords(path, lt(longColumn("missing"), 500L))); + Set values = new HashSet<>(); + values.add(1L); + values.add(2L); + values.add(5L); + assertEquals(0, countFilteredRecords(path, in(longColumn("missing"), values))); + assertEquals(1000, countFilteredRecords(path, notIn(longColumn("missing"), values))); } @Test diff --git a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/compat/TestRowGroupFilter.java b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/compat/TestRowGroupFilter.java index 14877abb6c..40527775bf 100644 --- a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/compat/TestRowGroupFilter.java +++ b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/compat/TestRowGroupFilter.java @@ -21,6 +21,8 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Set; +import java.util.HashSet; import org.junit.Test; @@ -32,8 +34,10 @@ import static org.junit.Assert.assertEquals; import static org.apache.parquet.filter2.predicate.FilterApi.eq; +import static org.apache.parquet.filter2.predicate.FilterApi.in; import static org.apache.parquet.filter2.predicate.FilterApi.intColumn; import static org.apache.parquet.filter2.predicate.FilterApi.notEq; +import static org.apache.parquet.filter2.predicate.FilterApi.notIn; import static org.apache.parquet.hadoop.TestInputFormat.makeBlockFromStats; public class TestRowGroupFilter { @@ -83,7 +87,29 @@ public void testApplyRowGroupFilters() { MessageType schema = MessageTypeParser.parseMessageType("message Document { optional int32 foo; }"); IntColumn foo = intColumn("foo"); - List filtered = RowGroupFilter.filterRowGroups(FilterCompat.get(eq(foo, 50)), blocks, schema); + Set set1 = new HashSet<>(); + set1.add(9); + set1.add(10); + set1.add(50); + List filtered = RowGroupFilter.filterRowGroups(FilterCompat.get(in(foo, set1)), blocks, schema); + assertEquals(Arrays.asList(b1, b2, b5), filtered); + filtered = RowGroupFilter.filterRowGroups(FilterCompat.get(notIn(foo, set1)), blocks, schema); + assertEquals(Arrays.asList(b1, b2, b3, b4, b5, b6), filtered); + + Set set2 = new HashSet<>(); + set2.add(null); + filtered = RowGroupFilter.filterRowGroups(FilterCompat.get(in(foo, set2)), blocks, schema); + assertEquals(Arrays.asList(b1, b3, b4, b5, b6), filtered); + filtered = RowGroupFilter.filterRowGroups(FilterCompat.get(notIn(foo, set2)), blocks, schema); + assertEquals(Arrays.asList(b1, b2, b3, b4, b5, b6), filtered); + + set2.add(8); + filtered = RowGroupFilter.filterRowGroups(FilterCompat.get(in(foo, set2)), blocks, schema); + assertEquals(Arrays.asList(b1, b2, b3, b4, b5, b6), filtered); + filtered = RowGroupFilter.filterRowGroups(FilterCompat.get(notIn(foo, set2)), blocks, schema); + assertEquals(Arrays.asList(b1, b2, b3, b4, b5, b6), filtered); + + filtered = RowGroupFilter.filterRowGroups(FilterCompat.get(eq(foo, 50)), blocks, schema); assertEquals(Arrays.asList(b1, b2, b5), filtered); filtered = RowGroupFilter.filterRowGroups(FilterCompat.get(notEq(foo, 50)), blocks, schema); diff --git a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilterTest.java b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilterTest.java index 1e243f89dc..65cefe46f7 100644 --- a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilterTest.java +++ b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilterTest.java @@ -464,6 +464,122 @@ public void testGtEqDouble() throws Exception { canDrop(gtEq(d, Double.MIN_VALUE), ccmd, dictionaries)); } + @Test + public void testInBinary() throws Exception { + BinaryColumn b = binaryColumn("binary_field"); + + Set set1 = new HashSet<>(); + set1.add(Binary.fromString("F")); + set1.add(Binary.fromString("C")); + set1.add(Binary.fromString("h")); + set1.add(Binary.fromString("E")); + FilterPredicate predIn1 = in(b, set1); + FilterPredicate predNotIn1 = notIn(b, set1); + assertFalse("Should not drop block", canDrop(predIn1, ccmd, dictionaries)); + assertFalse("Should not drop block", canDrop(predNotIn1, ccmd, dictionaries)); + + Set set2 = new HashSet<>(); + for (int i = 0; i < 26; i++) { + set2.add(Binary.fromString(Character.toString((char) (i + 97)))); + } + set2.add(Binary.fromString("A")); + FilterPredicate predIn2 = in(b, set2); + FilterPredicate predNotIn2 = notIn(b, set2); + assertFalse("Should not drop block", canDrop(predIn2, ccmd, dictionaries)); + assertTrue("Should not drop block", canDrop(predNotIn2, ccmd, dictionaries)); + + Set set3 = new HashSet<>(); + set3.add(Binary.fromString("F")); + set3.add(Binary.fromString("C")); + set3.add(Binary.fromString("A")); + set3.add(Binary.fromString("E")); + FilterPredicate predIn3 = in(b, set3); + FilterPredicate predNotIn3 = notIn(b, set3); + assertTrue("Should drop block", canDrop(predIn3, ccmd, dictionaries)); + assertFalse("Should not drop block", canDrop(predNotIn3, ccmd, dictionaries)); + + Set set4 = new HashSet<>(); + set4.add(null); + FilterPredicate predIn4 = in(b, set4); + FilterPredicate predNotIn4 = notIn(b, set4); + assertFalse("Should not drop block for null", canDrop(predIn4, ccmd, dictionaries)); + assertFalse("Should not drop block for null", canDrop(predNotIn4, ccmd, dictionaries)); + } + + @Test + public void testInFixed() throws Exception { + BinaryColumn b = binaryColumn("fixed_field"); + + // Only V2 supports dictionary encoding for FIXED_LEN_BYTE_ARRAY values + if (version == PARQUET_2_0) { + Set set1 = new HashSet<>(); + set1.add(toBinary("-2", 17)); + set1.add(toBinary("-22", 17)); + set1.add(toBinary("12345", 17)); + FilterPredicate predIn1 = in(b, set1); + FilterPredicate predNotIn1 = notIn(b, set1); + assertTrue("Should drop block for in (-2, -22, 12345)", + canDrop(predIn1, ccmd, dictionaries)); + assertFalse("Should not drop block for notIn (-2, -22, 12345)", + canDrop(predNotIn1, ccmd, dictionaries)); + + Set set2 = new HashSet<>(); + set2.add(toBinary("-1", 17)); + set2.add(toBinary("0", 17)); + set2.add(toBinary("12345", 17)); + assertFalse("Should not drop block for in (-1, 0, 12345)", + canDrop(in(b, set2), ccmd, dictionaries)); + assertFalse("Should not drop block for in (-1, 0, 12345)", + canDrop(notIn(b, set2), ccmd, dictionaries)); + } + + Set set3 = new HashSet<>(); + set3.add(null); + FilterPredicate predIn3 = in(b, set3); + FilterPredicate predNotIn3 = notIn(b, set3); + assertFalse("Should not drop block for null", + canDrop(predIn3, ccmd, dictionaries)); + assertFalse("Should not drop block for null", + canDrop(predNotIn3, ccmd, dictionaries)); + } + + @Test + public void testInInt96() throws Exception { + // INT96 ordering is undefined => no filtering shall be done + BinaryColumn b = binaryColumn("int96_field"); + + Set set1 = new HashSet<>(); + set1.add(toBinary("-2", 12)); + set1.add(toBinary("-0", 12)); + set1.add(toBinary("12345", 12)); + FilterPredicate predIn1 = in(b, set1); + FilterPredicate predNotIn1 = notIn(b, set1); + assertFalse("Should not drop block for in (-2, -0, 12345)", + canDrop(predIn1, ccmd, dictionaries)); + assertFalse("Should not drop block for notIn (-2, -0, 12345)", + canDrop(predNotIn1, ccmd, dictionaries)); + + Set set2 = new HashSet<>(); + set2.add(toBinary("-2", 17)); + set2.add(toBinary("12345", 17)); + set2.add(toBinary("-789", 17)); + FilterPredicate predIn2 = in(b, set2); + FilterPredicate predNotIn2 = notIn(b, set2); + assertFalse("Should not drop block for in (-2, 12345, -789)", + canDrop(predIn2, ccmd, dictionaries)); + assertFalse("Should not drop block for notIn (-2, 12345, -789)", + canDrop(predNotIn2, ccmd, dictionaries)); + + Set set3 = new HashSet<>(); + set3.add(null); + FilterPredicate predIn3 = in(b, set3); + FilterPredicate predNotIn3 = notIn(b, set3); + assertFalse("Should not drop block for null", + canDrop(predIn3, ccmd, dictionaries)); + assertFalse("Should not drop block for null", + canDrop(predNotIn3, ccmd, dictionaries)); + } + @Test public void testAnd() throws Exception { BinaryColumn col = binaryColumn("binary_field"); diff --git a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/TestRecordLevelFilters.java b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/TestRecordLevelFilters.java index 5a7d02f19f..4c3538c3d5 100644 --- a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/TestRecordLevelFilters.java +++ b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/TestRecordLevelFilters.java @@ -50,6 +50,7 @@ import static org.apache.parquet.filter2.predicate.FilterApi.longColumn; import static org.apache.parquet.filter2.predicate.FilterApi.eq; import static org.apache.parquet.filter2.predicate.FilterApi.gt; +import static org.apache.parquet.filter2.predicate.FilterApi.in; import static org.apache.parquet.filter2.predicate.FilterApi.not; import static org.apache.parquet.filter2.predicate.FilterApi.notEq; import static org.apache.parquet.filter2.predicate.FilterApi.or; @@ -146,6 +147,39 @@ public void testAllFilter() throws Exception { assertEquals(new ArrayList(), found); } + @Test + public void testInFilter() throws Exception { + BinaryColumn name = binaryColumn("name"); + + HashSet nameSet = new HashSet<>(); + nameSet.add(Binary.fromString("thing2")); + nameSet.add(Binary.fromString("thing1")); + for (int i = 100; i < 200; i++) { + nameSet.add(Binary.fromString("p" + i)); + } + FilterPredicate pred = in(name, nameSet); + List found = PhoneBookWriter.readFile(phonebookFile, FilterCompat.get(pred)); + + List expectedNames = new ArrayList<>(); + expectedNames.add("thing1"); + expectedNames.add("thing2"); + for (int i = 100; i < 200; i++) { + expectedNames.add("p" + i); + } + expectedNames.add("dummy1"); + expectedNames.add("dummy2"); + expectedNames.add("dummy3"); + + // validate that all the values returned by the reader fulfills the filter and there are no values left out, + // i.e. "thing1", "thing2" and from "p100" to "p199" and nothing else. + assertEquals(expectedNames.get(0), ((Group)(found.get(0))).getString("name", 0)); + assertEquals(expectedNames.get(1), ((Group)(found.get(1))).getString("name", 0)); + for (int i = 2; i < 102; i++) { + assertEquals(expectedNames.get(i), ((Group)(found.get(i))).getString("name", 0)); + } + assert(found.size() == 102); + } + @Test public void testNameNotNull() throws Exception { BinaryColumn name = binaryColumn("name"); @@ -232,7 +266,7 @@ public void testUserDefinedByInstance() throws Exception { LongColumn name = longColumn("id"); final HashSet h = new HashSet(); - h.add(20L); + h.add(20L); h.add(27L); h.add(28L); diff --git a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/statisticslevel/TestStatisticsFilter.java b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/statisticslevel/TestStatisticsFilter.java index 97dd1695ae..e9682e6170 100644 --- a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/statisticslevel/TestStatisticsFilter.java +++ b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/statisticslevel/TestStatisticsFilter.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.HashSet; import java.util.List; +import java.util.Set; import org.junit.Test; @@ -37,6 +38,7 @@ import org.apache.parquet.filter2.predicate.UserDefinedPredicate; import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; import org.apache.parquet.hadoop.metadata.CompressionCodecName; +import org.apache.parquet.io.api.Binary; import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName; import org.apache.parquet.schema.Types; @@ -51,11 +53,13 @@ import static org.apache.parquet.filter2.predicate.FilterApi.eq; import static org.apache.parquet.filter2.predicate.FilterApi.gt; import static org.apache.parquet.filter2.predicate.FilterApi.gtEq; +import static org.apache.parquet.filter2.predicate.FilterApi.in; import static org.apache.parquet.filter2.predicate.FilterApi.intColumn; import static org.apache.parquet.filter2.predicate.FilterApi.lt; import static org.apache.parquet.filter2.predicate.FilterApi.ltEq; import static org.apache.parquet.filter2.predicate.FilterApi.not; import static org.apache.parquet.filter2.predicate.FilterApi.notEq; +import static org.apache.parquet.filter2.predicate.FilterApi.notIn; import static org.apache.parquet.filter2.predicate.FilterApi.or; import static org.apache.parquet.filter2.predicate.FilterApi.userDefined; import static org.apache.parquet.filter2.statisticslevel.StatisticsFilter.canDrop; @@ -278,6 +282,93 @@ public void testGtEq() { assertFalse(canDrop(gtEq(doubleColumn, 0.1), missingMinMaxColumnMetas)); } + @Test + public void testInNotIn() { + Set values1 = new HashSet<>(); + values1.add(10); + values1.add(12); + values1.add(15); + values1.add(17); + values1.add(19); + assertFalse(canDrop(in(intColumn, values1), columnMetas)); + assertFalse(canDrop(notIn(intColumn, values1), columnMetas)); + + Set values2 = new HashSet<>(); + values2.add(109); + values2.add(2); + values2.add(5); + values2.add(117); + values2.add(101); + assertFalse(canDrop(in(intColumn, values2), columnMetas)); + assertFalse(canDrop(notIn(intColumn, values2), columnMetas)); + + Set values3 = new HashSet<>(); + values3.add(1); + values3.add(2); + values3.add(5); + values3.add(7); + values3.add(10); + assertFalse(canDrop(in(intColumn, values3), columnMetas)); + assertFalse(canDrop(notIn(intColumn, values3), columnMetas)); + + Set values4 = new HashSet<>(); + values4.add(50); + values4.add(60); + assertFalse(canDrop(in(intColumn, values4), missingMinMaxColumnMetas)); + assertFalse(canDrop(notIn(intColumn, values4), missingMinMaxColumnMetas)); + + Set values5 = new HashSet<>(); + values5.add(1.0); + values5.add(2.0); + values5.add(95.0); + values5.add(107.0); + values5.add(99.0); + assertFalse(canDrop(in(doubleColumn, values5), columnMetas)); + assertFalse(canDrop(notIn(doubleColumn, values5), columnMetas)); + + Set values6 = new HashSet<>(); + values6.add(Binary.fromString("test1")); + values6.add(Binary.fromString("test2")); + assertTrue(canDrop(in(missingColumn, values6), columnMetas)); + assertFalse(canDrop(notIn(missingColumn, values6), columnMetas)); + + Set values7 = new HashSet<>(); + values7.add(null); + assertFalse(canDrop(in(intColumn, values7), nullColumnMetas)); + assertFalse(canDrop(notIn(intColumn, values7), nullColumnMetas)); + + Set values8 = new HashSet<>(); + values8.add(null); + assertFalse(canDrop(in(missingColumn, values8), columnMetas)); + assertFalse(canDrop(notIn(missingColumn, values8), columnMetas)); + + IntStatistics statsNoNulls = new IntStatistics(); + statsNoNulls.setMinMax(10, 100); + statsNoNulls.setNumNulls(0); + + IntStatistics statsSomeNulls = new IntStatistics(); + statsSomeNulls.setMinMax(10, 100); + statsSomeNulls.setNumNulls(3); + + Set values9 = new HashSet<>(); + values9.add(null); + assertTrue(canDrop(in(intColumn, values9), Arrays.asList( + getIntColumnMeta(statsNoNulls, 177L), + getDoubleColumnMeta(doubleStats, 177L)))); + + assertFalse(canDrop(notIn(intColumn, values9), Arrays.asList( + getIntColumnMeta(statsNoNulls, 177L), + getDoubleColumnMeta(doubleStats, 177L)))); + + assertFalse(canDrop(in(intColumn, values9), Arrays.asList( + getIntColumnMeta(statsSomeNulls, 177L), + getDoubleColumnMeta(doubleStats, 177L)))); + + assertFalse(canDrop(notIn(intColumn, values9), Arrays.asList( + getIntColumnMeta(statsSomeNulls, 177L), + getDoubleColumnMeta(doubleStats, 177L)))); + } + @Test public void testAnd() { FilterPredicate yes = eq(intColumn, 9); diff --git a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestBloomFiltering.java b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestBloomFiltering.java index 4ebe15aecf..b07fccddde 100644 --- a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestBloomFiltering.java +++ b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestBloomFiltering.java @@ -48,10 +48,12 @@ import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Random; +import java.util.Set; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -311,6 +313,30 @@ record -> record.getId() == 1234L, assertCorrectFiltering( record -> "miller".equals(record.getName()), eq(binaryColumn("name"), Binary.fromString("miller"))); + + Set values1 = new HashSet<>(); + values1.add(Binary.fromString("miller")); + values1.add(Binary.fromString("anderson")); + + assertCorrectFiltering( + record -> "miller".equals(record.getName()) || "anderson".equals(record.getName()), + in(binaryColumn("name"), values1)); + + Set values2 = new HashSet<>(); + values2.add(Binary.fromString("miller")); + values2.add(Binary.fromString("alien")); + + assertCorrectFiltering( + record -> "miller".equals(record.getName()), + in(binaryColumn("name"), values2)); + + Set values3 = new HashSet<>(); + values3.add(Binary.fromString("alien")); + values3.add(Binary.fromString("predator")); + + assertCorrectFiltering( + record -> "dummy".equals(record.getName()), + in(binaryColumn("name"), values3)); } @Test diff --git a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestColumnIndexFiltering.java b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestColumnIndexFiltering.java index a66533d0f9..5e181059f0 100644 --- a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestColumnIndexFiltering.java +++ b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestColumnIndexFiltering.java @@ -25,6 +25,7 @@ import static org.apache.parquet.filter2.predicate.FilterApi.doubleColumn; import static org.apache.parquet.filter2.predicate.FilterApi.eq; import static org.apache.parquet.filter2.predicate.FilterApi.gtEq; +import static org.apache.parquet.filter2.predicate.FilterApi.in; import static org.apache.parquet.filter2.predicate.FilterApi.longColumn; import static org.apache.parquet.filter2.predicate.FilterApi.lt; import static org.apache.parquet.filter2.predicate.FilterApi.ltEq; @@ -53,10 +54,12 @@ import java.util.Collections; import java.util.Comparator; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Random; +import java.util.Set; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -377,12 +380,48 @@ public void testSimpleFiltering() throws IOException { assertCorrectFiltering( record -> record.getId() == 1234, eq(longColumn("id"), 1234l)); + + Set idSet = new HashSet<>(); + idSet.add(1234l); + idSet.add(5678l); + idSet.add(1357l); + idSet.add(111l); + idSet.add(6666l); + idSet.add(2l); + idSet.add(2468l); + + assertCorrectFiltering( + record -> (record.getId() == 1234 || record.getId() == 5678 || record.getId() == 1357 || + record.getId() == 111 || record.getId() == 6666 || record.getId() == 2 || record.getId() == 2468), + in(longColumn("id"), idSet) + ); + assertCorrectFiltering( record -> "miller".equals(record.getName()), eq(binaryColumn("name"), Binary.fromString("miller"))); + + Set nameSet = new HashSet<>(); + nameSet.add(Binary.fromString("anderson")); + nameSet.add(Binary.fromString("miller")); + nameSet.add(Binary.fromString("thomas")); + nameSet.add(Binary.fromString("williams")); + + assertCorrectFiltering( + record -> ("anderson".equals(record.getName()) || "miller".equals(record.getName()) || + "thomas".equals(record.getName()) || "williams".equals(record.getName())), + in(binaryColumn("name"), nameSet) + ); + assertCorrectFiltering( record -> record.getName() == null, eq(binaryColumn("name"), null)); + + Set nullSet = new HashSet<>(); + nullSet.add(null); + + assertCorrectFiltering( + record -> record.getName() == null, + in(binaryColumn("name"), nullSet)); } @Test