Skip to content

Commit 388213f

Browse files
committed
Fix
1 parent df07fc4 commit 388213f

File tree

3 files changed

+33
-19
lines changed

3 files changed

+33
-19
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.sql.catalyst
1919

20-
import org.apache.spark.SparkException
2120
import org.apache.spark.sql.catalyst.expressions._
2221
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
2322
import org.apache.spark.sql.types._
@@ -172,11 +171,8 @@ object InternalRow {
172171
case DoubleType => (input, v) => input.setDouble(ordinal, v.asInstanceOf[Double])
173172
case DecimalType.Fixed(precision, _) =>
174173
(input, v) => input.setDecimal(ordinal, v.asInstanceOf[Decimal], precision)
175-
case CalendarIntervalType | BinaryType | _: ArrayType | StringType | _: StructType |
176-
_: MapType | _: ObjectType =>
177-
(input, v) => input.update(ordinal, v)
178174
case udt: UserDefinedType[_] => getWriter(ordinal, udt.sqlType)
179175
case NullType => (input, _) => input.setNullAt(ordinal)
180-
case _ => throw new SparkException(s"Unsupported data type $dt")
176+
case _ => (input, v) => input.update(ordinal, v)
181177
}
182178
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,11 @@ class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mutable
5050

5151
override def target(row: InternalRow): MutableProjection = {
5252
// If `mutableRow` is `UnsafeRow`, `MutableProjection` accepts fixed-length types only
53-
assert(!row.isInstanceOf[UnsafeRow] ||
54-
validExprs.forall { case (e, _) => UnsafeRow.isFixedLength(e.dataType) })
53+
require(!row.isInstanceOf[UnsafeRow] ||
54+
validExprs.forall { case (e, _) => UnsafeRow.isFixedLength(e.dataType) },
55+
"MutableProjection cannot use UnsafeRow for output data types: " +
56+
validExprs.map(_._1.dataType).filterNot(UnsafeRow.isFixedLength)
57+
.map(_.catalogString).mkString(", "))
5558
mutableRow = row
5659
this
5760
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,36 +20,41 @@ package org.apache.spark.sql.catalyst.expressions
2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.Row
2222
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
23+
import org.apache.spark.sql.internal.SQLConf
2324
import org.apache.spark.sql.types._
2425
import org.apache.spark.unsafe.types.CalendarInterval
2526

2627
class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper {
2728

28-
private def createMutableProjection(dataTypes: Array[DataType]): MutableProjection = {
29+
val fixedLengthTypes = Array[DataType](
30+
BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType,
31+
DateType, TimestampType)
32+
33+
val variableLengthTypes = Array(
34+
StringType, DecimalType.defaultConcreteType, CalendarIntervalType, BinaryType,
35+
ArrayType(StringType), MapType(IntegerType, StringType),
36+
StructType.fromDDL("a INT, b STRING"), ObjectType(classOf[java.lang.Integer]))
37+
38+
def createMutableProjection(dataTypes: Array[DataType]): MutableProjection = {
2939
MutableProjection.create(dataTypes.zipWithIndex.map(x => BoundReference(x._2, x._1, true)))
3040
}
3141

3242
testBothCodegenAndInterpreted("fixed-length types") {
33-
val fixedLengthTypes = Array[DataType](
34-
BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType,
35-
DateType, TimestampType)
43+
val inputRow = InternalRow.fromSeq(Seq(true, 3.toByte, 15.toShort, -83, 129L, 1.0f, 5.0, 1, 2L))
3644
val proj = createMutableProjection(fixedLengthTypes)
37-
val inputRow = InternalRow.fromSeq(
38-
Seq(false, 3.toByte, 15.toShort, -83, 129L, 1.0f, 5.0, 100, 200L))
3945
assert(proj(inputRow) === inputRow)
46+
}
4047

41-
// Use UnsafeRow as buffer
48+
testBothCodegenAndInterpreted("unsafe buffer") {
49+
val inputRow = InternalRow.fromSeq(Seq(false, 1.toByte, 9.toShort, -18, 53L, 3.2f, 7.8, 4, 9L))
4250
val numBytes = UnsafeRow.calculateBitSetWidthInBytes(fixedLengthTypes.length)
4351
val unsafeBuffer = UnsafeRow.createFromByteArray(numBytes, fixedLengthTypes.length)
52+
val proj = createMutableProjection(fixedLengthTypes)
4453
val projUnsafeRow = proj.target(unsafeBuffer)(inputRow)
45-
assert(FromUnsafeProjection(fixedLengthTypes)(projUnsafeRow) === inputRow)
54+
assert(FromUnsafeProjection.apply(fixedLengthTypes)(projUnsafeRow) === inputRow)
4655
}
4756

4857
testBothCodegenAndInterpreted("variable-length types") {
49-
val variableLengthTypes = Array(
50-
StringType, DecimalType.defaultConcreteType, CalendarIntervalType, BinaryType,
51-
ArrayType(StringType), MapType(IntegerType, StringType),
52-
StructType.fromDDL("a INT, b STRING"), ObjectType(classOf[java.lang.Integer]))
5358
val proj = createMutableProjection(variableLengthTypes)
5459
val scalaValues = Seq("abc", BigDecimal(10), CalendarInterval.fromString("interval 1 day"),
5560
Array[Byte](1, 2), Array("123", "456"), Map(1 -> "a", 2 -> "b"), Row(1, "a"),
@@ -63,4 +68,14 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper {
6368
assert(toScala(projRow.get(index, dataType)) === toScala(inputRow.get(index, dataType)))
6469
}
6570
}
71+
72+
test("unsupported types for unsafe buffer") {
73+
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString) {
74+
val proj = createMutableProjection(Array(StringType))
75+
val errMsg = intercept[IllegalArgumentException] {
76+
proj.target(new UnsafeRow(1))
77+
}.getMessage
78+
assert(errMsg.contains("MutableProjection cannot use UnsafeRow for output data types:"))
79+
}
80+
}
6681
}

0 commit comments

Comments
 (0)