From 91ce4a5cc16d2199ace3aae6ef82a4cdb165d44f Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Fri, 15 Oct 2021 15:52:59 +0800 Subject: [PATCH 1/2] Improve byte array sort perf by unify getPrefix function of UTF8String and ByteArray --- .../apache/spark/unsafe/types/ByteArray.java | 36 ++++++++++++++---- .../apache/spark/unsafe/types/UTF8String.java | 38 +------------------ 2 files changed, 30 insertions(+), 44 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index 4383ee1533c2..d5cba5d71b25 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -17,6 +17,7 @@ package org.apache.spark.unsafe.types; +import java.nio.ByteOrder; import java.util.Arrays; import com.google.common.primitives.Ints; @@ -26,6 +27,8 @@ public final class ByteArray { public static final byte[] EMPTY_BYTE = new byte[0]; + private static final boolean IS_LITTLE_ENDIAN = + ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; /** * Writes the content of a byte array into a memory address, identified by an object and an @@ -42,15 +45,34 @@ public static void writeToMemory(byte[] src, Object target, long targetOffset) { public static long getPrefix(byte[] bytes) { if (bytes == null) { return 0L; + } + return getPrefix(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length); + } + + protected static long getPrefix(Object base, long offset, int numBytes) { + // Since JVMs are either 4-byte aligned or 8-byte aligned, we check the size of the bytes. + // If size is 0, just return 0. + // If size is between 1 and 4 (inclusive), assume data is 4-byte aligned under the hood and + // use a getInt to fetch the prefix. + // If size is greater than 4, assume we have at least 8 bytes of data to fetch. + // After getting the data, we use a mask to mask out data that is not part of the bytes. + final long p; + final long mask; + if (numBytes >= 8) { + p = Platform.getLong(base, offset); + mask = 0; + } else if (numBytes > 4) { + p = Platform.getLong(base, offset); + mask = (1L << (8 - numBytes) * 8) - 1; + } else if (numBytes > 0) { + long pRaw = Platform.getInt(base, offset); + p = IS_LITTLE_ENDIAN ? pRaw : (pRaw << 32); + mask = (1L << (8 - numBytes) * 8) - 1; } else { - final int minLen = Math.min(bytes.length, 8); - long p = 0; - for (int i = 0; i < minLen; ++i) { - p |= ((long) Platform.getByte(bytes, Platform.BYTE_ARRAY_OFFSET + i) & 0xff) - << (56 - 8 * i); - } - return p; + p = 0; + mask = 0; } + return (IS_LITTLE_ENDIAN ? java.lang.Long.reverseBytes(p) : p) & ~mask; } public static byte[] subStringSQL(byte[] bytes, int pos, int len) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 9fdbd507a799..6c3adf2c798c 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -246,43 +246,7 @@ public int numChars() { * Returns a 64-bit integer that can be used as the prefix used in sorting. */ public long getPrefix() { - // Since JVMs are either 4-byte aligned or 8-byte aligned, we check the size of the string. - // If size is 0, just return 0. - // If size is between 0 and 4 (inclusive), assume data is 4-byte aligned under the hood and - // use a getInt to fetch the prefix. - // If size is greater than 4, assume we have at least 8 bytes of data to fetch. - // After getting the data, we use a mask to mask out data that is not part of the string. - long p; - long mask = 0; - if (IS_LITTLE_ENDIAN) { - if (numBytes >= 8) { - p = Platform.getLong(base, offset); - } else if (numBytes > 4) { - p = Platform.getLong(base, offset); - mask = (1L << (8 - numBytes) * 8) - 1; - } else if (numBytes > 0) { - p = (long) Platform.getInt(base, offset); - mask = (1L << (8 - numBytes) * 8) - 1; - } else { - p = 0; - } - p = java.lang.Long.reverseBytes(p); - } else { - // byteOrder == ByteOrder.BIG_ENDIAN - if (numBytes >= 8) { - p = Platform.getLong(base, offset); - } else if (numBytes > 4) { - p = Platform.getLong(base, offset); - mask = (1L << (8 - numBytes) * 8) - 1; - } else if (numBytes > 0) { - p = ((long) Platform.getInt(base, offset)) << 32; - mask = (1L << (8 - numBytes) * 8) - 1; - } else { - p = 0; - } - } - p &= ~mask; - return p; + return ByteArray.getPrefix(base, offset, numBytes); } /** From 9fcc06b46d43a8014259140c37f8d725a0354d28 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Sat, 16 Oct 2021 18:28:36 +0800 Subject: [PATCH 2/2] address comment --- .../apache/spark/unsafe/types/ByteArray.java | 2 +- .../spark/unsafe/array/ByteArraySuite.java | 51 +++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index d5cba5d71b25..39442c3dd2aa 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -49,7 +49,7 @@ public static long getPrefix(byte[] bytes) { return getPrefix(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length); } - protected static long getPrefix(Object base, long offset, int numBytes) { + static long getPrefix(Object base, long offset, int numBytes) { // Since JVMs are either 4-byte aligned or 8-byte aligned, we check the size of the bytes. // If size is 0, just return 0. // If size is between 1 and 4 (inclusive), assume data is 4-byte aligned under the hood and diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java new file mode 100644 index 000000000000..703610dfde44 --- /dev/null +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/ByteArraySuite.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.array; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.ByteArray; +import org.junit.Assert; +import org.junit.Test; + +public class ByteArraySuite { + private long getPrefixByByte(byte[] bytes) { + final int minLen = Math.min(bytes.length, 8); + long p = 0; + for (int i = 0; i < minLen; ++i) { + p |= ((long) Platform.getByte(bytes, Platform.BYTE_ARRAY_OFFSET + i) & 0xff) + << (56 - 8 * i); + } + return p; + } + + @Test + public void testGetPrefix() { + for (int i = 0; i <= 9; i++) { + byte[] bytes = new byte[i]; + int prefix = i - 1; + while (prefix >= 0) { + bytes[prefix] = (byte) prefix; + prefix -= 1; + } + + long result = ByteArray.getPrefix(bytes); + long expected = getPrefixByByte(bytes); + Assert.assertEquals(result, expected); + } + } +}