diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 3ecc137c8cd7f..601e024aafb90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -104,6 +104,11 @@ object JavaTypeInference { case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) + case c if ExpressionEncoderUtils.hasEncoderForClass(c) => + // User-defined type. Use user-defined Encoder to get schema for this type + val dataType = ExpressionEncoderUtils.getEncoderForClass(c).schemaFor(c) + (dataType, true) + case _ if typeToken.isArray => val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet) (ArrayType(dataType, nullable), true) @@ -277,6 +282,10 @@ object JavaTypeInference { inferDataType(et)._1, customCollectionCls = Some(c)) + case c if ExpressionEncoderUtils.hasEncoderForClass(c) => + // User-defined type. Use user-defined Encoder to get deserializer for this type + ExpressionEncoderUtils.getEncoderForClass(c).deserializerFor(path, c) + case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) val keyDataType = inferDataType(keyType)._1 @@ -422,6 +431,10 @@ object JavaTypeInference { case c if c == classOf[java.lang.Double] => Invoke(inputObject, "doubleValue", DoubleType) + case c if ExpressionEncoderUtils.hasEncoderForClass(c) => + // User-defined type. Use user-defined Encoder to get serializer for this type + ExpressionEncoderUtils.getEncoderForClass(c).serializerFor(inputObject, c) + case _ if typeToken.isArray => toCatalystArray(inputObject, typeToken.getComponentType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index f9acc208b715e..dd29241e9e070 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -416,6 +416,11 @@ object ScalaReflection extends ScalaReflection { } else { newInstance } + + case t if ExpressionEncoderUtils.hasEncoderForClass(getClassFromType(t)) => + // User-defined type. Use user-defined Encoder to get deserializer for this type + ExpressionEncoderUtils.getEncoderForClass(getClassFromType(t)) + .deserializerFor(path, getClassFromType(tpe)) } } @@ -643,6 +648,11 @@ object ScalaReflection extends ScalaReflection { val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) + case t if ExpressionEncoderUtils.hasEncoderForClass(getClassFromType(t)) => + // User-defined type. Use user-defined Encoder to get serializer for this type + ExpressionEncoderUtils.getEncoderForClass(getClassFromType(t)) + .serializerFor(inputObject, getClassFromType(tpe)) + case other => throw new UnsupportedOperationException( s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) @@ -780,6 +790,13 @@ object ScalaReflection extends ScalaReflection { val Schema(dataType, nullable) = schemaFor(fieldType) StructField(fieldName, dataType, nullable) }), nullable = true) + + case t if ExpressionEncoderUtils.hasEncoderForClass(getClassFromType(t)) => + // User-defined type. Use user-defined Encoder to get schema for this type + val dataType = ExpressionEncoderUtils.getEncoderForClass(getClassFromType(t)) + .schemaFor(getClassFromType(tpe)) + Schema(dataType, nullable = true) + case other => throw new UnsupportedOperationException(s"Schema for type $other is not supported") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderUtils.scala new file mode 100644 index 0000000000000..e3e2ab47ed087 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderUtils.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import scala.collection.mutable.{Map => MutableMap} + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.types.DataType +import org.apache.spark.util.Utils + +/* + * Trait for user-defined Encoder to be used within ExpressionEncoder. + * e.g. Define a avro field (AvroEncoder) in scala case class (ExpressionEncoder) in Dataset. + * + * User needs to configure the user-defined type and Encoder(which extends this trait) + * through conf spark.expressionencoder: + * e.g. spark.expressionencoder.org.apache.avro.specific.SpecificRecord + * = com.databricks.spark.avro.AvroEncoder$ + * This enables Dataset of case class can have SpecificRecord typed field to use AvroEncoder to + * ser/de this field within the case class. + * + * Encoder class extending this trait needs to have a default no-arg constructor, Encoder + * singleton object extending this trait needs to have class name suffix $ in the conf. + */ +trait EncoderWithinExpressionEncoder { + /** + * get the SQL data type for given class type T. + * This is required in ExpressionEncoder's schemaFor method to get SQL schema for + * user-defined field. + * + * @param inputClass class of T + * @return DataType spark sql schema of T + */ + def schemaFor[T](inputClass: Class[T]): DataType + + /** + * get the serializer used to serializer object type T into internal row. + * This is required in ExpressionEncoder's serializerFor method to get serializer for + * user-defined field. + * + * @param inputObject input object T as Expression + * @param inputClass class of T + * @return serializer as Expression + */ + def serializerFor[T](inputObject: Expression, inputClass: Class[T]): Expression + + /** + * get the deserializer used to deserializer internal row into object type T. + * This is required in ExpressionEncoder's deserializerFor method to get deserializer for + * user-defined field. + * + * @param path input path as Expression + * @param inputClass class of T + * @return deserializer as Expression + */ + def deserializerFor[T](path: Option[Expression], inputClass: Class[T]): Expression +} + +/* + * Util class for using user-defined type and Encoder within ExpressionEncoder. + * + * The user-defined Encoder must extend EncoderWithinExpressionEncoder trait. + * + * Encoder extending this trait needs to have a default no-arg constructor, singleton + * Encoder object extending this trait needs to have class name suffix $ in the conf + */ +object ExpressionEncoderUtils { + + /** + * Get the spark conf from SparkEnv + */ + lazy val conf: SparkConf = { + if (SparkEnv.get != null) { + SparkEnv.get.conf + } else { + // Fall back if SparkEnv not initialized. e.g. unit test + new SparkConf + } + } + + /** + * Get the array of user-defined types and the encoders that can be used + * inside ExpressionEncoder. + */ + lazy private val typeClassToEncoderClass: Array[(Class[_], Class[_])] = + (conf) + .getAllWithPrefix("spark.expressionencoder.") + .filter{case(k, v) => Utils.classIsLoadable(k) && Utils.classIsLoadable(v)} + .map{case(k, v) => (Utils.classForName(k), Utils.classForName(v))} + + /** + * Encoder instance cache for the user-defined types + * to avoid duplicate Encoder instantiation. + */ + lazy private val encoderCache: MutableMap[Class[_], EncoderWithinExpressionEncoder] = + MutableMap[Class[_], EncoderWithinExpressionEncoder]() + + /** + * Check if given user-defined type has Encoder configured. + */ + def hasEncoderForClass(clz: Class[_]): Boolean = + encoderCache.contains(clz) || + typeClassToEncoderClass.filter(_._1 isAssignableFrom clz).size > 0 + + /** + * Return Encoder for user-defined type. + */ + def getEncoderForClass(clz: Class[_]): EncoderWithinExpressionEncoder = { + encoderCache.getOrElseUpdate(clz, findEncoderForClass(clz)) + } + + /** + * Instantiate and return Encoder instance for the user-defined type. + * Check only one Encoder is present for the type and the Encoder class must + * implement trait EncoderWithinExpressionEncoder + */ + private def findEncoderForClass(clz: Class[_]): EncoderWithinExpressionEncoder = { + val encoders = typeClassToEncoderClass.filter(_._1 isAssignableFrom clz) + .map(_._2) + assert(encoders.size == 1, + s"More than one encoder in spark.expressionencoder exists " + + s"for class: $clz.getName") + + val encoder = encoders.head + assert(classOf[EncoderWithinExpressionEncoder] isAssignableFrom encoder, + s"${encoder} does not extend trait EncoderWithinExpressionEncoder") + + // If encoder is a singleton object(end with $), return the singleton object + if (encoder.getName.endsWith("$")) { + encoder.getField("MODULE$").get(null) + .asInstanceOf[EncoderWithinExpressionEncoder] + } else { + // The encoder should be a class that has no-arg constructor + encoder.newInstance.asInstanceOf[EncoderWithinExpressionEncoder] + } + } +}