Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,22 @@
import com.google.common.primitives.UnsignedLongs;

import org.apache.spark.annotation.Private;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.util.Utils;

import java.nio.ByteOrder;

import static org.apache.spark.unsafe.PlatformDependent.*;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are not used here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, they were: they're needed for BYTE_ARRAY_OFFSET.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This broke the build, so I'm hotfixing now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, just missed that, thanks!


@Private
public class PrefixComparators {
private PrefixComparators() {}

public static final StringPrefixComparator STRING = new StringPrefixComparator();
public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc();
public static final BinaryPrefixComparator BINARY = new BinaryPrefixComparator();
public static final BinaryPrefixComparatorDesc BINARY_DESC = new BinaryPrefixComparatorDesc();
public static final LongPrefixComparator LONG = new LongPrefixComparator();
public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc();
public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator();
Expand All @@ -52,6 +59,38 @@ public int compare(long bPrefix, long aPrefix) {
}
}

public static final class BinaryPrefixComparator extends PrefixComparator {
@Override
public int compare(long aPrefix, long bPrefix) {
return UnsignedLongs.compare(aPrefix, bPrefix);
}

public static long computePrefix(byte[] bytes) {
if (bytes == null) {
return 0L;
} else {
/**
* TODO: If a wrapper for BinaryType is created (SPARK-8786),
* these codes below will be in the wrapper class.
*/
final int minLen = Math.min(bytes.length, 8);
long p = 0;
for (int i = 0; i < minLen; ++i) {
p |= (128L + PlatformDependent.UNSAFE.getByte(bytes, BYTE_ARRAY_OFFSET + i))
<< (56 - 8 * i);
}
return p;
}
}
}

public static final class BinaryPrefixComparatorDesc extends PrefixComparator {
@Override
public int compare(long bPrefix, long aPrefix) {
return UnsignedLongs.compare(aPrefix, bPrefix);
}
}

public static final class LongPrefixComparator extends PrefixComparator {
@Override
public int compare(long a, long b) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,44 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
}

test("Binary prefix comparator") {

def compareBinary(x: Array[Byte], y: Array[Byte]): Int = {
for (i <- 0 until x.length; if i < y.length) {
val res = x(i).compare(y(i))
if (res != 0) return res
}
x.length - y.length
}

def testPrefixComparison(x: Array[Byte], y: Array[Byte]): Unit = {
val s1Prefix = PrefixComparators.BinaryPrefixComparator.computePrefix(x)
val s2Prefix = PrefixComparators.BinaryPrefixComparator.computePrefix(y)
val prefixComparisonResult =
PrefixComparators.BINARY.compare(s1Prefix, s2Prefix)
assert(
(prefixComparisonResult == 0) ||
(prefixComparisonResult < 0 && compareBinary(x, y) < 0) ||
(prefixComparisonResult > 0 && compareBinary(x, y) > 0))
}

// scalastyle:off
val regressionTests = Table(
("s1", "s2"),
("abc", "世界"),
("你好", "世界"),
("你好123", "你好122")
)
// scalastyle:on

forAll (regressionTests) { (s1: String, s2: String) =>
testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8"))
}
forAll { (s1: String, s2: String) =>
testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8"))
}
}

test("double prefix comparator handles NaNs properly") {
val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L)
val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator

abstract sealed class SortDirection
Expand Down Expand Up @@ -63,6 +64,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val childCode = child.child.gen(ctx)
val input = childCode.primitive
val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getName
val DoublePrefixCmp = classOf[DoublePrefixComparator].getName

val (nullValue: Long, prefixCode: String) = child.child.dataType match {
Expand All @@ -76,6 +78,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
(DoublePrefixComparator.computePrefix(Double.NegativeInfinity),
s"$DoublePrefixCmp.computePrefix((double)$input)")
case StringType => (0L, s"$input.getPrefix()")
case BinaryType => (0L, s"$BinaryPrefixCmp.computePrefix($input)")
case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
val prefix = if (dt.precision <= Decimal.MAX_LONG_DIGITS) {
s"$input.toUnscaledLong()"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ object SortPrefixUtils {
sortOrder.dataType match {
case StringType =>
if (sortOrder.isAscending) PrefixComparators.STRING else PrefixComparators.STRING_DESC
case BinaryType =>
if (sortOrder.isAscending) PrefixComparators.BINARY else PrefixComparators.BINARY_DESC
case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | TimestampType =>
if (sortOrder.isAscending) PrefixComparators.LONG else PrefixComparators.LONG_DESC
case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
Expand Down