Skip to content

Commit 8cb710b

Browse files
committed
Address comments.
1 parent ed4f4c9 commit 8cb710b

File tree

6 files changed

+30
-51
lines changed

6 files changed

+30
-51
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import scala.util.Properties
2424
import org.apache.commons.lang3.reflect.ConstructorUtils
2525

2626
import org.apache.spark.internal.Logging
27-
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
27+
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue}
2828
import org.apache.spark.sql.catalyst.expressions._
2929
import org.apache.spark.sql.catalyst.expressions.objects._
3030
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData}
@@ -426,33 +426,20 @@ object ScalaReflection extends ScalaReflection {
426426
* * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"`
427427
* * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")`
428428
*/
429-
def serializerForType(tpe: `Type`,
430-
cls: RuntimeClass): Expression = ScalaReflection.cleanUpReflectionObjects {
429+
def serializerForType(tpe: `Type`): Expression = ScalaReflection.cleanUpReflectionObjects {
431430
val clsName = getClassNameFromType(tpe)
432431
val walkedTypePath = s"""- root class: "$clsName"""" :: Nil
433432

434433
// The input object to `ExpressionEncoder` is located at first column of an row.
435434
val inputObject = BoundReference(0, dataTypeFor(tpe),
436-
nullable = !cls.isPrimitive)
435+
nullable = !tpe.typeSymbol.asClass.isPrimitive)
437436

438437
serializerFor(inputObject, tpe, walkedTypePath)
439438
}
440439

441440
/**
442441
* Returns an expression for serializing the value of an input expression into Spark SQL
443442
* internal representation.
444-
*
445-
* The expression generated by this method will be used by `ExpressionEncoder` as serializer
446-
* to convert a JVM object to Spark SQL representation.
447-
*
448-
* The returned serializer generally converts a JVM object to corresponding Spark SQL
449-
* representation. For example, `Seq[_]` is converted to a Spark SQL array, `Product` is
450-
* converted to a Spark SQL struct.
451-
*
452-
* If input object is not of ObjectType, it means that the input object is already in a form
453-
* of Spark's internal representation. We simply return the input object.
454-
*
455-
* For unsupported types, an `UnsupportedOperationException` will be thrown.
456443
*/
457444
private def serializerFor(
458445
inputObject: Expression,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.encoders
1919

2020
import scala.reflect.ClassTag
21-
import scala.reflect.runtime.universe.{`Type`, typeTag, RuntimeClass, TypeTag}
21+
import scala.reflect.runtime.universe.{typeTag, TypeTag}
2222

2323
import org.apache.spark.sql.Encoder
2424
import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection}
@@ -48,7 +48,6 @@ object ExpressionEncoder {
4848
def apply[T : TypeTag](): ExpressionEncoder[T] = {
4949
val mirror = ScalaReflection.mirror
5050
val tpe = typeTag[T].in(mirror).tpe
51-
val cls = mirror.runtimeClass(tpe)
5251

5352
if (ScalaReflection.optionOfProductType(tpe)) {
5453
throw new UnsupportedOperationException(
@@ -59,7 +58,8 @@ object ExpressionEncoder {
5958
"`val ds: Dataset[Tuple1[MyClass]] = Seq(Tuple1(MyClass(...)), Tuple1(null)).toDS`")
6059
}
6160

62-
val serializer = ScalaReflection.serializerForType(tpe, cls)
61+
val cls = mirror.runtimeClass(tpe)
62+
val serializer = ScalaReflection.serializerForType(tpe)
6363
val deserializer = ScalaReflection.deserializerForType(tpe)
6464

6565
new ExpressionEncoder[T](
@@ -204,10 +204,9 @@ case class ExpressionEncoder[T](
204204
* 2. For other cases, we create a struct to wrap the `serializer`.
205205
*/
206206
val serializer: Seq[NamedExpression] = {
207-
val serializedAsStruct = objSerializer.dataType.isInstanceOf[StructType]
208207
val clsName = Utils.getSimpleName(clsTag.runtimeClass)
209208

210-
if (serializedAsStruct) {
209+
if (isSerializedAsStruct) {
211210
val nullSafeSerializer = objSerializer.transformUp {
212211
case r: BoundReference =>
213212
// For input object of Product type, we can't encode it to row if it's null, as Spark SQL
@@ -236,9 +235,7 @@ case class ExpressionEncoder[T](
236235
* `GetColumnByOrdinal` with corresponding ordinal.
237236
*/
238237
val deserializer: Expression = {
239-
val serializedAsStruct = objSerializer.dataType.isInstanceOf[StructType]
240-
241-
if (serializedAsStruct) {
238+
if (isSerializedAsStruct) {
242239
// We serialized this kind of objects to root-level row. The input of general deserializer
243240
// is a `GetColumnByOrdinal(0)` expression to extract first column of a row. We need to
244241
// transform attributes accessors.
@@ -264,6 +261,11 @@ case class ExpressionEncoder[T](
264261
StructField(s.name, s.dataType, s.nullable)
265262
})
266263

264+
/**
265+
* Returns true if the type `T` is serialized as a struct.
266+
*/
267+
def isSerializedAsStruct: Boolean = objSerializer.dataType.isInstanceOf[StructType]
268+
267269
// serializer expressions are used to encode an object to a row, while the object is usually an
268270
// intermediate value produced inside an operator, not from the output of the child operator. This
269271
// is quite different from normal expressions, and `AttributeReference` doesn't work here

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,7 @@ class ScalaReflectionSuite extends SparkFunSuite {
262262

263263
test("SPARK-15062: Get correct serializer for List[_]") {
264264
val list = List(1, 2, 3)
265-
val serializer = serializerForType(ScalaReflection.localTypeOf[List[Int]],
266-
classOf[List[Int]])
265+
val serializer = serializerForType(ScalaReflection.localTypeOf[List[Int]])
267266
assert(serializer.isInstanceOf[NewInstance])
268267
assert(serializer.asInstanceOf[NewInstance]
269268
.cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData]))
@@ -276,42 +275,37 @@ class ScalaReflectionSuite extends SparkFunSuite {
276275

277276
test("serialize and deserialize arbitrary sequence types") {
278277
import scala.collection.immutable.Queue
279-
val queueSerializer = serializerForType(ScalaReflection.localTypeOf[Queue[Int]],
280-
classOf[Queue[Int]])
278+
val queueSerializer = serializerForType(ScalaReflection.localTypeOf[Queue[Int]])
281279
assert(queueSerializer.dataType ==
282280
ArrayType(IntegerType, containsNull = false))
283281
val queueDeserializer = deserializerForType(ScalaReflection.localTypeOf[Queue[Int]])
284282
assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]]))
285283

286284
import scala.collection.mutable.ArrayBuffer
287-
val arrayBufferSerializer = serializerForType(ScalaReflection.localTypeOf[ArrayBuffer[Int]],
288-
classOf[ArrayBuffer[Int]])
285+
val arrayBufferSerializer = serializerForType(ScalaReflection.localTypeOf[ArrayBuffer[Int]])
289286
assert(arrayBufferSerializer.dataType ==
290287
ArrayType(IntegerType, containsNull = false))
291288
val arrayBufferDeserializer = deserializerForType(ScalaReflection.localTypeOf[ArrayBuffer[Int]])
292289
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))
293290
}
294291

295292
test("serialize and deserialize arbitrary map types") {
296-
val mapSerializer = serializerForType(ScalaReflection.localTypeOf[Map[Int, Int]],
297-
classOf[Map[Int, Int]])
293+
val mapSerializer = serializerForType(ScalaReflection.localTypeOf[Map[Int, Int]])
298294
assert(mapSerializer.dataType ==
299295
MapType(IntegerType, IntegerType, valueContainsNull = false))
300296
val mapDeserializer = deserializerForType(ScalaReflection.localTypeOf[Map[Int, Int]])
301297
assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]]))
302298

303299
import scala.collection.immutable.HashMap
304-
val hashMapSerializer = serializerForType(ScalaReflection.localTypeOf[HashMap[Int, Int]],
305-
classOf[HashMap[Int, Int]])
300+
val hashMapSerializer = serializerForType(ScalaReflection.localTypeOf[HashMap[Int, Int]])
306301
assert(hashMapSerializer.dataType ==
307302
MapType(IntegerType, IntegerType, valueContainsNull = false))
308303
val hashMapDeserializer = deserializerForType(ScalaReflection.localTypeOf[HashMap[Int, Int]])
309304
assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]]))
310305

311306
import scala.collection.mutable.{LinkedHashMap => LHMap}
312307
val linkedHashMapSerializer = serializerForType(
313-
ScalaReflection.localTypeOf[LHMap[Long, String]],
314-
classOf[LHMap[Long, String]])
308+
ScalaReflection.localTypeOf[LHMap[Long, String]])
315309
assert(linkedHashMapSerializer.dataType ==
316310
MapType(LongType, StringType, valueContainsNull = true))
317311
val linkedHashMapDeserializer = deserializerForType(
@@ -320,10 +314,10 @@ class ScalaReflectionSuite extends SparkFunSuite {
320314
}
321315

322316
test("SPARK-22442: Generate correct field names for special characters") {
323-
val serializer = serializerForType(ScalaReflection.localTypeOf[SpecialCharAsFieldData],
324-
classOf[SpecialCharAsFieldData]).collect {
325-
case If(_, _, s: CreateNamedStruct) => s
326-
}.head
317+
val serializer = serializerForType(ScalaReflection.localTypeOf[SpecialCharAsFieldData])
318+
.collect {
319+
case If(_, _, s: CreateNamedStruct) => s
320+
}.head
327321
val deserializer = deserializerForType(ScalaReflection.localTypeOf[SpecialCharAsFieldData])
328322
assert(serializer.dataType(0).name == "field.1")
329323
assert(serializer.dataType(1).name == "field 2")

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,7 +1087,7 @@ class Dataset[T] private[sql](
10871087
// Note that we do this before joining them, to enable the join operator to return null for one
10881088
// side, in cases like outer-join.
10891089
val left = {
1090-
val combined = if (!this.exprEnc.objSerializer.dataType.isInstanceOf[StructType]) {
1090+
val combined = if (!this.exprEnc.isSerializedAsStruct) {
10911091
assert(joined.left.output.length == 1)
10921092
Alias(joined.left.output.head, "_1")()
10931093
} else {
@@ -1097,7 +1097,7 @@ class Dataset[T] private[sql](
10971097
}
10981098

10991099
val right = {
1100-
val combined = if (!other.exprEnc.objSerializer.dataType.isInstanceOf[StructType]) {
1100+
val combined = if (!other.exprEnc.isSerializedAsStruct) {
11011101
assert(joined.right.output.length == 1)
11021102
Alias(joined.right.output.head, "_2")()
11031103
} else {
@@ -1110,14 +1110,14 @@ class Dataset[T] private[sql](
11101110
// combine the outputs of each join side.
11111111
val conditionExpr = joined.condition.get transformUp {
11121112
case a: Attribute if joined.left.outputSet.contains(a) =>
1113-
if (!this.exprEnc.objSerializer.dataType.isInstanceOf[StructType]) {
1113+
if (!this.exprEnc.isSerializedAsStruct) {
11141114
left.output.head
11151115
} else {
11161116
val index = joined.left.output.indexWhere(_.exprId == a.exprId)
11171117
GetStructField(left.output.head, index)
11181118
}
11191119
case a: Attribute if joined.right.outputSet.contains(a) =>
1120-
if (!other.exprEnc.objSerializer.dataType.isInstanceOf[StructType]) {
1120+
if (!other.exprEnc.isSerializedAsStruct) {
11211121
right.output.head
11221122
} else {
11231123
val index = joined.right.output.indexWhere(_.exprId == a.exprId)
@@ -1390,7 +1390,7 @@ class Dataset[T] private[sql](
13901390
implicit val encoder = c1.encoder
13911391
val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan)
13921392

1393-
if (!encoder.objSerializer.dataType.isInstanceOf[StructType]) {
1393+
if (!encoder.isSerializedAsStruct) {
13941394
new Dataset[U1](sparkSession, project, encoder)
13951395
} else {
13961396
// Flattens inner fields of U1

sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.plans.logical._
2727
import org.apache.spark.sql.execution.QueryExecution
2828
import org.apache.spark.sql.expressions.ReduceAggregator
2929
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode}
30-
import org.apache.spark.sql.types.StructType
3130

3231
/**
3332
* :: Experimental ::
@@ -458,8 +457,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
458457
val encoders = columns.map(_.encoder)
459458
val namedColumns =
460459
columns.map(_.withInputType(vExprEnc, dataAttributes).named)
461-
462-
val keyColumn = if (!kExprEnc.objSerializer.dataType.isInstanceOf[StructType]) {
460+
val keyColumn = if (!kExprEnc.isSerializedAsStruct) {
463461
assert(groupingAttributes.length == 1)
464462
groupingAttributes.head
465463
} else {

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ import org.apache.spark.sql.types._
3232
import org.apache.spark.util.Utils
3333

3434
object TypedAggregateExpression {
35-
3635
def apply[BUF : Encoder, OUT : Encoder](
3736
aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = {
3837
val bufferEncoder = encoderFor[BUF]
@@ -46,8 +45,7 @@ object TypedAggregateExpression {
4645
// serialization.
4746
val isSimpleBuffer = {
4847
bufferSerializer.head match {
49-
case Alias(_: BoundReference, _)
50-
if !bufferEncoder.objSerializer.dataType.isInstanceOf[StructType] => true
48+
case Alias(_: BoundReference, _) if !bufferEncoder.isSerializedAsStruct => true
5149
case _ => false
5250
}
5351
}

0 commit comments

Comments
 (0)