Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor
import org.apache.spark.sql.execution.vectorized.WritableColumnVector
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal}

/**
* An `Iterator` like trait used to extract values from columnar byte buffer. When a value is
Expand Down Expand Up @@ -111,6 +111,10 @@ private[columnar] class IntervalColumnAccessor(buffer: ByteBuffer)
extends BasicColumnAccessor[CalendarInterval](buffer, CALENDAR_INTERVAL)
with NullableColumnAccessor

private[columnar] class VariantColumnAccessor(buffer: ByteBuffer)
extends BasicColumnAccessor[VariantVal](buffer, VARIANT)
with NullableColumnAccessor

private[columnar] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType)
extends NativeColumnAccessor(buffer, COMPACT_DECIMAL(dataType))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BI
private[columnar]
class IntervalColumnBuilder extends ComplexColumnBuilder(new IntervalColumnStats, CALENDAR_INTERVAL)

private[columnar]
class VariantColumnBuilder extends ComplexColumnBuilder(new VariantColumnStats, VARIANT)

private[columnar] class CompactDecimalColumnBuilder(dataType: DecimalType)
extends NativeColumnBuilder(new DecimalColumnStats(dataType), COMPACT_DECIMAL(dataType))

Expand Down Expand Up @@ -189,6 +192,7 @@ private[columnar] object ColumnBuilder {
case s: StringType => new StringColumnBuilder(s)
case BinaryType => new BinaryColumnBuilder
case CalendarIntervalType => new IntervalColumnBuilder
case VariantType => new VariantColumnBuilder
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
new CompactDecimalColumnBuilder(dt)
case dt: DecimalType => new DecimalColumnBuilder(dt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,21 @@ private[columnar] final class BinaryColumnStats extends ColumnStats {
Array[Any](null, null, nullCount, count, sizeInBytes)
}

private[columnar] final class VariantColumnStats extends ColumnStats {
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val size = VARIANT.actualSize(row, ordinal)
sizeInBytes += size
count += 1
} else {
gatherNullStats()
}
}

override def collectedStatistics: Array[Any] =
Array[Any](null, null, nullCount, count, sizeInBytes)
}

private[columnar] final class IntervalColumnStats extends ColumnStats {
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ import scala.annotation.tailrec

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types.{PhysicalArrayType, PhysicalBinaryType, PhysicalBooleanType, PhysicalByteType, PhysicalCalendarIntervalType, PhysicalDataType, PhysicalDecimalType, PhysicalDoubleType, PhysicalFloatType, PhysicalIntegerType, PhysicalLongType, PhysicalMapType, PhysicalNullType, PhysicalShortType, PhysicalStringType, PhysicalStructType}
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.errors.ExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}


/**
Expand Down Expand Up @@ -815,6 +815,45 @@ private[columnar] object CALENDAR_INTERVAL extends ColumnType[CalendarInterval]
}
}

/**
* Used to append/extract Java VariantVals into/from the underlying [[ByteBuffer]] of a column.
*
* Variants are encoded in `append` as:
* | value size | metadata size | value binary | metadata binary |
* and are only expected to be decoded in `extract`.
*/
private[columnar] object VARIANT
extends ColumnType[VariantVal] with DirectCopyColumnType[VariantVal] {
override def dataType: PhysicalDataType = PhysicalVariantType

/** Chosen to match the default size set in `VariantType`. */
override def defaultSize: Int = 2048

override def getField(row: InternalRow, ordinal: Int): VariantVal = row.getVariant(ordinal)

override def setField(row: InternalRow, ordinal: Int, value: VariantVal): Unit =
row.update(ordinal, value)

override def append(v: VariantVal, buffer: ByteBuffer): Unit = {
val valueSize = v.getValue().length
val metadataSize = v.getMetadata().length
ByteBufferHelper.putInt(buffer, valueSize)
ByteBufferHelper.putInt(buffer, metadataSize)
ByteBufferHelper.copyMemory(ByteBuffer.wrap(v.getValue()), buffer, valueSize)
ByteBufferHelper.copyMemory(ByteBuffer.wrap(v.getMetadata()), buffer, metadataSize)
}

override def extract(buffer: ByteBuffer): VariantVal = {
val valueSize = ByteBufferHelper.getInt(buffer)
val metadataSize = ByteBufferHelper.getInt(buffer)
val valueBuffer = ByteBuffer.allocate(valueSize)
ByteBufferHelper.copyMemory(buffer, valueBuffer, valueSize)
val metadataBuffer = ByteBuffer.allocate(metadataSize)
ByteBufferHelper.copyMemory(buffer, metadataBuffer, metadataSize)
new VariantVal(valueBuffer.array(), metadataBuffer.array())
}
}

