Skip to content

Commit 83a9791

Browse files
davidm-dbcloud-fan
authored andcommitted
[SPARK-48388][SQL] Fix SET statement behavior for SQL Scripts
### What changes were proposed in this pull request? `SET` statement is used to set config values and it has a poorly designed grammar rule `#setConfiguration` that matches everything after `SET` - `SET .*?`. This conflicts with the usage of `SET` for setting session variables, and we needed to introduce `SET (VAR | VARIABLE)` grammar rule to make distinction between setting the config values and session variables - [SET VAR pull request](#40474). However, this is not by SQL standard, so for SQL scripting ([JIRA](https://issues.apache.org/jira/browse/SPARK-48338)) we are opting to disable `SET` for configs and use it only for session variables. This enables use to use only `SET` for setting values to session variables. Config values can still be set from SQL scripts using `EXECUTE IMMEDIATE`. This change simply reorders grammar rules to achieve above behavior, and alters only visitor functions where name of the rule had to be changed or completely new rule was added. ### Why are the changes needed? These changes are supposed to resolve the issues poorly designed `SET` statement for the case of SQL scripts. ### Does this PR introduce _any_ user-facing change? No. This PR is in a series of PRs that will introduce changes to sql() API to add support for SQL scripting, but for now, the API remains unchanged. In the future, the API will remain the same as well, but it will have new possibility to execute SQL scripts. ### How was this patch tested? Already existing tests should cover the changes. New tests for SQL scripts were added to: - `SqlScriptingParserSuite` - `SqlScriptingInterpreterSuite` ### Was this patch authored or co-authored using generative AI tooling? Closes #47272 from davidm-db/sql_scripting_set_statement. Authored-by: David Milicevic <david.milicevic@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 1a428c1 commit 83a9791

File tree

7 files changed

+143
-45
lines changed

7 files changed

+143
-45
lines changed

sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,18 @@ compoundBody
6161

6262
compoundStatement
6363
: statement
64+
| setStatementWithOptionalVarKeyword
6465
| beginEndCompoundBlock
6566
;
6667

68+
setStatementWithOptionalVarKeyword
69+
: SET (VARIABLE | VAR)? assignmentList #setVariableWithOptionalKeyword
70+
| SET (VARIABLE | VAR)? LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ
71+
LEFT_PAREN query RIGHT_PAREN #setVariableWithOptionalKeyword
72+
;
73+
6774
singleStatement
68-
: statement SEMICOLON* EOF
75+
: (statement|setResetStatement) SEMICOLON* EOF
6976
;
7077

7178
beginLabel
@@ -212,7 +219,7 @@ statement
212219
identifierReference dataType? variableDefaultExpression? #createVariable
213220
| DROP TEMPORARY VARIABLE (IF EXISTS)? identifierReference #dropVariable
214221
| EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)?
215-
statement #explain
222+
(statement|setResetStatement) #explain
216223
| SHOW TABLES ((FROM | IN) identifierReference)?
217224
(LIKE? pattern=stringLit)? #showTables
218225
| SHOW TABLE EXTENDED ((FROM | IN) ns=identifierReference)?
@@ -251,26 +258,29 @@ statement
251258
| (MSCK)? REPAIR TABLE identifierReference
252259
(option=(ADD|DROP|SYNC) PARTITIONS)? #repairTable
253260
| op=(ADD | LIST) identifier .*? #manageResource
254-
| SET COLLATION collationName=identifier #setCollation
255-
| SET ROLE .*? #failNativeCommand
261+
| CREATE INDEX (IF errorCapturingNot EXISTS)? identifier ON TABLE?
262+
identifierReference (USING indexType=identifier)?
263+
LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN
264+
(OPTIONS options=propertyList)? #createIndex
265+
| DROP INDEX (IF EXISTS)? identifier ON TABLE? identifierReference #dropIndex
266+
| unsupportedHiveNativeCommands .*? #failNativeCommand
267+
;
268+
269+
setResetStatement
270+
: SET COLLATION collationName=identifier #setCollation
271+
| SET ROLE .*? #failSetRole
256272
| SET TIME ZONE interval #setTimeZone
257273
| SET TIME ZONE timezone #setTimeZone
258274
| SET TIME ZONE .*? #setTimeZone
259275
| SET (VARIABLE | VAR) assignmentList #setVariable
260276
| SET (VARIABLE | VAR) LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ
261-
LEFT_PAREN query RIGHT_PAREN #setVariable
277+
LEFT_PAREN query RIGHT_PAREN #setVariable
262278
| SET configKey EQ configValue #setQuotedConfiguration
263279
| SET configKey (EQ .*?)? #setConfiguration
264280
| SET .*? EQ configValue #setQuotedConfiguration
265281
| SET .*? #setConfiguration
266282
| RESET configKey #resetQuotedConfiguration
267283
| RESET .*? #resetConfiguration
268-
| CREATE INDEX (IF errorCapturingNot EXISTS)? identifier ON TABLE?
269-
identifierReference (USING indexType=identifier)?
270-
LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN
271-
(OPTIONS options=propertyList)? #createIndex
272-
| DROP INDEX (IF EXISTS)? identifier ON TABLE? identifierReference #dropIndex
273-
| unsupportedHiveNativeCommands .*? #failNativeCommand
274284
;
275285

