From 6e6989b2153ce798f779d00a4b54b2c8e7c2a70a Mon Sep 17 00:00:00 2001 From: James Duong Date: Fri, 10 Nov 2023 07:16:53 -0800 Subject: [PATCH] GH-38662: [Java] Add comparators Add comparators for: - FixedSizeBinaryVector - LargeListVector - FixedSizeListVector - NullVector --- .../sort/DefaultVectorComparators.java | 140 ++++++++++++++++-- .../sort/TestDefaultVectorComparator.java | 132 +++++++++++++++++ 2 files changed, 259 insertions(+), 13 deletions(-) diff --git a/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java b/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java index 4f9c8b7d71bab..588876aa99059 100644 --- a/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java +++ b/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java @@ -32,11 +32,13 @@ import org.apache.arrow.vector.Decimal256Vector; import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.DurationVector; +import org.apache.arrow.vector.FixedSizeBinaryVector; import org.apache.arrow.vector.Float4Vector; import org.apache.arrow.vector.Float8Vector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.IntervalDayVector; import org.apache.arrow.vector.IntervalMonthDayNanoVector; +import org.apache.arrow.vector.NullVector; import org.apache.arrow.vector.SmallIntVector; import org.apache.arrow.vector.TimeMicroVector; import org.apache.arrow.vector.TimeMilliVector; @@ -50,7 +52,9 @@ import org.apache.arrow.vector.UInt8Vector; import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VariableWidthVector; -import org.apache.arrow.vector.complex.BaseRepeatedValueVector; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.RepeatedValueVector; +import org.apache.arrow.vector.holders.NullableFixedSizeBinaryHolder; /** * Default comparator implementations for different types of vectors. @@ -111,13 +115,21 @@ public static VectorValueComparator createDefaultComp return (VectorValueComparator) new TimeSecComparator(); } else if (vector instanceof TimeStampVector) { return (VectorValueComparator) new TimeStampComparator(); + } else if (vector instanceof FixedSizeBinaryVector) { + return (VectorValueComparator) new FixedSizeBinaryComparator(); } } else if (vector instanceof VariableWidthVector) { return (VectorValueComparator) new VariableWidthComparator(); - } else if (vector instanceof BaseRepeatedValueVector) { + } else if (vector instanceof RepeatedValueVector) { VectorValueComparator innerComparator = - createDefaultComparator(((BaseRepeatedValueVector) vector).getDataVector()); + createDefaultComparator(((RepeatedValueVector) vector).getDataVector()); return new RepeatedValueComparator(innerComparator); + } else if (vector instanceof FixedSizeListVector) { + VectorValueComparator innerComparator = + createDefaultComparator(((FixedSizeListVector) vector).getDataVector()); + return new FixedSizeListComparator(innerComparator); + } else if (vector instanceof NullVector) { + return (VectorValueComparator) new NullComparator(); } throw new IllegalArgumentException("No default comparator for " + vector.getClass().getCanonicalName()); @@ -674,6 +686,61 @@ public VectorValueComparator createNew() { } } + /** + * Default comparator for {@link org.apache.arrow.vector.FixedSizeBinaryVector}. + * The comparison is in lexicographic order, with null comes first. + */ + public static class FixedSizeBinaryComparator extends VectorValueComparator { + + @Override + public int compare(int index1, int index2) { + NullableFixedSizeBinaryHolder holder1 = new NullableFixedSizeBinaryHolder(); + NullableFixedSizeBinaryHolder holder2 = new NullableFixedSizeBinaryHolder(); + vector1.get(index1, holder1); + vector2.get(index2, holder2); + + return ByteFunctionHelpers.compare( + holder1.buffer, 0, holder1.byteWidth, holder2.buffer, 0, holder2.byteWidth); + } + + @Override + public int compareNotNull(int index1, int index2) { + NullableFixedSizeBinaryHolder holder1 = new NullableFixedSizeBinaryHolder(); + NullableFixedSizeBinaryHolder holder2 = new NullableFixedSizeBinaryHolder(); + vector1.get(index1, holder1); + vector2.get(index2, holder2); + + return ByteFunctionHelpers.compare( + holder1.buffer, 0, holder1.byteWidth, holder2.buffer, 0, holder2.byteWidth); + } + + @Override + public VectorValueComparator createNew() { + return new FixedSizeBinaryComparator(); + } + } + + /** + * Default comparator for {@link org.apache.arrow.vector.NullVector}. + */ + public static class NullComparator extends VectorValueComparator { + @Override + public int compare(int index1, int index2) { + // Values are always equal (and are always null). + return 0; + } + + @Override + public int compareNotNull(int index1, int index2) { + throw new AssertionError("Cannot compare non-null values in a NullVector."); + } + + @Override + public VectorValueComparator createNew() { + return new NullComparator(); + } + } + /** * Default comparator for {@link org.apache.arrow.vector.VariableWidthVector}. * The comparison is in lexicographic order, with null comes first. @@ -705,14 +772,14 @@ public VectorValueComparator createNew() { } /** - * Default comparator for {@link BaseRepeatedValueVector}. + * Default comparator for {@link RepeatedValueVector}. * It works by comparing the underlying vector in a lexicographic order. * @param inner vector type. */ public static class RepeatedValueComparator - extends VectorValueComparator { + extends VectorValueComparator { - private VectorValueComparator innerComparator; + private final VectorValueComparator innerComparator; public RepeatedValueComparator(VectorValueComparator innerComparator) { this.innerComparator = innerComparator; @@ -720,16 +787,16 @@ public RepeatedValueComparator(VectorValueComparator innerComparator) { @Override public int compareNotNull(int index1, int index2) { - int startIdx1 = vector1.getOffsetBuffer().getInt(index1 * OFFSET_WIDTH); - int startIdx2 = vector2.getOffsetBuffer().getInt(index2 * OFFSET_WIDTH); + int startIdx1 = vector1.getOffsetBuffer().getInt((long) index1 * OFFSET_WIDTH); + int startIdx2 = vector2.getOffsetBuffer().getInt((long) index2 * OFFSET_WIDTH); - int endIdx1 = vector1.getOffsetBuffer().getInt((index1 + 1) * OFFSET_WIDTH); - int endIdx2 = vector2.getOffsetBuffer().getInt((index2 + 1) * OFFSET_WIDTH); + int endIdx1 = vector1.getOffsetBuffer().getInt((long) (index1 + 1) * OFFSET_WIDTH); + int endIdx2 = vector2.getOffsetBuffer().getInt((long) (index2 + 1) * OFFSET_WIDTH); int length1 = endIdx1 - startIdx1; int length2 = endIdx2 - startIdx2; - int length = length1 < length2 ? length1 : length2; + int length = Math.min(length1, length2); for (int i = 0; i < length; i++) { int result = innerComparator.compare(startIdx1 + i, startIdx2 + i); @@ -741,13 +808,60 @@ public int compareNotNull(int index1, int index2) { } @Override - public VectorValueComparator createNew() { + public VectorValueComparator createNew() { VectorValueComparator newInnerComparator = innerComparator.createNew(); return new RepeatedValueComparator<>(newInnerComparator); } @Override - public void attachVectors(BaseRepeatedValueVector vector1, BaseRepeatedValueVector vector2) { + public void attachVectors(RepeatedValueVector vector1, RepeatedValueVector vector2) { + this.vector1 = vector1; + this.vector2 = vector2; + + innerComparator.attachVectors((T) vector1.getDataVector(), (T) vector2.getDataVector()); + } + } + + /** + * Default comparator for {@link RepeatedValueVector}. + * It works by comparing the underlying vector in a lexicographic order. + * @param inner vector type. + */ + public static class FixedSizeListComparator + extends VectorValueComparator { + + private final VectorValueComparator innerComparator; + + public FixedSizeListComparator(VectorValueComparator innerComparator) { + this.innerComparator = innerComparator; + } + + @Override + public int compareNotNull(int index1, int index2) { + int length1 = vector1.getListSize(); + int length2 = vector2.getListSize(); + + int length = Math.min(length1, length2); + int startIdx1 = vector1.getElementStartIndex(index1); + int startIdx2 = vector2.getElementStartIndex(index2); + + for (int i = 0; i < length; i++) { + int result = innerComparator.compare(startIdx1 + i, startIdx2 + i); + if (result != 0) { + return result; + } + } + return length1 - length2; + } + + @Override + public VectorValueComparator createNew() { + VectorValueComparator newInnerComparator = innerComparator.createNew(); + return new FixedSizeListComparator<>(newInnerComparator); + } + + @Override + public void attachVectors(FixedSizeListVector vector1, FixedSizeListVector vector2) { this.vector1 = vector1; this.vector2 = vector2; diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java index bdae85110aa62..43c634b7647fb 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java @@ -31,12 +31,14 @@ import org.apache.arrow.vector.Decimal256Vector; import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.DurationVector; +import org.apache.arrow.vector.FixedSizeBinaryVector; import org.apache.arrow.vector.Float4Vector; import org.apache.arrow.vector.Float8Vector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.IntervalDayVector; import org.apache.arrow.vector.LargeVarBinaryVector; import org.apache.arrow.vector.LargeVarCharVector; +import org.apache.arrow.vector.NullVector; import org.apache.arrow.vector.SmallIntVector; import org.apache.arrow.vector.TimeMicroVector; import org.apache.arrow.vector.TimeMilliVector; @@ -52,6 +54,8 @@ import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.LargeListVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.testing.ValueVectorDataPopulator; import org.apache.arrow.vector.types.TimeUnit; @@ -158,6 +162,61 @@ public void testCopiedComparatorForLists() { } } + private FixedSizeListVector createFixedSizeListVector(int count) { + FixedSizeListVector listVector = FixedSizeListVector.empty("list vector", count, allocator); + Types.MinorType type = Types.MinorType.INT; + listVector.addOrGetVector(FieldType.nullable(type.getType())); + listVector.allocateNew(); + + IntVector dataVector = (IntVector) listVector.getDataVector(); + + for (int i = 0; i < count; i++) { + dataVector.set(i, i); + } + dataVector.setValueCount(count); + + listVector.setNotNull(0); + listVector.setValueCount(1); + + return listVector; + } + + @Test + public void testCompareFixedSizeLists() { + try (FixedSizeListVector listVector1 = createFixedSizeListVector(10); + FixedSizeListVector listVector2 = createFixedSizeListVector(11)) { + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(listVector1); + comparator.attachVectors(listVector1, listVector2); + + // prefix is smaller + assertTrue(comparator.compare(0, 0) < 0); + } + + try (FixedSizeListVector listVector1 = createFixedSizeListVector(11); + FixedSizeListVector listVector2 = createFixedSizeListVector(11)) { + ((IntVector) listVector2.getDataVector()).set(10, 110); + + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(listVector1); + comparator.attachVectors(listVector1, listVector2); + + // breaking tie by the last element + assertTrue(comparator.compare(0, 0) < 0); + } + + try (FixedSizeListVector listVector1 = createFixedSizeListVector(10); + FixedSizeListVector listVector2 = createFixedSizeListVector(10)) { + + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(listVector1); + comparator.attachVectors(listVector1, listVector2); + + // list vector elements equal + assertTrue(comparator.compare(0, 0) == 0); + } + } + @Test public void testCompareUInt1() { try (UInt1Vector vec = new UInt1Vector("", allocator)) { @@ -845,6 +904,65 @@ public void testCompareTimeStamp() { } } + @Test + public void testCompareFixedSizeBinary() { + try (FixedSizeBinaryVector vector1 = new FixedSizeBinaryVector("test1", allocator, 2); + FixedSizeBinaryVector vector2 = new FixedSizeBinaryVector("test1", allocator, 3)) { + vector1.allocateNew(); + vector2.allocateNew(); + vector1.set(0, new byte[] {1, 1}); + vector2.set(0, new byte[] {1, 1, 0}); + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(vector1); + comparator.attachVectors(vector1, vector2); + + // prefix is smaller + assertTrue(comparator.compare(0, 0) < 0); + } + + try (FixedSizeBinaryVector vector1 = new FixedSizeBinaryVector("test1", allocator, 3); + FixedSizeBinaryVector vector2 = new FixedSizeBinaryVector("test1", allocator, 3)) { + vector1.allocateNew(); + vector2.allocateNew(); + vector1.set(0, new byte[] {1, 1, 0}); + vector2.set(0, new byte[] {1, 1, 1}); + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(vector1); + comparator.attachVectors(vector1, vector2); + + // breaking tie by the last element + assertTrue(comparator.compare(0, 0) < 0); + } + + try (FixedSizeBinaryVector vector1 = new FixedSizeBinaryVector("test1", allocator, 3); + FixedSizeBinaryVector vector2 = new FixedSizeBinaryVector("test1", allocator, 3)) { + vector1.allocateNew(); + vector2.allocateNew(); + vector1.set(0, new byte[] {1, 1, 1}); + vector2.set(0, new byte[] {1, 1, 1}); + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(vector1); + comparator.attachVectors(vector1, vector2); + + // list vector elements equal + assertTrue(comparator.compare(0, 0) == 0); + } + } + + @Test + public void testCompareNull() { + try (NullVector vec = new NullVector("test", + FieldType.notNullable(new ArrowType.Int(32, false)))) { + vec.setValueCount(2); + + VectorValueComparator comparator = + DefaultVectorComparators.createDefaultComparator(vec); + comparator.attachVector(vec); + assertEquals(DefaultVectorComparators.NullComparator.class, comparator.getClass()); + assertEquals(0, comparator.compare(0, 1)); + } + } + @Test public void testCheckNullsOnCompareIsFalseForNonNullableVector() { try (IntVector vec = new IntVector("not nullable", @@ -937,4 +1055,18 @@ private static void verifyVariableWidthComparatorReturne VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); assertEquals(DefaultVectorComparators.VariableWidthComparator.class, comparator.getClass()); } + + @Test + public void testRepeatedDefaultComparators() { + final FieldType type = FieldType.nullable(Types.MinorType.INT.getType()); + try (final LargeListVector vector = new LargeListVector("list", allocator, type, null)) { + vector.addOrGetVector(FieldType.nullable(type.getType())); + verifyRepeatedComparatorReturned(vector); + } + } + + private static void verifyRepeatedComparatorReturned(V vec) { + VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + assertEquals(DefaultVectorComparators.RepeatedValueComparator.class, comparator.getClass()); + } }