Skip to content

Commit b4c4b40

Browse files
committed
[SPARK-32106][SQL] Implement script transform in sql/core
1 parent 4cf8c1d commit b4c4b40

File tree

10 files changed

+762
-60
lines changed

10 files changed

+762
-60
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -744,8 +744,30 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
744744
selectClause.hints.asScala.foldRight(withWindow)(withHints)
745745
}
746746

747+
// Script Transform's input/output format.
748+
type ScriptIOFormat =
749+
(Seq[(String, String)], Option[String], Seq[(String, String)], Option[String])
750+
751+
protected def getRowFormatDelimited(ctx: RowFormatDelimitedContext): ScriptIOFormat = {
752+
// TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema
753+
// expects a seq of pairs in which the old parsers' token names are used as keys.
754+
// Transforming the result of visitRowFormatDelimited would be quite a bit messier than
755+
// retrieving the key value pairs ourselves.
756+
def entry(key: String, value: Token): Seq[(String, String)] = {
757+
Option(value).map(t => key -> t.getText).toSeq
758+
}
759+
760+
val entries = entry("TOK_TABLEROWFORMATFIELD", ctx.fieldsTerminatedBy) ++
761+
entry("TOK_TABLEROWFORMATCOLLITEMS", ctx.collectionItemsTerminatedBy) ++
762+
entry("TOK_TABLEROWFORMATMAPKEYS", ctx.keysTerminatedBy) ++
763+
entry("TOK_TABLEROWFORMATLINES", ctx.linesSeparatedBy) ++
764+
entry("TOK_TABLEROWFORMATNULL", ctx.nullDefinedAs)
765+
766+
(entries, None, Seq.empty, None)
767+
}
768+
747769
/**
748-
* Create a (Hive based) [[ScriptInputOutputSchema]].
770+
* Create a [[ScriptInputOutputSchema]].
749771
*/
750772
protected def withScriptIOSchema(
751773
ctx: ParserRuleContext,
@@ -754,7 +776,30 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
754776
outRowFormat: RowFormatContext,
755777
recordReader: Token,
756778
schemaLess: Boolean): ScriptInputOutputSchema = {
757-
throw new ParseException("Script Transform is not supported", ctx)
779+
780+
def format(fmt: RowFormatContext): ScriptIOFormat = fmt match {
781+
case c: RowFormatDelimitedContext =>
782+
getRowFormatDelimited(c)
783+
784+
case c: RowFormatSerdeContext =>
785+
throw new ParseException("TRANSFORM with serde is only supported in hive mode", ctx)
786+
787+
// SPARK-32106: When there is no definition about format, we return empty result
788+
// to use a built-in default Serde in SparkScriptTransformationExec.
789+
case null =>
790+
(Nil, None, Seq.empty, None)
791+
}
792+
793+
val (inFormat, inSerdeClass, inSerdeProps, reader) = format(inRowFormat)
794+
795+
val (outFormat, outSerdeClass, outSerdeProps, writer) = format(outRowFormat)
796+
797+
ScriptInputOutputSchema(
798+
inFormat, outFormat,
799+
inSerdeClass, outSerdeClass,
800+
inSerdeProps, outSerdeProps,
801+
reader, writer,
802+
schemaLess)
758803
}
759804

760805
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.plans._
2424
import org.apache.spark.sql.catalyst.plans.logical._
2525
import org.apache.spark.sql.internal.SQLConf
26-
import org.apache.spark.sql.types.IntegerType
26+
import org.apache.spark.sql.types.{IntegerType, LongType, StringType}
2727

