@@ -48,8 +48,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, con
4848import org .apache .spark .sql .connector .catalog .{CatalogV2Util , SupportsNamespaces , TableCatalog }
4949import org .apache .spark .sql .connector .catalog .TableChange .ColumnPosition
5050import org .apache .spark .sql .connector .expressions .{ApplyTransform , BucketTransform , DaysTransform , Expression => V2Expression , FieldReference , HoursTransform , IdentityTransform , LiteralValue , MonthsTransform , Transform , YearsTransform }
51- import org .apache .spark .sql .errors .{QueryCompilationErrors , QueryParsingErrors , SqlScriptingErrors }
52- import org .apache .spark .sql .errors .DataTypeErrors .toSQLStmt
51+ import org .apache .spark .sql .errors .{DataTypeErrorsBase , QueryCompilationErrors , QueryParsingErrors , SqlScriptingErrors }
5352import org .apache .spark .sql .internal .SQLConf
5453import org .apache .spark .sql .internal .SQLConf .LEGACY_BANG_EQUALS_NOT
5554import org .apache .spark .sql .types ._
@@ -62,7 +61,8 @@ import org.apache.spark.util.random.RandomSampler
6261 * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or
6362 * TableIdentifier.
6463 */
65- class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
64+ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper
65+ with Logging with DataTypeErrorsBase {
6666 import org .apache .spark .sql .connector .catalog .CatalogV2Implicits ._
6767 import ParserUtils ._
6868
@@ -142,37 +142,31 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
142142
143143 val compoundStatements = buff.toList
144144
145- if (allowVarDeclare) {
146- val declareVarStatement = compoundStatements
147- .dropWhile(statement => statement.isInstanceOf [SingleStatement ] &&
148- statement.asInstanceOf [SingleStatement ].parsedPlan.isInstanceOf [CreateVariable ])
149- .filter(_.isInstanceOf [SingleStatement ])
150- .find(_.asInstanceOf [SingleStatement ].parsedPlan.isInstanceOf [CreateVariable ])
151-
152- declareVarStatement match {
153- case Some (SingleStatement (parsedPlan)) =>
154- throw SqlScriptingErrors .variableDeclarationOnlyAtBeginning(
155- parsedPlan.asInstanceOf [CreateVariable ]
156- .name.asInstanceOf [UnresolvedIdentifier ]
157- .nameParts.last,
158- parsedPlan.origin.line.get.toString)
159- case _ =>
145+ val candidates = if (allowVarDeclare) {
146+ compoundStatements.dropWhile {
147+ case SingleStatement (_ : CreateVariable ) => true
148+ case _ => false
160149 }
161-
162150 } else {
163- val declareVarStatement = compoundStatements
164- .filter(_.isInstanceOf [SingleStatement ])
165- .find(_.asInstanceOf [SingleStatement ].parsedPlan.isInstanceOf [CreateVariable ])
151+ compoundStatements
152+ }
153+
154+ val declareVarStatement = candidates.collectFirst {
155+ case SingleStatement (c : CreateVariable ) => c
156+ }
166157
167- declareVarStatement match {
168- case Some (SingleStatement (parsedPlan)) =>
158+ declareVarStatement match {
159+ case Some (c : CreateVariable ) =>
160+ if (allowVarDeclare) {
169161 throw SqlScriptingErrors .variableDeclarationOnlyAtBeginning(
170- parsedPlan.asInstanceOf [CreateVariable ]
171- .name.asInstanceOf [UnresolvedIdentifier ]
172- .nameParts.last,
173- parsedPlan.origin.line.get.toString)
174- case _ =>
175- }
162+ toSQLId(c.name.asInstanceOf [UnresolvedIdentifier ].nameParts),
163+ c.origin.line.get.toString)
164+ } else {
165+ throw SqlScriptingErrors .variableDeclarationNotAllowedInScope(
166+ toSQLId(c.name.asInstanceOf [UnresolvedIdentifier ].nameParts),
167+ c.origin.line.get.toString)
168+ }
169+ case _ =>
176170 }
177171
178172 CompoundBody (buff.toSeq, label)
0 commit comments