diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 5e29e57d93585..095ed5bdf4c5d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -633,14 +633,28 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C Token(script, Nil) :: Token("TOK_SERDE", serdeClause) :: Token("TOK_RECORDREADER", readerClause) :: - outputClause :: Nil) :: Nil) => + outputClause) :: Nil) => + // TODO the output should be bind with the output clause or RecordReader val output = outputClause match { - case Token("TOK_ALIASLIST", aliases) => + case Token("TOK_ALIASLIST", aliases) :: Nil => aliases.map { case Token(name, Nil) => AttributeReference(name, StringType)() } - case Token("TOK_TABCOLLIST", attributes) => + case Token("TOK_TABCOLLIST", attributes) :: Nil => attributes.map { case Token("TOK_TABCOL", Token(name, Nil) :: dataType :: Nil) => AttributeReference(name, nodeToDataType(dataType))() } + case Nil => // Not specified the output field names, let it be the same as input + (0 to inputExprs.length - 1).map { idx => + // Keep the same as Hive does, the first field names is "key", and second is + // "value", however, Hive seems gives null string for the rest of the + // field name, which supposed to be a bug of Hive. + if (idx == 0) { + AttributeReference("key", StringType)() + } else if (idx == 1) { + AttributeReference("value", StringType)() + } else { + AttributeReference(s"_col${idx - 2}", StringType)() + } + } } val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 0c8f676e9c5c8..31e4eb5ae982a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution import java.io.{BufferedReader, InputStreamReader} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.types.StringType import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveContext @@ -54,23 +55,44 @@ case class ScriptTransformation( val outputStream = proc.getOutputStream val reader = new BufferedReader(new InputStreamReader(inputStream)) + // This projection outputs to the script, which runs in a single process + // TODO a Writer SerDe will be placed here. + val inputProjection = new InterpretedProjection(input, child.output) + + // This projection is casting the scripts output into user specified data type + // TODO a Reader SerDe will be placed here for the casting the output + // data type into the required one + val outputProjection = new InterpretedProjection(output.zipWithIndex.map { + case (attr, idx) if (attr.dataType == StringType) => BoundReference(idx, StringType, true) + case (attr, idx) => Cast(BoundReference(idx, StringType, true), attr.dataType) + }, output) + // TODO: This should be exposed as an iterator instead of reading in all the data at once. val outputLines = collection.mutable.ArrayBuffer[Row]() val readerThread = new Thread("Transform OutputReader") { + val row = new GenericMutableRow(output.length) override def run() { var curLine = reader.readLine() while (curLine != null) { - // TODO: Use SerDe - outputLines += new GenericRow(curLine.split("\t").asInstanceOf[Array[Any]]) + // TODO: A Reader SerDe will be placed here. + val splits = curLine.split("\t") + var idx = 0 + while (idx < output.length) { + row(idx) = if (idx < splits.length) splits(idx) else null + idx += 1 + } + + outputLines += outputProjection(row) curLine = reader.readLine() } } } + readerThread.start() - val outputProjection = new InterpretedProjection(input, child.output) + iter - .map(outputProjection) - // TODO: Use SerDe + .map(inputProjection) + // TODO: Use the Writer SerDe .map(_.mkString("", "\t", "\n").getBytes("utf-8")).foreach(outputStream.write) outputStream.close() readerThread.join() diff --git a/sql/hive/src/test/resources/golden/TRANSFORM #1 (without serde specified)-0-e9b26dfb0d994cded154f5c4488d88fd b/sql/hive/src/test/resources/golden/TRANSFORM #1 (without serde specified)-0-e9b26dfb0d994cded154f5c4488d88fd new file mode 100644 index 0000000000000..65989ffc2b59d --- /dev/null +++ b/sql/hive/src/test/resources/golden/TRANSFORM #1 (without serde specified)-0-e9b26dfb0d994cded154f5c4488d88fd @@ -0,0 +1,3 @@ +1 val_0 +1 val_0 +1 val_0 diff --git a/sql/hive/src/test/resources/golden/TRANSFORM #2 (without output field names specified)-0-9b03785805377ca184093e467928898e b/sql/hive/src/test/resources/golden/TRANSFORM #2 (without output field names specified)-0-9b03785805377ca184093e467928898e new file mode 100644 index 0000000000000..65989ffc2b59d --- /dev/null +++ b/sql/hive/src/test/resources/golden/TRANSFORM #2 (without output field names specified)-0-9b03785805377ca184093e467928898e @@ -0,0 +1,3 @@ +1 val_0 +1 val_0 +1 val_0 diff --git a/sql/hive/src/test/resources/golden/TRANSFORM #3 (with data type specified)-0-ec79190a8f8c3124fcde7122e1fc8b3f b/sql/hive/src/test/resources/golden/TRANSFORM #3 (with data type specified)-0-ec79190a8f8c3124fcde7122e1fc8b3f new file mode 100644 index 0000000000000..2524aa511f40e --- /dev/null +++ b/sql/hive/src/test/resources/golden/TRANSFORM #3 (with data type specified)-0-ec79190a8f8c3124fcde7122e1fc8b3f @@ -0,0 +1,2 @@ +239.0 val_238 +239.0 val_238 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index df72be7746ac6..514fc4467ada2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -62,7 +62,26 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("SHOW TABLES") } } - + + createQueryTest("TRANSFORM #1 (without serde specified)", + """ + |SELECT transform(key + 1, value) USING '/bin/cat' AS a, b + |FROM src ORDER BY a, b DESC LIMIT 3 + """.stripMargin) + + createQueryTest("TRANSFORM #2 (without output field names specified)", + """ + |SELECT transform(key + 1, value) USING '/bin/cat' + |FROM src ORDER BY key, value DESC LIMIT 3 + """.stripMargin) + + createQueryTest("TRANSFORM #3 (with data type specified)", + """ + | SELECT a, b FROM (SELECT transform(key + 1, value) + | USING '/bin/cat' AS (a FLOAT, b STRING) + | FROM src) t WHERE a = 239.0 + """.stripMargin) + createQueryTest("! operator", """ |SELECT a FROM (