2828
/**
2929
* Parser test cases for rules defined in [[CatalystSqlParser]] / [[AstBuilder]].
@@ -1031,4 +1031,115 @@ class PlanParserSuite extends AnalysisTest {
10311031
assertEqual("select a, b from db.c;;;", table("db", "c").select('a, 'b))
10321032
assertEqual("select a, b from db.c; ;; ;", table("db", "c").select('a, 'b))
10331033
}
1034+
1035+
test("SPARK-32106: TRANSFORM plan") {
1036+
// verify schema less
1037+
assertEqual(
1038+
"""
1039+
|SELECT TRANSFORM(a, b, c)
1040+
|USING 'cat'
1041+
|FROM testData
1042+
""".stripMargin,
1043+
ScriptTransformation(
1044+
Seq('a, 'b, 'c),
1045+
"cat",
1046+
Seq(AttributeReference("key", StringType)(),
1047+
AttributeReference("value", StringType)()),
1048+
UnresolvedRelation(TableIdentifier("testData")),
1049+
ScriptInputOutputSchema(List.empty, List.empty, None, None,
1050+
List.empty, List.empty, None, None, true))
1051+
)
1052+
1053+
// verify without output schema
1054+
assertEqual(
1055+
"""
1056+
|SELECT TRANSFORM(a, b, c)
1057+
|USING 'cat' AS (a, b, c)
1058+
|FROM testData
1059+
""".stripMargin,
1060+
ScriptTransformation(
1061+
Seq('a, 'b, 'c),
1062+
"cat",
1063+
Seq(AttributeReference("a", StringType)(),
1064+
AttributeReference("b", StringType)(),
1065+
AttributeReference("c", StringType)()),
1066+
UnresolvedRelation(TableIdentifier("testData")),
1067+
ScriptInputOutputSchema(List.empty, List.empty, None, None,
1068+
List.empty, List.empty, None, None, false)))
1069+
1070+
// verify with output schema
1071+
assertEqual(
1072+
"""
1073+
|SELECT TRANSFORM(a, b, c)
1074+
|USING 'cat' AS (a int, b string, c long)
1075+
|FROM testData
1076+
""".stripMargin,
1077+
ScriptTransformation(
1078+
Seq('a, 'b, 'c),
1079+
"cat",
1080+
Seq(AttributeReference("a", IntegerType)(),
1081+
AttributeReference("b", StringType)(),
1082+
AttributeReference("c", LongType)()),
1083+
UnresolvedRelation(TableIdentifier("testData")),
1084+
ScriptInputOutputSchema(List.empty, List.empty, None, None,
1085+
List.empty, List.empty, None, None, false)))
1086+
1087+
// verify with ROW FORMAT DELIMETED
1088+
assertEqual(
1089+
"""
1090+
|SELECT TRANSFORM(a, b, c)
1091+
|ROW FORMAT DELIMITED
1092+
|FIELDS TERMINATED BY '\t'
1093+
|COLLECTION ITEMS TERMINATED BY '\u0002'
1094+
|MAP KEYS TERMINATED BY '\u0003'
1095+
|LINES TERMINATED BY '\n'
1096+
|NULL DEFINED AS 'null'
1097+
|USING 'cat' AS (a, b, c)
1098+
|ROW FORMAT DELIMITED
1099+
|FIELDS TERMINATED BY '\t'
1100+
|COLLECTION ITEMS TERMINATED BY '\u0004'
1101+
|MAP KEYS TERMINATED BY '\u0005'
1102+
|LINES TERMINATED BY '\n'
1103+
|NULL DEFINED AS 'NULL'
1104+
|FROM testData
1105+
""".stripMargin,
1106+
ScriptTransformation(
1107+
Seq('a, 'b, 'c),
1108+
"cat",
1109+
Seq(AttributeReference("a", StringType)(),
1110+
AttributeReference("b", StringType)(),
1111+
AttributeReference("c", StringType)()),
1112+
UnresolvedRelation(TableIdentifier("testData")),
1113+
ScriptInputOutputSchema(
1114+
Seq(("TOK_TABLEROWFORMATFIELD", "'\\t'"),
1115+
("TOK_TABLEROWFORMATCOLLITEMS", "'\u0002'"),
1116+
("TOK_TABLEROWFORMATMAPKEYS", "'\u0003'"),
1117+
("TOK_TABLEROWFORMATLINES", "'\\n'"),
1118+
("TOK_TABLEROWFORMATNULL", "'null'")),
1119+
Seq(("TOK_TABLEROWFORMATFIELD", "'\\t'"),
1120+
("TOK_TABLEROWFORMATCOLLITEMS", "'\u0004'"),
1121+
("TOK_TABLEROWFORMATMAPKEYS", "'\u0005'"),
1122+
("TOK_TABLEROWFORMATLINES", "'\\n'"),
1123+
("TOK_TABLEROWFORMATNULL", "'NULL'")), None, None,
1124+
List.empty, List.empty, None, None, false)))
1125+
1126+
// verify with ROW FORMAT SERDE
1127+
intercept(
1128+
"""
1129+
|SELECT TRANSFORM(a, b, c)
1130+
|ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde'
1131+
|WITH SERDEPROPERTIES(
1132+
| "separatorChar" = "\t",
1133+
| "quoteChar" = "'",
1134+
| "escapeChar" = "\\")
1135+
|USING 'cat' AS (a, b, c)
1136+
|ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde'
1137+
|WITH SERDEPROPERTIES(
1138+
| "separatorChar" = "\t",
1139+
| "quoteChar" = "'",
1140+
| "escapeChar" = "\\")
1141+
|FROM testData
1142+
""".stripMargin,
1143+
"TRANSFORM with serde is only supported in hive mode")
1144+
}
10341145
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class SparkPlanner(
4646
Window ::
4747
JoinSelection ::
4848
InMemoryScans ::
49+
SparkScripts ::
4950
BasicOperators :: Nil)
5051

5152
/**
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import java.io._
21+
22+
import org.apache.hadoop.conf.Configuration
23+
24+
import org.apache.spark.TaskContext
25+
import org.apache.spark.sql.catalyst.InternalRow
26+
import org.apache.spark.sql.catalyst.expressions._
27+
import org.apache.spark.sql.types._
28+
import org.apache.spark.util.CircularBuffer
29+
30+
/**
31+
* Transforms the input by forking and running the specified script.
32+
*
33+
* @param input the set of expression that should be passed to the script.
34+
* @param script the command that should be executed.
35+
* @param output the attributes that are produced by the script.
36+
*/
37+
case class SparkScriptTransformationExec(
38+
input: Seq[Expression],
39+
script: String,
40+
output: Seq[Attribute],
41+
child: SparkPlan,
42+
ioschema: ScriptTransformationIOSchema)
43+
extends BaseScriptTransformationExec {
44+
45+
override def processIterator(
46+
inputIterator: Iterator[InternalRow],
47+
hadoopConf: Configuration): Iterator[InternalRow] = {
48+
49+
val (outputStream, proc, inputStream, stderrBuffer) = initProc
50+
51+
val outputProjection = new InterpretedProjection(inputExpressionsWithoutSerde, child.output)
52+
53+
// This new thread will consume the ScriptTransformation's input rows and write them to the
54+
// external process. That process's output will be read by this current thread.
55+
val writerThread = SparkScriptTransformationWriterThread(
56+
inputIterator.map(outputProjection),
57+
inputExpressionsWithoutSerde.map(_.dataType),
58+
ioschema,
59+
outputStream,
60+
proc,
61+
stderrBuffer,
62+
TaskContext.get(),
63+
hadoopConf
64+
)
65+
66+
val outputIterator =
67+
createOutputIteratorWithoutSerde(writerThread, inputStream, proc, stderrBuffer)
68+
69+
writerThread.start()
70+
71+
outputIterator
72+
}
73+
}
74+
75+
case class SparkScriptTransformationWriterThread(
76+
iter: Iterator[InternalRow],
77+
inputSchema: Seq[DataType],
78+
ioSchema: ScriptTransformationIOSchema,
79+
outputStream: OutputStream,
80+
proc: Process,
81+
stderrBuffer: CircularBuffer,
82+
taskContext: TaskContext,
83+
conf: Configuration)
84+
extends BaseScriptTransformationWriterThread {
85+
86+
override def processRows(): Unit = {
87+
processRowsWithoutSerde()
88+
}
89+
}

0 commit comments

Comments
 (0)