diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index c7651daffe36..ac72528f554f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -21,9 +21,13 @@ import java.io._ import java.util.Properties import javax.annotation.Nullable +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.exec.{RecordWriter, RecordReader} + import scala.collection.JavaConverters._ import scala.util.control.NonFatal +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ @@ -34,10 +38,9 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.execution._ -import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} import org.apache.spark.sql.types.DataType -import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} +import org.apache.spark.util.{SerializableConfiguration, CircularBuffer, RedirectThread, Utils} import org.apache.spark.{Logging, TaskContext} /** @@ -58,6 +61,8 @@ case class ScriptTransformation( override def otherCopyArgs: Seq[HiveContext] = sc :: Nil + private val _broadcastedHiveConf = new SerializableConfiguration(sc.hiveconf) + protected override def doExecute(): RDD[InternalRow] = { def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = { val cmd = List("/bin/bash", "-c", script) @@ -67,6 +72,7 @@ case class ScriptTransformation( val inputStream = proc.getInputStream val outputStream = proc.getOutputStream val errorStream = proc.getErrorStream + val localHconf = _broadcastedHiveConf.value // In order to avoid deadlocks, we need to consume the error output of the child process. // To avoid issues caused by large error output, we use a circular buffer to limit the amount @@ -96,19 +102,23 @@ case class ScriptTransformation( outputStream, proc, stderrBuffer, - TaskContext.get() + TaskContext.get(), + localHconf ) // This nullability is a performance optimization in order to avoid an Option.foreach() call // inside of a loop - @Nullable val (outputSerde, outputSoi) = { - ioschema.initOutputSerDe(output).getOrElse((null, null)) + @Nullable val (outputSerde, outputSoi, tableProperties) = { + ioschema.initOutputSerDe(output).getOrElse((null, null, null)) } val reader = new BufferedReader(new InputStreamReader(inputStream)) val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { var curLine: String = null - val scriptOutputStream = new DataInputStream(inputStream) + val scriptOutputReader: RecordReader = ioschema.getRecordReader(localHconf) + + scriptOutputReader.initialize( + new DataInputStream(inputStream), _broadcastedHiveConf.value, tableProperties) var scriptOutputWritable: Writable = null val reusedWritableObject: Writable = if (null != outputSerde) { outputSerde.getSerializedClass().newInstance @@ -134,15 +144,13 @@ case class ScriptTransformation( } } else if (scriptOutputWritable == null) { scriptOutputWritable = reusedWritableObject - try { - scriptOutputWritable.readFields(scriptOutputStream) + if (scriptOutputReader.next(scriptOutputWritable) <= 0) { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + false + } else { true - } catch { - case _: EOFException => - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - false } } else { true @@ -210,7 +218,8 @@ private class ScriptTransformationWriterThread( outputStream: OutputStream, proc: Process, stderrBuffer: CircularBuffer, - taskContext: TaskContext + taskContext: TaskContext, + localHconf: Configuration ) extends Thread("Thread-ScriptTransformation-Feed") with Logging { setDaemon(true) @@ -222,9 +231,9 @@ private class ScriptTransformationWriterThread( override def run(): Unit = Utils.logUncaughtExceptions { TaskContext.setTaskContext(taskContext) - val dataOutputStream = new DataOutputStream(outputStream) - + val scriptInWriter: RecordWriter = ioschema.getRecordWriter(localHconf) + scriptInWriter.initialize(dataOutputStream, localHconf) // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so // let's use a variable to record whether the `finally` block was hit due to an exception var threwException: Boolean = true @@ -250,7 +259,7 @@ private class ScriptTransformationWriterThread( } else { val writable = inputSerde.serialize( row.asInstanceOf[GenericInternalRow].values, inputSoi) - prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream) + scriptInWriter.write(writable) } } outputStream.close() @@ -300,6 +309,19 @@ case class HiveScriptIOSchema ( val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) + def getRecordReader(conf: Configuration): RecordReader = { + // TODO: add support to get reader from sql clause + val readerName = + HiveConf.getVar(conf, HiveConf.ConfVars.HIVESCRIPTRECORDREADER) + Utils.classForName(readerName).newInstance.asInstanceOf[RecordReader] + } + + def getRecordWriter(conf: Configuration): RecordWriter = { + // TODO: add support to get writer from sql clause + val writerName = + HiveConf.getVar(conf, HiveConf.ConfVars.HIVESCRIPTRECORDWRITER) + Utils.classForName(writerName).newInstance.asInstanceOf[RecordWriter] + } def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, ObjectInspector)] = { inputSerdeClass.map { serdeClass => @@ -313,12 +335,14 @@ case class HiveScriptIOSchema ( } } - def initOutputSerDe(output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = { + def initOutputSerDe(output: Seq[Attribute]) + : Option[(AbstractSerDe, StructObjectInspector, Properties)] = { outputSerdeClass.map { serdeClass => val (columns, columnTypes) = parseAttrs(output) val serde = initSerDe(serdeClass, columns, columnTypes, outputSerdeProps) val structObjectInspector = serde.getObjectInspector().asInstanceOf[StructObjectInspector] - (serde, structObjectInspector) + (serde, structObjectInspector, + createTableProperties(serdeClass, columns, columnTypes, outputSerdeProps)) } } @@ -328,23 +352,27 @@ case class HiveScriptIOSchema ( (columns, columnTypes) } - private def initSerDe( + private def createTableProperties( serdeClassName: String, columns: Seq[String], columnTypes: Seq[DataType], - serdeProps: Seq[(String, String)]): AbstractSerDe = { - - val serde = Utils.classForName(serdeClassName).newInstance.asInstanceOf[AbstractSerDe] - + serdeProps: Seq[(String, String)]) = { val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") - var propsMap = serdeProps.toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) - + propsMap = propsMap + (serdeConstants.FIELD_DELIM -> "\t") val properties = new Properties() properties.putAll(propsMap.asJava) - serde.initialize(null, properties) + properties + } + private def initSerDe( + serdeClassName: String, + columns: Seq[String], + columnTypes: Seq[DataType], + serdeProps: Seq[(String, String)]): AbstractSerDe = { + val serde = Utils.classForName(serdeClassName).newInstance.asInstanceOf[AbstractSerDe] + serde.initialize(null, createTableProperties(serdeClassName, columns, columnTypes, serdeProps)) serde } } diff --git a/sql/hive/src/test/resources/data/scripts/test_transript.py b/sql/hive/src/test/resources/data/scripts/test_transript.py new file mode 100644 index 000000000000..e8110b4a9015 --- /dev/null +++ b/sql/hive/src/test/resources/data/scripts/test_transript.py @@ -0,0 +1,7 @@ +import sys + +for line in sys.stdin: + arr = line.strip().split("\t") + for i in range(len(arr)): + arr[i] = arr[i] + "#" + print("\t".join(arr)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 83f9f3eaa3a5..ba6a6a5a4ed1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -429,7 +429,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' FROM src; """.stripMargin.replaceAll(System.lineSeparator(), " ")) - test("transform with SerDe2") { + // TODO: Only support serde which compatible with TextRecordReader at the moment. + ignore("transform with SerDe2") { sql("CREATE TABLE small_src(key INT, value STRING)") sql("INSERT OVERWRITE TABLE small_src SELECT key, value FROM src LIMIT 10") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1ff1d9a2934c..ff188e2d16ea 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -763,6 +763,18 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { """.stripMargin), (2 to 6).map(i => Row(i))) } + test("test script transform script input and output format") { + val data = (1 to 5).map { i => (i, i) } + data.toDF("a", "b").registerTempTable("test") + checkAnswer( + sql("""FROM + |(FROM test SELECT TRANSFORM(a, b) + |USING 'python src/test/resources/data/scripts/test_transript.py' + |AS (thing1 string, thing2 string)) t + |SELECT thing1 + """.stripMargin), (1 to 5).map(i => Row(i + "#"))) + } + test("window function: udaf with aggregate expressin") { val data = Seq( WindowData(1, "a", 5),