@@ -20,36 +20,41 @@ package org.apache.spark.sql.catalyst.expressions
2020import org .apache .spark .SparkFunSuite
2121import org .apache .spark .sql .Row
2222import org .apache .spark .sql .catalyst .{CatalystTypeConverters , InternalRow }
23+ import org .apache .spark .sql .internal .SQLConf
2324import org .apache .spark .sql .types ._
2425import org .apache .spark .unsafe .types .CalendarInterval
2526
2627class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper {
2728
28- private def createMutableProjection (dataTypes : Array [DataType ]): MutableProjection = {
29+ val fixedLengthTypes = Array [DataType ](
30+ BooleanType , ByteType , ShortType , IntegerType , LongType , FloatType , DoubleType ,
31+ DateType , TimestampType )
32+
33+ val variableLengthTypes = Array (
34+ StringType , DecimalType .defaultConcreteType, CalendarIntervalType , BinaryType ,
35+ ArrayType (StringType ), MapType (IntegerType , StringType ),
36+ StructType .fromDDL(" a INT, b STRING" ), ObjectType (classOf [java.lang.Integer ]))
37+
38+ def createMutableProjection (dataTypes : Array [DataType ]): MutableProjection = {
2939 MutableProjection .create(dataTypes.zipWithIndex.map(x => BoundReference (x._2, x._1, true )))
3040 }
3141
3242 testBothCodegenAndInterpreted(" fixed-length types" ) {
33- val fixedLengthTypes = Array [DataType ](
34- BooleanType , ByteType , ShortType , IntegerType , LongType , FloatType , DoubleType ,
35- DateType , TimestampType )
43+ val inputRow = InternalRow .fromSeq(Seq (true , 3 .toByte, 15 .toShort, - 83 , 129L , 1.0f , 5.0 , 1 , 2L ))
3644 val proj = createMutableProjection(fixedLengthTypes)
37- val inputRow = InternalRow .fromSeq(
38- Seq (false , 3 .toByte, 15 .toShort, - 83 , 129L , 1.0f , 5.0 , 100 , 200L ))
3945 assert(proj(inputRow) === inputRow)
46+ }
4047
41- // Use UnsafeRow as buffer
48+ testBothCodegenAndInterpreted(" unsafe buffer" ) {
49+ val inputRow = InternalRow .fromSeq(Seq (false , 1 .toByte, 9 .toShort, - 18 , 53L , 3.2f , 7.8 , 4 , 9L ))
4250 val numBytes = UnsafeRow .calculateBitSetWidthInBytes(fixedLengthTypes.length)
4351 val unsafeBuffer = UnsafeRow .createFromByteArray(numBytes, fixedLengthTypes.length)
52+ val proj = createMutableProjection(fixedLengthTypes)
4453 val projUnsafeRow = proj.target(unsafeBuffer)(inputRow)
45- assert(FromUnsafeProjection (fixedLengthTypes)(projUnsafeRow) === inputRow)
54+ assert(FromUnsafeProjection .apply (fixedLengthTypes)(projUnsafeRow) === inputRow)
4655 }
4756
4857 testBothCodegenAndInterpreted(" variable-length types" ) {
49- val variableLengthTypes = Array (
50- StringType , DecimalType .defaultConcreteType, CalendarIntervalType , BinaryType ,
51- ArrayType (StringType ), MapType (IntegerType , StringType ),
52- StructType .fromDDL(" a INT, b STRING" ), ObjectType (classOf [java.lang.Integer ]))
5358 val proj = createMutableProjection(variableLengthTypes)
5459 val scalaValues = Seq (" abc" , BigDecimal (10 ), CalendarInterval .fromString(" interval 1 day" ),
5560 Array [Byte ](1 , 2 ), Array (" 123" , " 456" ), Map (1 -> " a" , 2 -> " b" ), Row (1 , " a" ),
@@ -63,4 +68,14 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper {
6368 assert(toScala(projRow.get(index, dataType)) === toScala(inputRow.get(index, dataType)))
6469 }
6570 }
71+
72+ test(" unsupported types for unsafe buffer" ) {
73+ withSQLConf(SQLConf .CODEGEN_FACTORY_MODE .key -> CodegenObjectFactoryMode .NO_CODEGEN .toString) {
74+ val proj = createMutableProjection(Array (StringType ))
75+ val errMsg = intercept[IllegalArgumentException ] {
76+ proj.target(new UnsafeRow (1 ))
77+ }.getMessage
78+ assert(errMsg.contains(" MutableProjection cannot use UnsafeRow for output data types:" ))
79+ }
80+ }
6681}
0 commit comments