276286
executeImmediate

sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ case class UnclosedCommentProcessor(
402402
override def exitSingleStatement(ctx: SqlBaseParser.SingleStatementContext): Unit = {
403403
// SET command uses a wildcard to match anything, and we shouldn't parse the comments, e.g.
404404
// `SET myPath =/a/*`.
405-
if (!ctx.statement().isInstanceOf[SqlBaseParser.SetConfigurationContext]) {
405+
if (!ctx.setResetStatement().isInstanceOf[SqlBaseParser.SetConfigurationContext]) {
406406
checkUnclosedComment(tokenStream, command)
407407
}
408408
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -170,15 +170,20 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
170170

171171
override def visitCompoundStatement(ctx: CompoundStatementContext): CompoundPlanStatement =
172172
withOrigin(ctx) {
173-
Option(ctx.statement()).map {s =>
174-
SingleStatement(parsedPlan = visit(s).asInstanceOf[LogicalPlan])
175-
}.getOrElse {
176-
visit(ctx.beginEndCompoundBlock()).asInstanceOf[CompoundPlanStatement]
177-
}
173+
Option(ctx.statement().asInstanceOf[ParserRuleContext])
174+
.orElse(Option(ctx.setStatementWithOptionalVarKeyword().asInstanceOf[ParserRuleContext]))
175+
.map { s =>
176+
SingleStatement(parsedPlan = visit(s).asInstanceOf[LogicalPlan])
177+
}.getOrElse {
178+
visit(ctx.beginEndCompoundBlock()).asInstanceOf[CompoundPlanStatement]
179+
}
178180
}
179181

180182
override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
181-
visit(ctx.statement).asInstanceOf[LogicalPlan]
183+
Option(ctx.statement().asInstanceOf[ParserRuleContext])
184+
.orElse(Option(ctx.setResetStatement().asInstanceOf[ParserRuleContext]))
185+
.map { s => visit(s).asInstanceOf[LogicalPlan] }
186+
.get
182187
}
183188

184189
override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) {
@@ -5461,26 +5466,20 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
54615466
)
54625467
}
54635468