private[columnar] object ColumnType {
@tailrec
def apply(dataType: DataType): ColumnType[_] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
case _: StringType => classOf[StringColumnAccessor].getName
case BinaryType => classOf[BinaryColumnAccessor].getName
case CalendarIntervalType => classOf[IntervalColumnAccessor].getName
case VariantType => classOf[VariantColumnAccessor].getName
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
classOf[CompactDecimalColumnAccessor].getName
case dt: DecimalType => classOf[DecimalColumnAccessor].getName
Expand All @@ -101,7 +102,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
val createCode = dt match {
case t if CodeGenerator.isPrimitiveType(dt) =>
s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
case NullType | BinaryType | CalendarIntervalType =>
case NullType | BinaryType | CalendarIntervalType | VariantType =>
s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
case other =>
s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder),
Expand Down
72 changes: 72 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,78 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval
}
}

test("variant in a cached row-based df") {
val query = """select
parse_json(format_string('{\"a\": %s}', id)) v,
cast(null as variant) as null_v,
case when id % 2 = 0 then parse_json(cast(id as string)) else null end as some_null
from range(0, 10)"""
val df = spark.sql(query)
df.cache()

val expected = spark.sql(query)
checkAnswer(df, expected.collect())
}

test("struct of variant in a cached row-based df") {
val query = """select named_struct(
'v', parse_json(format_string('{\"a\": %s}', id)),
'null_v', cast(null as variant),
'some_null', case when id % 2 = 0 then parse_json(cast(id as string)) else null end
) v
from range(0, 10)"""
val df = spark.sql(query)
df.cache()

val expected = spark.sql(query)
checkAnswer(df, expected.collect())
}

test("array of variant in a cached row-based df") {
val query = """select array(
parse_json(cast(id as string)),
parse_json(format_string('{\"a\": %s}', id)),
null,
case when id % 2 = 0 then parse_json(cast(id as string)) else null end) v
from range(0, 10)"""
val df = spark.sql(query)
df.cache()

val expected = spark.sql(query)
checkAnswer(df, expected.collect())
}

test("map of variant in a cached row-based df") {
val query = """select map(
'v', parse_json(format_string('{\"a\": %s}', id)),
'null_v', cast(null as variant),
'some_null', case when id % 2 = 0 then parse_json(cast(id as string)) else null end
) v
from range(0, 10)"""
val df = spark.sql(query)
df.cache()

val expected = spark.sql(query)
checkAnswer(df, expected.collect())
}

test("variant in a cached column-based df") {
withTable("t") {
val query = """select named_struct(
'v', parse_json(format_string('{\"a\": %s}', id)),
'null_v', cast(null as variant),
'some_null', case when id % 2 = 0 then parse_json(cast(id as string)) else null end
) v
from range(0, 10)"""
spark.sql(query).write.format("parquet").mode("overwrite").saveAsTable("t")
val df = spark.sql("select * from t")
df.cache()

val expected = spark.sql(query)
checkAnswer(df, expected.collect())
}
}

test("variant_get size") {
val largeKey = "x" * 1000
val df = Seq(s"""{ "$largeKey": {"a" : 1 },
Expand Down