1717
1818package org .apache .spark .sql .catalyst
1919
20- import java .beans .Introspector
20+ import java .beans .{ PropertyDescriptor , Introspector }
2121import java .lang .{Iterable => JIterable }
22- import java .util .{Iterator => JIterator , Map => JMap }
22+ import java .util .{Iterator => JIterator , Map => JMap , List => JList }
2323
2424import scala .language .existentials
2525
2626import com .google .common .reflect .TypeToken
27+
2728import org .apache .spark .sql .types ._
29+ import org .apache .spark .sql .catalyst .expressions ._
30+ import org .apache .spark .sql .catalyst .analysis .{UnresolvedAttribute , UnresolvedExtractValue }
31+ import org .apache .spark .sql .catalyst .util .{GenericArrayData , ArrayBasedMapData , DateTimeUtils }
32+ import org .apache .spark .unsafe .types .UTF8String
33+
2834
2935/**
3036 * Type-inference utilities for POJOs and Java collections.
@@ -33,13 +39,14 @@ object JavaTypeInference {
3339
3440 private val iterableType = TypeToken .of(classOf [JIterable [_]])
3541 private val mapType = TypeToken .of(classOf [JMap [_, _]])
42+ private val listType = TypeToken .of(classOf [JList [_]])
3643 private val iteratorReturnType = classOf [JIterable [_]].getMethod(" iterator" ).getGenericReturnType
3744 private val nextReturnType = classOf [JIterator [_]].getMethod(" next" ).getGenericReturnType
3845 private val keySetReturnType = classOf [JMap [_, _]].getMethod(" keySet" ).getGenericReturnType
3946 private val valuesReturnType = classOf [JMap [_, _]].getMethod(" values" ).getGenericReturnType
4047
4148 /**
42- * Infers the corresponding SQL data type of a JavaClean class.
49+ * Infers the corresponding SQL data type of a JavaBean class.
4350 * @param beanClass Java type
4451 * @return (SQL data type, nullable)
4552 */
@@ -58,6 +65,8 @@ object JavaTypeInference {
5865 (c.getAnnotation(classOf [SQLUserDefinedType ]).udt().newInstance(), true )
5966
6067 case c : Class [_] if c == classOf [java.lang.String ] => (StringType , true )
68+ case c : Class [_] if c == classOf [Array [Byte ]] => (BinaryType , true )
69+
6170 case c : Class [_] if c == java.lang.Short .TYPE => (ShortType , false )
6271 case c : Class [_] if c == java.lang.Integer .TYPE => (IntegerType , false )
6372 case c : Class [_] if c == java.lang.Long .TYPE => (LongType , false )
@@ -87,15 +96,14 @@ object JavaTypeInference {
8796 (ArrayType (dataType, nullable), true )
8897
8998 case _ if mapType.isAssignableFrom(typeToken) =>
90- val typeToken2 = typeToken.asInstanceOf [TypeToken [_ <: JMap [_, _]]]
91- val mapSupertype = typeToken2.getSupertype(classOf [JMap [_, _]])
92- val keyType = elementType(mapSupertype.resolveType(keySetReturnType))
93- val valueType = elementType(mapSupertype.resolveType(valuesReturnType))
99+ val (keyType, valueType) = mapKeyValueType(typeToken)
94100 val (keyDataType, _) = inferDataType(keyType)
95101 val (valueDataType, nullable) = inferDataType(valueType)
96102 (MapType (keyDataType, valueDataType, nullable), true )
97103
98104 case _ =>
105+ // TODO: we should only collect properties that have getter and setter. However, some tests
106+ // pass in scala case class as java bean class which doesn't have getter and setter.
99107 val beanInfo = Introspector .getBeanInfo(typeToken.getRawType)
100108 val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == " class" )
101109 val fields = properties.map { property =>
@@ -107,11 +115,294 @@ object JavaTypeInference {
107115 }
108116 }
109117
118+ private def getJavaBeanProperties (beanClass : Class [_]): Array [PropertyDescriptor ] = {
119+ val beanInfo = Introspector .getBeanInfo(beanClass)
120+ beanInfo.getPropertyDescriptors
121+ .filter(p => p.getReadMethod != null && p.getWriteMethod != null )
122+ }
123+
110124 private def elementType (typeToken : TypeToken [_]): TypeToken [_] = {
111125 val typeToken2 = typeToken.asInstanceOf [TypeToken [_ <: JIterable [_]]]
112- val iterableSupertype = typeToken2.getSupertype(classOf [JIterable [_]])
113- val iteratorType = iterableSupertype.resolveType(iteratorReturnType)
114- val itemType = iteratorType.resolveType(nextReturnType)
115- itemType
126+ val iterableSuperType = typeToken2.getSupertype(classOf [JIterable [_]])
127+ val iteratorType = iterableSuperType.resolveType(iteratorReturnType)
128+ iteratorType.resolveType(nextReturnType)
129+ }
130+
131+ private def mapKeyValueType (typeToken : TypeToken [_]): (TypeToken [_], TypeToken [_]) = {
132+ val typeToken2 = typeToken.asInstanceOf [TypeToken [_ <: JMap [_, _]]]
133+ val mapSuperType = typeToken2.getSupertype(classOf [JMap [_, _]])
134+ val keyType = elementType(mapSuperType.resolveType(keySetReturnType))
135+ val valueType = elementType(mapSuperType.resolveType(valuesReturnType))
136+ keyType -> valueType
137+ }
138+
139+ /**
140+ * Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping
141+ * to a native type, an ObjectType is returned.
142+ *
143+ * Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type
144+ * system. As a result, ObjectType will be returned for things like boxed Integers.
145+ */
146+ private def inferExternalType (cls : Class [_]): DataType = cls match {
147+ case c if c == java.lang.Boolean .TYPE => BooleanType
148+ case c if c == java.lang.Byte .TYPE => ByteType
149+ case c if c == java.lang.Short .TYPE => ShortType
150+ case c if c == java.lang.Integer .TYPE => IntegerType
151+ case c if c == java.lang.Long .TYPE => LongType
152+ case c if c == java.lang.Float .TYPE => FloatType
153+ case c if c == java.lang.Double .TYPE => DoubleType
154+ case c if c == classOf [Array [Byte ]] => BinaryType
155+ case _ => ObjectType (cls)
156+ }
157+
158+ /**
159+ * Returns an expression that can be used to construct an object of java bean `T` given an input
160+ * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
161+ * of the same name as the constructor arguments. Nested classes will have their fields accessed
162+ * using UnresolvedExtractValue.
163+ */
164+ def constructorFor (beanClass : Class [_]): Expression = {
165+ constructorFor(TypeToken .of(beanClass), None )
166+ }
167+
168+ private def constructorFor (typeToken : TypeToken [_], path : Option [Expression ]): Expression = {
169+ /** Returns the current path with a sub-field extracted. */
170+ def addToPath (part : String ): Expression = path
171+ .map(p => UnresolvedExtractValue (p, expressions.Literal (part)))
172+ .getOrElse(UnresolvedAttribute (part))
173+
174+ /** Returns the current path or `BoundReference`. */
175+ def getPath : Expression = path.getOrElse(BoundReference (0 , inferDataType(typeToken)._1, true ))
176+
177+ typeToken.getRawType match {
178+ case c if ! inferExternalType(c).isInstanceOf [ObjectType ] => getPath
179+
180+ case c if c == classOf [java.lang.Short ] =>
181+ NewInstance (c, getPath :: Nil , propagateNull = true , ObjectType (c))
182+ case c if c == classOf [java.lang.Integer ] =>
183+ NewInstance (c, getPath :: Nil , propagateNull = true , ObjectType (c))
184+ case c if c == classOf [java.lang.Long ] =>
185+ NewInstance (c, getPath :: Nil , propagateNull = true , ObjectType (c))
186+ case c if c == classOf [java.lang.Double ] =>
187+ NewInstance (c, getPath :: Nil , propagateNull = true , ObjectType (c))
188+ case c if c == classOf [java.lang.Byte ] =>
189+ NewInstance (c, getPath :: Nil , propagateNull = true , ObjectType (c))
190+ case c if c == classOf [java.lang.Float ] =>
191+ NewInstance (c, getPath :: Nil , propagateNull = true , ObjectType (c))
192+ case c if c == classOf [java.lang.Boolean ] =>
193+ NewInstance (c, getPath :: Nil , propagateNull = true , ObjectType (c))
194+
195+ case c if c == classOf [java.sql.Date ] =>
196+ StaticInvoke (
197+ DateTimeUtils ,
198+ ObjectType (c),
199+ " toJavaDate" ,
200+ getPath :: Nil ,
201+ propagateNull = true )
202+
203+ case c if c == classOf [java.sql.Timestamp ] =>
204+ StaticInvoke (
205+ DateTimeUtils ,
206+ ObjectType (c),
207+ " toJavaTimestamp" ,
208+ getPath :: Nil ,
209+ propagateNull = true )
210+
211+ case c if c == classOf [java.lang.String ] =>
212+ Invoke (getPath, " toString" , ObjectType (classOf [String ]))
213+
214+ case c if c == classOf [java.math.BigDecimal ] =>
215+ Invoke (getPath, " toJavaBigDecimal" , ObjectType (classOf [java.math.BigDecimal ]))
216+
217+ case c if c.isArray =>
218+ val elementType = c.getComponentType
219+ val primitiveMethod = elementType match {
220+ case c if c == java.lang.Boolean .TYPE => Some (" toBooleanArray" )
221+ case c if c == java.lang.Byte .TYPE => Some (" toByteArray" )
222+ case c if c == java.lang.Short .TYPE => Some (" toShortArray" )
223+ case c if c == java.lang.Integer .TYPE => Some (" toIntArray" )
224+ case c if c == java.lang.Long .TYPE => Some (" toLongArray" )
225+ case c if c == java.lang.Float .TYPE => Some (" toFloatArray" )
226+ case c if c == java.lang.Double .TYPE => Some (" toDoubleArray" )
227+ case _ => None
228+ }
229+
230+ primitiveMethod.map { method =>
231+ Invoke (getPath, method, ObjectType (c))
232+ }.getOrElse {
233+ Invoke (
234+ MapObjects (
235+ p => constructorFor(typeToken.getComponentType, Some (p)),
236+ getPath,
237+ inferDataType(elementType)._1),
238+ " array" ,
239+ ObjectType (c))
240+ }
241+
242+ case c if listType.isAssignableFrom(typeToken) =>
243+ val et = elementType(typeToken)
244+ val array =
245+ Invoke (
246+ MapObjects (
247+ p => constructorFor(et, Some (p)),
248+ getPath,
249+ inferDataType(et)._1),
250+ " array" ,
251+ ObjectType (classOf [Array [Any ]]))
252+
253+ StaticInvoke (classOf [java.util.Arrays ], ObjectType (c), " asList" , array :: Nil )
254+
255+ case _ if mapType.isAssignableFrom(typeToken) =>
256+ val (keyType, valueType) = mapKeyValueType(typeToken)
257+ val keyDataType = inferDataType(keyType)._1
258+ val valueDataType = inferDataType(valueType)._1
259+
260+ val keyData =
261+ Invoke (
262+ MapObjects (
263+ p => constructorFor(keyType, Some (p)),
264+ Invoke (getPath, " keyArray" , ArrayType (keyDataType)),
265+ keyDataType),
266+ " array" ,
267+ ObjectType (classOf [Array [Any ]]))
268+
269+ val valueData =
270+ Invoke (
271+ MapObjects (
272+ p => constructorFor(valueType, Some (p)),
273+ Invoke (getPath, " valueArray" , ArrayType (valueDataType)),
274+ valueDataType),
275+ " array" ,
276+ ObjectType (classOf [Array [Any ]]))
277+
278+ StaticInvoke (
279+ ArrayBasedMapData ,
280+ ObjectType (classOf [JMap [_, _]]),
281+ " toJavaMap" ,
282+ keyData :: valueData :: Nil )
283+
284+ case other =>
285+ val properties = getJavaBeanProperties(other)
286+ assert(properties.length > 0 )
287+
288+ val setters = properties.map { p =>
289+ val fieldName = p.getName
290+ val fieldType = typeToken.method(p.getReadMethod).getReturnType
291+ p.getWriteMethod.getName -> constructorFor(fieldType, Some (addToPath(fieldName)))
292+ }.toMap
293+
294+ val newInstance = NewInstance (other, Nil , propagateNull = false , ObjectType (other))
295+ val result = InitializeJavaBean (newInstance, setters)
296+
297+ if (path.nonEmpty) {
298+ expressions.If (
299+ IsNull (getPath),
300+ expressions.Literal .create(null , ObjectType (other)),
301+ result
302+ )
303+ } else {
304+ result
305+ }
306+ }
307+ }
308+
309+ /**
310+ * Returns expressions for extracting all the fields from the given type.
311+ */
312+ def extractorsFor (beanClass : Class [_]): CreateNamedStruct = {
313+ val inputObject = BoundReference (0 , ObjectType (beanClass), nullable = true )
314+ extractorFor(inputObject, TypeToken .of(beanClass)).asInstanceOf [CreateNamedStruct ]
315+ }
316+
317+ private def extractorFor (inputObject : Expression , typeToken : TypeToken [_]): Expression = {
318+
319+ def toCatalystArray (input : Expression , elementType : TypeToken [_]): Expression = {
320+ val (dataType, nullable) = inferDataType(elementType)
321+ if (ScalaReflection .isNativeType(dataType)) {
322+ NewInstance (
323+ classOf [GenericArrayData ],
324+ input :: Nil ,
325+ dataType = ArrayType (dataType, nullable))
326+ } else {
327+ MapObjects (extractorFor(_, elementType), input, ObjectType (elementType.getRawType))
328+ }
329+ }
330+
331+ if (! inputObject.dataType.isInstanceOf [ObjectType ]) {
332+ inputObject
333+ } else {
334+ typeToken.getRawType match {
335+ case c if c == classOf [String ] =>
336+ StaticInvoke (
337+ classOf [UTF8String ],
338+ StringType ,
339+ " fromString" ,
340+ inputObject :: Nil )
341+
342+ case c if c == classOf [java.sql.Timestamp ] =>
343+ StaticInvoke (
344+ DateTimeUtils ,
345+ TimestampType ,
346+ " fromJavaTimestamp" ,
347+ inputObject :: Nil )
348+
349+ case c if c == classOf [java.sql.Date ] =>
350+ StaticInvoke (
351+ DateTimeUtils ,
352+ DateType ,
353+ " fromJavaDate" ,
354+ inputObject :: Nil )
355+
356+ case c if c == classOf [java.math.BigDecimal ] =>
357+ StaticInvoke (
358+ Decimal ,
359+ DecimalType .SYSTEM_DEFAULT ,
360+ " apply" ,
361+ inputObject :: Nil )
362+
363+ case c if c == classOf [java.lang.Boolean ] =>
364+ Invoke (inputObject, " booleanValue" , BooleanType )
365+ case c if c == classOf [java.lang.Byte ] =>
366+ Invoke (inputObject, " byteValue" , ByteType )
367+ case c if c == classOf [java.lang.Short ] =>
368+ Invoke (inputObject, " shortValue" , ShortType )
369+ case c if c == classOf [java.lang.Integer ] =>
370+ Invoke (inputObject, " intValue" , IntegerType )
371+ case c if c == classOf [java.lang.Long ] =>
372+ Invoke (inputObject, " longValue" , LongType )
373+ case c if c == classOf [java.lang.Float ] =>
374+ Invoke (inputObject, " floatValue" , FloatType )
375+ case c if c == classOf [java.lang.Double ] =>
376+ Invoke (inputObject, " doubleValue" , DoubleType )
377+
378+ case _ if typeToken.isArray =>
379+ toCatalystArray(inputObject, typeToken.getComponentType)
380+
381+ case _ if listType.isAssignableFrom(typeToken) =>
382+ toCatalystArray(inputObject, elementType(typeToken))
383+
384+ case _ if mapType.isAssignableFrom(typeToken) =>
385+ // TODO: for java map, if we get the keys and values by `keySet` and `values`, we can
386+ // not guarantee they have same iteration order(which is different from scala map).
387+ // A possible solution is creating a new `MapObjects` that can iterate a map directly.
388+ throw new UnsupportedOperationException (" map type is not supported currently" )
389+
390+ case other =>
391+ val properties = getJavaBeanProperties(other)
392+ if (properties.length > 0 ) {
393+ CreateNamedStruct (properties.flatMap { p =>
394+ val fieldName = p.getName
395+ val fieldType = typeToken.method(p.getReadMethod).getReturnType
396+ val fieldValue = Invoke (
397+ inputObject,
398+ p.getReadMethod.getName,
399+ inferExternalType(fieldType.getRawType))
400+ expressions.Literal (fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
401+ })
402+ } else {
403+ throw new UnsupportedOperationException (s " no encoder found for ${other.getName}" )
404+ }
405+ }
406+ }
116407 }
117408}
0 commit comments