5464-
/**
5465-
* Create a [[SetVariable]] command.
5466-
*
5467-
* For example:
5468-
* {{{
5469-
* SET VARIABLE var1 = v1, var2 = v2, ...
5470-
* SET VARIABLE (var1, var2, ...) = (SELECT ...)
5471-
* }}}
5472-
*/
5473-
override def visitSetVariable(ctx: SetVariableContext): LogicalPlan = withOrigin(ctx) {
5474-
if (ctx.query() != null) {
5469+
private def visitSetVariableImpl(
5470+
query: QueryContext,
5471+
multipartIdentifierList: MultipartIdentifierListContext,
5472+
assignmentList: AssignmentListContext): LogicalPlan = {
5473+
if (query != null) {
54755474
// The SET variable source is a query
5476-
val variables = ctx.multipartIdentifierList.multipartIdentifier.asScala.map { variableIdent =>
5475+
val variables = multipartIdentifierList.multipartIdentifier.asScala.map { variableIdent =>
54775476
val varName = visitMultipartIdentifier(variableIdent)
54785477
UnresolvedAttribute(varName)
54795478
}.toSeq
5480-
SetVariable(variables, visitQuery(ctx.query()))
5479+
SetVariable(variables, visitQuery(query))
54815480
} else {
54825481
// The SET variable source is list of expressions.
5483-
val (variables, values) = ctx.assignmentList().assignment().asScala.map { assign =>
5482+
val (variables, values) = assignmentList.assignment().asScala.map { assign =>
54845483
val varIdent = visitMultipartIdentifier(assign.key)
54855484
val varExpr = expression(assign.value)
54865485
val varNamedExpr = varExpr match {
@@ -5492,4 +5491,23 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
54925491
SetVariable(variables, Project(values, OneRowRelation()))
54935492
}
54945493
}
5494+
5495+
/**
5496+
* Create a [[SetVariable]] command.
5497+
*
5498+
* For example:
5499+
* {{{
5500+
* SET VARIABLE var1 = v1, var2 = v2, ...
5501+
* SET VARIABLE (var1, var2, ...) = (SELECT ...)
5502+
* }}}
5503+
*/
5504+
override def visitSetVariable(ctx: SetVariableContext): LogicalPlan = withOrigin(ctx) {
5505+
visitSetVariableImpl(ctx.query(), ctx.multipartIdentifierList(), ctx.assignmentList())
5506+
}
5507+
5508+
override def visitSetVariableWithOptionalKeyword(
5509+
ctx: SetVariableWithOptionalKeywordContext): LogicalPlan =
5510+
withOrigin(ctx) {
5511+
visitSetVariableImpl(ctx.query(), ctx.multipartIdentifierList(), ctx.assignmentList())
5512+
}
54955513
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class ParserUtilsSuite extends SparkFunSuite {
2929
import ParserUtils._
3030

3131
val setConfContext = buildContext("set example.setting.name=setting.value") { parser =>
32-
parser.statement().asInstanceOf[SetConfigurationContext]
32+
parser.setResetStatement().asInstanceOf[SetConfigurationContext]
3333
}
3434

3535
val showFuncContext = buildContext("show functions foo.bar") { parser =>

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,59 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
263263
assert(tree.label.nonEmpty)
264264
}
265265

266+
test("SET VAR statement test") {
267+
val sqlScriptText =
268+
"""
269+
|BEGIN
270+
| DECLARE totalInsCnt = 0;
271+
| SET VAR totalInsCnt = (SELECT x FROM y WHERE id = 1);
272+
|END""".stripMargin
273+
val tree = parseScript(sqlScriptText)
274+
assert(tree.collection.length == 2)
275+
assert(tree.collection.head.isInstanceOf[SingleStatement])
276+
assert(tree.collection(1).isInstanceOf[SingleStatement])
277+
}
278+
279+
test("SET VARIABLE statement test") {
280+
val sqlScriptText =
281+
"""
282+
|BEGIN
283+
| DECLARE totalInsCnt = 0;
284+
| SET VARIABLE totalInsCnt = (SELECT x FROM y WHERE id = 1);
285+
|END""".stripMargin
286+
val tree = parseScript(sqlScriptText)
287+
assert(tree.collection.length == 2)
288+
assert(tree.collection.head.isInstanceOf[SingleStatement])
289+
assert(tree.collection(1).isInstanceOf[SingleStatement])
290+
}
291+
292+
test("SET statement test") {
293+
val sqlScriptText =
294+
"""
295+
|BEGIN
296+
| DECLARE totalInsCnt = 0;
297+
| SET totalInsCnt = (SELECT x FROM y WHERE id = 1);
298+
|END""".stripMargin
299+
val tree = parseScript(sqlScriptText)
300+
assert(tree.collection.length == 2)
301+
assert(tree.collection.head.isInstanceOf[SingleStatement])
302+
assert(tree.collection(1).isInstanceOf[SingleStatement])
303+
}
304+
305+
test("SET statement test - should fail") {
306+
val sqlScriptText =
307+
"""
308+
|BEGIN
309+
| DECLARE totalInsCnt = 0;
310+
| SET totalInsCnt = (SELECT x FROMERROR y WHERE id = 1);
311+
|END""".stripMargin
312+
val e = intercept[ParseException] {
313+
parseScript(sqlScriptText)
314+
}
315+
assert(e.getErrorClass === "PARSE_SYNTAX_ERROR")
316+
assert(e.getMessage.contains("Syntax error"))
317+
}
318+
266319
// Helper methods
267320
def cleanupStatementString(statementStr: String): String = {
268321
statementStr

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ class SparkSqlAstBuilder extends AstBuilder {
238238
invalidStatement("EXPLAIN LOGICAL", ctx)
239239
}
240240

241-
val statement = plan(ctx.statement)
241+
val statement = plan(Option(ctx.statement()).getOrElse(ctx.setResetStatement()))
242242
if (statement == null) {
243243
null // This is enough since ParseException will raise later.
244244
} else {
@@ -399,21 +399,20 @@ class SparkSqlAstBuilder extends AstBuilder {
399399
}
400400

401401
/**
402-
* Fail an unsupported Hive native command.
402+
* Fail an unsupported Hive native command (SET ROLE is handled separately).
403403
*/
404404
override def visitFailNativeCommand(
405-
ctx: FailNativeCommandContext): LogicalPlan = withOrigin(ctx) {
406-
val keywords = if (ctx.unsupportedHiveNativeCommands != null) {
407-
ctx.unsupportedHiveNativeCommands.children.asScala.collect {
408-
case n: TerminalNode => n.getText
409-
}.mkString(" ")
410-
} else {
411-
// SET ROLE is the exception to the rule, because we handle this before other SET commands.
412-
"SET ROLE"
413-
}
405+
ctx: FailNativeCommandContext): LogicalPlan = withOrigin(ctx) {
406+
val keywords = ctx.unsupportedHiveNativeCommands.children.asScala.collect {
407+
case n: TerminalNode => n.getText
408+
}.mkString(" ")
414409
invalidStatement(keywords, ctx)
415410
}
416411

412+
override def visitFailSetRole(ctx: FailSetRoleContext): LogicalPlan = withOrigin(ctx) {
413+
invalidStatement("SET ROLE", ctx);
414+
}
415+
417416
/**
418417
* Create a [[AddFilesCommand]], [[AddJarsCommand]], [[AddArchivesCommand]],
419418
* [[ListFilesCommand]], [[ListJarsCommand]] or [[ListArchivesCommand]]

sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession {
9898
}
9999
}
100100

101-
test("session vars - set and read") {
101+
test("session vars - set and read (SET VAR)") {
102102
val sqlScript =
103103
"""
104104
|BEGIN
@@ -116,6 +116,24 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession {
116116
verifySqlScriptResult(sqlScript, expected)
117117
}
118118

119+
test("session vars - set and read (SET)") {
120+
val sqlScript =
121+
"""
122+
|BEGIN
123+
|DECLARE var = 1;
124+
|SET var = var + 1;
125+
|SELECT var;
126+
|END
127+
|""".stripMargin
128+
val expected = Seq(
129+
Seq.empty[Row], // declare var
130+
Seq.empty[Row], // set var
131+
Seq(Row(2)), // select
132+
Seq.empty[Row] // drop var
133+
)
134+
verifySqlScriptResult(sqlScript, expected)
135+
}
136+
119137
test("session vars - set and read scoped") {
120138
val sqlScript =
121139
"""

0 commit comments

Comments
 (0)