diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 9652a48e5270..2074649cc986 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor import org.apache.spark.sql.execution.vectorized.WritableColumnVector import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal} /** * An `Iterator` like trait used to extract values from columnar byte buffer. When a value is @@ -111,6 +111,10 @@ private[columnar] class IntervalColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[CalendarInterval](buffer, CALENDAR_INTERVAL) with NullableColumnAccessor +private[columnar] class VariantColumnAccessor(buffer: ByteBuffer) + extends BasicColumnAccessor[VariantVal](buffer, VARIANT) + with NullableColumnAccessor + private[columnar] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) extends NativeColumnAccessor(buffer, COMPACT_DECIMAL(dataType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala index 9fafdb794841..b65ef12f12d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala @@ -131,6 +131,9 @@ class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BI private[columnar] class IntervalColumnBuilder extends ComplexColumnBuilder(new IntervalColumnStats, CALENDAR_INTERVAL) +private[columnar] +class VariantColumnBuilder extends ComplexColumnBuilder(new VariantColumnStats, VARIANT) + private[columnar] class CompactDecimalColumnBuilder(dataType: DecimalType) extends NativeColumnBuilder(new DecimalColumnStats(dataType), COMPACT_DECIMAL(dataType)) @@ -189,6 +192,7 @@ private[columnar] object ColumnBuilder { case s: StringType => new StringColumnBuilder(s) case BinaryType => new BinaryColumnBuilder case CalendarIntervalType => new IntervalColumnBuilder + case VariantType => new VariantColumnBuilder case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => new CompactDecimalColumnBuilder(dt) case dt: DecimalType => new DecimalColumnBuilder(dt) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index 45f489cb13c2..4e4b3667fa24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -297,6 +297,21 @@ private[columnar] final class BinaryColumnStats extends ColumnStats { Array[Any](null, null, nullCount, count, sizeInBytes) } +private[columnar] final class VariantColumnStats extends ColumnStats { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + if (!row.isNullAt(ordinal)) { + val size = VARIANT.actualSize(row, ordinal) + sizeInBytes += size + count += 1 + } else { + gatherNullStats() + } + } + + override def collectedStatistics: Array[Any] = + Array[Any](null, null, nullCount, count, sizeInBytes) +} + private[columnar] final class IntervalColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index b8e63294f3cd..5cc3a3d83d4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -24,11 +24,11 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types.{PhysicalArrayType, PhysicalBinaryType, PhysicalBooleanType, PhysicalByteType, PhysicalCalendarIntervalType, PhysicalDataType, PhysicalDecimalType, PhysicalDoubleType, PhysicalFloatType, PhysicalIntegerType, PhysicalLongType, PhysicalMapType, PhysicalNullType, PhysicalShortType, PhysicalStringType, PhysicalStructType} +import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} /** @@ -815,6 +815,45 @@ private[columnar] object CALENDAR_INTERVAL extends ColumnType[CalendarInterval] } } +/** + * Used to append/extract Java VariantVals into/from the underlying [[ByteBuffer]] of a column. + * + * Variants are encoded in `append` as: + * | value size | metadata size | value binary | metadata binary | + * and are only expected to be decoded in `extract`. + */ +private[columnar] object VARIANT + extends ColumnType[VariantVal] with DirectCopyColumnType[VariantVal] { + override def dataType: PhysicalDataType = PhysicalVariantType + + /** Chosen to match the default size set in `VariantType`. */ + override def defaultSize: Int = 2048 + + override def getField(row: InternalRow, ordinal: Int): VariantVal = row.getVariant(ordinal) + + override def setField(row: InternalRow, ordinal: Int, value: VariantVal): Unit = + row.update(ordinal, value) + + override def append(v: VariantVal, buffer: ByteBuffer): Unit = { + val valueSize = v.getValue().length + val metadataSize = v.getMetadata().length + ByteBufferHelper.putInt(buffer, valueSize) + ByteBufferHelper.putInt(buffer, metadataSize) + ByteBufferHelper.copyMemory(ByteBuffer.wrap(v.getValue()), buffer, valueSize) + ByteBufferHelper.copyMemory(ByteBuffer.wrap(v.getMetadata()), buffer, metadataSize) + } + + override def extract(buffer: ByteBuffer): VariantVal = { + val valueSize = ByteBufferHelper.getInt(buffer) + val metadataSize = ByteBufferHelper.getInt(buffer) + val valueBuffer = ByteBuffer.allocate(valueSize) + ByteBufferHelper.copyMemory(buffer, valueBuffer, valueSize) + val metadataBuffer = ByteBuffer.allocate(metadataSize) + ByteBufferHelper.copyMemory(buffer, metadataBuffer, metadataSize) + new VariantVal(valueBuffer.array(), metadataBuffer.array()) + } +} + private[columnar] object ColumnType { @tailrec def apply(dataType: DataType): ColumnType[_] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 75416b878914..d07ebeb843bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -89,6 +89,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera case _: StringType => classOf[StringColumnAccessor].getName case BinaryType => classOf[BinaryColumnAccessor].getName case CalendarIntervalType => classOf[IntervalColumnAccessor].getName + case VariantType => classOf[VariantColumnAccessor].getName case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => classOf[CompactDecimalColumnAccessor].getName case dt: DecimalType => classOf[DecimalColumnAccessor].getName @@ -101,7 +102,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera val createCode = dt match { case t if CodeGenerator.isPrimitiveType(dt) => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" - case NullType | BinaryType | CalendarIntervalType => + case NullType | BinaryType | CalendarIntervalType | VariantType => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" case other => s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala index de1e4330c564..ce2643f9e239 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala @@ -639,6 +639,78 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval } } + test("variant in a cached row-based df") { + val query = """select + parse_json(format_string('{\"a\": %s}', id)) v, + cast(null as variant) as null_v, + case when id % 2 = 0 then parse_json(cast(id as string)) else null end as some_null + from range(0, 10)""" + val df = spark.sql(query) + df.cache() + + val expected = spark.sql(query) + checkAnswer(df, expected.collect()) + } + + test("struct of variant in a cached row-based df") { + val query = """select named_struct( + 'v', parse_json(format_string('{\"a\": %s}', id)), + 'null_v', cast(null as variant), + 'some_null', case when id % 2 = 0 then parse_json(cast(id as string)) else null end + ) v + from range(0, 10)""" + val df = spark.sql(query) + df.cache() + + val expected = spark.sql(query) + checkAnswer(df, expected.collect()) + } + + test("array of variant in a cached row-based df") { + val query = """select array( + parse_json(cast(id as string)), + parse_json(format_string('{\"a\": %s}', id)), + null, + case when id % 2 = 0 then parse_json(cast(id as string)) else null end) v + from range(0, 10)""" + val df = spark.sql(query) + df.cache() + + val expected = spark.sql(query) + checkAnswer(df, expected.collect()) + } + + test("map of variant in a cached row-based df") { + val query = """select map( + 'v', parse_json(format_string('{\"a\": %s}', id)), + 'null_v', cast(null as variant), + 'some_null', case when id % 2 = 0 then parse_json(cast(id as string)) else null end + ) v + from range(0, 10)""" + val df = spark.sql(query) + df.cache() + + val expected = spark.sql(query) + checkAnswer(df, expected.collect()) + } + + test("variant in a cached column-based df") { + withTable("t") { + val query = """select named_struct( + 'v', parse_json(format_string('{\"a\": %s}', id)), + 'null_v', cast(null as variant), + 'some_null', case when id % 2 = 0 then parse_json(cast(id as string)) else null end + ) v + from range(0, 10)""" + spark.sql(query).write.format("parquet").mode("overwrite").saveAsTable("t") + val df = spark.sql("select * from t") + df.cache() + + val expected = spark.sql(query) + checkAnswer(df, expected.collect()) + } + } + test("variant_get size") { val largeKey = "x" * 1000 val df = Seq(s"""{ "$largeKey": {"a" : 1 },