Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,15 @@ object DeserializerBuildHelper {
returnNullable = false)
}

def createDeserializerForJavaByteBuffer(path: Expression, returnNullable: Boolean): Expression = {
StaticInvoke(
classOf[java.nio.ByteBuffer],
ObjectType(classOf[java.nio.ByteBuffer]),
"wrap",
path :: Nil,
returnNullable = false)
}

/**
* When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff
* and lost the required data type, which may lead to runtime error if the real type doesn't
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

package org.apache.spark.sql.catalyst

import java.beans.{Introspector, PropertyDescriptor}
import java.lang.{Iterable => JIterable}
import java.lang.reflect.Method
import java.lang.reflect.Type
import java.util.{Iterator => JIterator, List => JList, Map => JMap}

Expand Down Expand Up @@ -106,6 +106,7 @@ object JavaTypeInference {
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
case c: Class[_] if c == classOf[java.nio.ByteBuffer] => (BinaryType, true)

case _ if typeToken.isArray =>
val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet)
Expand All @@ -131,28 +132,44 @@ object JavaTypeInference {
s"of class $other")
}

// TODO: we should only collect properties that have getter and setter. However, some tests
// pass in scala case class as java bean class which doesn't have getter and setter.
val properties = getJavaBeanReadableProperties(other)
val fields = properties.map { property =>
val returnType = typeToken.method(property.getReadMethod).getReturnType
val (dataType, nullable) = inferDataType(returnType, seenTypeSet + other)
new StructField(property.getName, dataType, nullable)
val fields = getObjectProperties(other).map {
case (propertyName, getterMethod, setterMethod) =>
val (dataType, nullable) = inferDataType(
TypeToken.of(getterMethod.getGenericReturnType),
seenTypeSet + other)
new StructField(propertyName, dataType, nullable)
}
(new StructType(fields), true)
}
}

def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
val beanInfo = Introspector.getBeanInfo(beanClass)
beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
.filterNot(_.getName == "declaringClass")
.filter(_.getReadMethod != null)
}

private def getJavaBeanReadableAndWritableProperties(
beanClass: Class[_]): Array[PropertyDescriptor] = {
getJavaBeanReadableProperties(beanClass).filter(_.getWriteMethod != null)
/**
* Returns: Array[(porpertyName, getterName, setterName, propertyType)]
*
* Properties of the object are defined by a getter 'propertyType [get]PropertyName()' and a
* setter 'void [set]PropertyName(propertyType value)' functions; where [get]PropertyName is
* the name of the getter function, and [set]PropertyName is the name of the setter function.
*/
def getObjectProperties(beanClass: Class[_]): Array[(String, Method, Method)] = {
def propertyName(getterName: String, setterName: String): String = {
if (getterName == setterName) {
getterName
} else {
if (getterName.startsWith("get") && setterName.startsWith("set") &&
getterName.substring(3) == setterName.substring(3)) {
getterName.substring(3)
} else {
null
}
}
}
for {
a <- beanClass.getMethods.filter(method => method.getParameterCount == 0)
b <- beanClass.getMethods.filter(method => method.getReturnType == Void.TYPE &&
method.getParameterCount == 1)
if (propertyName(a.getName, b.getName) != null &&
a.getReturnType == b.getParameterTypes.head)
} yield (propertyName(a.getName, b.getName), a, b)
}

private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
Expand Down Expand Up @@ -317,24 +334,26 @@ object JavaTypeInference {
keyData :: valueData :: Nil,
returnNullable = false)

case other if other == classOf[java.nio.ByteBuffer] =>
createDeserializerForJavaByteBuffer(path, returnNullable = false)

case other if other.isEnum =>
createDeserializerForTypesSupportValueOf(
createDeserializerForString(path, returnNullable = false),
other)

case other =>
val properties = getJavaBeanReadableAndWritableProperties(other)
val setters = properties.map { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val (dataType, nullable) = inferDataType(fieldType)
val newTypePath = walkedTypePath.recordField(fieldType.getType.getTypeName, fieldName)
val setter = expressionWithNullSafety(
deserializerFor(fieldType, addToPath(path, fieldName, dataType, newTypePath),
newTypePath),
nullable = nullable,
newTypePath)
p.getWriteMethod.getName -> setter
val setters = getObjectProperties(other).map {
case (fieldName, getterMethod, setterMethod) =>
val fieldType = TypeToken.of(getterMethod.getGenericReturnType)
val (dataType, nullable) = inferDataType(fieldType)
val newTypePath = walkedTypePath.recordField(fieldType.getType.getTypeName, fieldName)
val setter = expressionWithNullSafety(
deserializerFor(fieldType, addToPath(path, fieldName, dataType, newTypePath),
newTypePath),
nullable = nullable,
newTypePath)
setterMethod.getName -> setter
}.toMap

val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false)
Expand Down Expand Up @@ -401,6 +420,9 @@ object JavaTypeInference {
case c if c == classOf[java.lang.Float] => createSerializerForFloat(inputObject)
case c if c == classOf[java.lang.Double] => createSerializerForDouble(inputObject)

case c if c == classOf[java.nio.ByteBuffer] =>
createSerializerForJavaByteBuffer(inputObject)

case _ if typeToken.isArray =>
toCatalystArray(inputObject, typeToken.getComponentType)

Expand All @@ -427,13 +449,12 @@ object JavaTypeInference {
Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false))

case other =>
val properties = getJavaBeanReadableAndWritableProperties(other)
val fields = properties.map { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val fields = getObjectProperties(other).map {
case (fieldName, getterMethod, setterMethod) =>
val fieldType = TypeToken.of(getterMethod.getGenericReturnType)
val fieldValue = Invoke(
inputObject,
p.getReadMethod.getName,
getterMethod.getName,
inferExternalType(fieldType.getRawType))
(fieldName, serializerFor(fieldValue, fieldType))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ object SerializerBuildHelper {
createSerializerForJavaBigInteger(inputObject)
}

def createSerializerForJavaByteBuffer(inputObject: Expression): Expression = {
Invoke(inputObject, "array", BinaryType)
}

def createSerializerForPrimitiveArray(
inputObject: Expression,
dataType: DataType): Expression = {
Expand Down
Loading