From 8d4278b7194d2fbb3f179154933f0d5fb31450c6 Mon Sep 17 00:00:00 2001 From: ifeanyi Date: Sun, 9 Feb 2025 12:13:14 +0100 Subject: [PATCH] Add `CASE` and `IF` statement support Add support for scripting statements ```sql CASE product_id WHEN 1 THEN SELECT 1; WHEN 2 THEN SELECT 2; ELSE SELECT 3; END CASE; ``` ```sql IF EXISTS(SELECT 1) THEN SELECT 1; ELSEIF EXISTS(SELECT 2) THEN SELECT 2; ELSE SELECT 3; END IF; ``` [BigQuery CASE](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#case) [BigQuery IF](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#if) [Snowflake CASE](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/case) [Snowflake IF](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/if) --- src/ast/mod.rs | 198 ++++++++++++++++++++++++++++++++++++-- src/ast/spans.rs | 78 ++++++++++++--- src/keywords.rs | 1 + src/parser/mod.rs | 104 ++++++++++++++++++++ tests/sqlparser_common.rs | 114 ++++++++++++++++++++++ 5 files changed, 473 insertions(+), 22 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 139850e86..d976df876 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -151,6 +151,15 @@ where DisplaySeparated { slice, sep: ", " } } +/// Writes the given statements to the formatter, each ending with +/// a semicolon and space separated. +fn format_statement_list(f: &mut fmt::Formatter, statements: &[Statement]) -> fmt::Result { + write!(f, "{}", display_separated(statements, "; "))?; + // We manually insert semicolon for the last statement, + // since display_separated doesn't handle that case. + write!(f, ";") +} + /// An identifier, decomposed into its value or character data and the quote style. #[derive(Debug, Clone, PartialOrd, Ord)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -2080,6 +2089,173 @@ pub enum Password { NullPassword, } +/// A `CASE` statement. +/// +/// Examples: +/// ```sql +/// CASE +/// WHEN EXISTS(SELECT 1) +/// THEN SELECT 1 FROM T; +/// WHEN EXISTS(SELECT 2) +/// THEN SELECT 1 FROM U; +/// ELSE +/// SELECT 1 FROM V; +/// END CASE; +/// ``` +/// +/// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#case_search_expression) +/// [Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/case) +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct CaseStatement { + pub match_expr: Option, + pub when_blocks: Vec, + pub else_block: Option>, + /// TRUE if the statement ends with `END CASE` (vs `END`). + pub has_end_case: bool, +} + +impl fmt::Display for CaseStatement { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let CaseStatement { + match_expr, + when_blocks, + else_block, + has_end_case, + } = self; + + write!(f, "CASE")?; + + if let Some(expr) = match_expr { + write!(f, " {expr}")?; + } + + if !when_blocks.is_empty() { + write!(f, " {}", display_separated(when_blocks, " "))?; + } + + if let Some(else_block) = else_block { + write!(f, " ELSE ")?; + format_statement_list(f, else_block)?; + } + + write!(f, " END")?; + if *has_end_case { + write!(f, " CASE")?; + } + + Ok(()) + } +} + +/// An `IF` statement. +/// +/// Examples: +/// ```sql +/// IF TRUE THEN +/// SELECT 1; +/// SELECT 2; +/// ELSEIF TRUE THEN +/// SELECT 3; +/// ELSE +/// SELECT 4; +/// END IF +/// ``` +/// +/// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#if) +/// [Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/if) +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct IfStatement { + pub if_block: ConditionalStatements, + pub elseif_blocks: Vec, + pub else_block: Option>, +} + +impl fmt::Display for IfStatement { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let IfStatement { + if_block, + elseif_blocks, + else_block, + } = self; + + write!(f, "{if_block}")?; + + if !elseif_blocks.is_empty() { + write!(f, " {}", display_separated(elseif_blocks, " "))?; + } + + if let Some(else_block) = else_block { + write!(f, " ELSE ")?; + format_statement_list(f, else_block)?; + } + + write!(f, " END IF")?; + + Ok(()) + } +} + +/// Represents a type of [ConditionalStatements] +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum ConditionalStatementKind { + /// `WHEN THEN ` + When, + /// `IF THEN ` + If, + /// `ELSEIF THEN ` + ElseIf, +} + +/// A block within a [Statement::Case] or [Statement::If]-like statement +/// +/// Examples: +/// ```sql +/// WHEN EXISTS(SELECT 1) THEN SELECT 1; +/// +/// IF TRUE THEN SELECT 1; SELECT 2; +/// ``` +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct ConditionalStatements { + /// The condition expression. + pub condition: Expr, + /// Statement list of the `THEN` clause. + pub statements: Vec, + pub kind: ConditionalStatementKind, +} + +impl fmt::Display for ConditionalStatements { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let ConditionalStatements { + condition: expr, + statements, + kind, + } = self; + + let kind = match kind { + ConditionalStatementKind::When => "WHEN", + ConditionalStatementKind::If => "IF", + ConditionalStatementKind::ElseIf => "ELSEIF", + }; + + write!(f, "{kind} {expr} THEN")?; + + if !statements.is_empty() { + write!(f, " ")?; + format_statement_list(f, statements)?; + } + + Ok(()) + } +} + /// Represents an expression assignment within a variable `DECLARE` statement. /// /// Examples: @@ -2647,6 +2823,10 @@ pub enum Statement { file_format: Option, source: Box, }, + /// A `CASE` statement. + Case(CaseStatement), + /// An `IF` statement. + If(IfStatement), /// ```sql /// CALL /// ``` @@ -3940,6 +4120,12 @@ impl fmt::Display for Statement { } Ok(()) } + Statement::Case(stmt) => { + write!(f, "{stmt}") + } + Statement::If(stmt) => { + write!(f, "{stmt}") + } Statement::AttachDatabase { schema_name, database_file_name, @@ -4942,18 +5128,14 @@ impl fmt::Display for Statement { write!(f, " {}", display_comma_separated(modes))?; } if !statements.is_empty() { - write!(f, " {}", display_separated(statements, "; "))?; - // We manually insert semicolon for the last statement, - // since display_separated doesn't handle that case. - write!(f, ";")?; + write!(f, " ")?; + format_statement_list(f, statements)?; } if let Some(exception_statements) = exception_statements { write!(f, " EXCEPTION WHEN ERROR THEN")?; if !exception_statements.is_empty() { - write!(f, " {}", display_separated(exception_statements, "; "))?; - // We manually insert semicolon for the last statement, - // since display_separated doesn't handle that case. - write!(f, ";")?; + write!(f, " ")?; + format_statement_list(f, exception_statements)?; } } if *has_end_keyword { diff --git a/src/ast/spans.rs b/src/ast/spans.rs index a4f5eb46c..0ee11f23f 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -22,20 +22,21 @@ use crate::tokenizer::Span; use super::{ dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr, AlterColumnOperation, - AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, CloseCursor, - ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef, ConflictTarget, ConnectBy, - ConstraintCharacteristics, CopySource, CreateIndex, CreateTable, CreateTableOptions, Cte, - Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr, ExprWithAlias, Fetch, FromTable, - Function, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, - FunctionArguments, GroupByExpr, HavingBound, IlikeSelectItem, Insert, Interpolate, - InterpolateExpr, Join, JoinConstraint, JoinOperator, JsonPath, JsonPathElem, LateralView, - LimitClause, MatchRecognizePattern, Measure, NamedWindowDefinition, ObjectName, ObjectNamePart, - Offset, OnConflict, OnConflictAction, OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, - PivotValueSource, ProjectionSelect, Query, ReferentialAction, RenameSelectItem, - ReplaceSelectElement, ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption, - Statement, Subscript, SymbolDefinition, TableAlias, TableAliasColumnDef, TableConstraint, - TableFactor, TableObject, TableOptionsClustered, TableWithJoins, UpdateTableFromKind, Use, - Value, Values, ViewColumnDef, WildcardAdditionalOptions, With, WithFill, + AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, CaseStatement, + CloseCursor, ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef, ConditionalStatements, + ConflictTarget, ConnectBy, ConstraintCharacteristics, CopySource, CreateIndex, CreateTable, + CreateTableOptions, Cte, Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr, + ExprWithAlias, Fetch, FromTable, Function, FunctionArg, FunctionArgExpr, + FunctionArgumentClause, FunctionArgumentList, FunctionArguments, GroupByExpr, HavingBound, + IfStatement, IlikeSelectItem, Insert, Interpolate, InterpolateExpr, Join, JoinConstraint, + JoinOperator, JsonPath, JsonPathElem, LateralView, LimitClause, MatchRecognizePattern, Measure, + NamedWindowDefinition, ObjectName, ObjectNamePart, Offset, OnConflict, OnConflictAction, + OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, PivotValueSource, ProjectionSelect, + Query, ReferentialAction, RenameSelectItem, ReplaceSelectElement, ReplaceSelectItem, Select, + SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript, SymbolDefinition, TableAlias, + TableAliasColumnDef, TableConstraint, TableFactor, TableObject, TableOptionsClustered, + TableWithJoins, UpdateTableFromKind, Use, Value, Values, ViewColumnDef, + WildcardAdditionalOptions, With, WithFill, }; /// Given an iterator of spans, return the [Span::union] of all spans. @@ -334,6 +335,8 @@ impl Spanned for Statement { file_format: _, source, } => source.span(), + Statement::Case(stmt) => stmt.span(), + Statement::If(stmt) => stmt.span(), Statement::Call(function) => function.span(), Statement::Copy { source, @@ -732,6 +735,53 @@ impl Spanned for CreateIndex { } } +impl Spanned for CaseStatement { + fn span(&self) -> Span { + let CaseStatement { + match_expr, + when_blocks, + else_block, + has_end_case: _, + } = self; + + union_spans( + match_expr + .iter() + .map(|e| e.span()) + .chain(when_blocks.iter().map(|b| b.span())) + .chain(else_block.iter().flat_map(|e| e.iter().map(|s| s.span()))), + ) + } +} + +impl Spanned for IfStatement { + fn span(&self) -> Span { + let IfStatement { + if_block, + elseif_blocks, + else_block, + } = self; + + union_spans( + iter::once(if_block.span()) + .chain(elseif_blocks.iter().map(|b| b.span())) + .chain(else_block.iter().flat_map(|e| e.iter().map(|s| s.span()))), + ) + } +} + +impl Spanned for ConditionalStatements { + fn span(&self) -> Span { + let ConditionalStatements { + condition, + statements, + kind: _, + } = self; + + union_spans(iter::once(condition.span()).chain(statements.iter().map(|s| s.span()))) + } +} + /// # partial span /// /// Missing spans: diff --git a/src/keywords.rs b/src/keywords.rs index 195bbb172..47da10096 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -297,6 +297,7 @@ define_keywords!( ELEMENT, ELEMENTS, ELSE, + ELSEIF, EMPTY, ENABLE, ENABLE_SCHEMA_EVOLUTION, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index d3c48a6e7..a46ad2316 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -528,6 +528,14 @@ impl<'a> Parser<'a> { Keyword::DESCRIBE => self.parse_explain(DescribeAlias::Describe), Keyword::EXPLAIN => self.parse_explain(DescribeAlias::Explain), Keyword::ANALYZE => self.parse_analyze(), + Keyword::CASE => { + self.prev_token(); + self.parse_case_stmt() + } + Keyword::IF => { + self.prev_token(); + self.parse_if_stmt() + } Keyword::SELECT | Keyword::WITH | Keyword::VALUES | Keyword::FROM => { self.prev_token(); self.parse_query().map(Statement::Query) @@ -615,6 +623,102 @@ impl<'a> Parser<'a> { } } + /// Parse a `CASE` statement. + /// + /// See [Statement::Case] + pub fn parse_case_stmt(&mut self) -> Result { + self.expect_keyword_is(Keyword::CASE)?; + + let match_expr = if self.peek_keyword(Keyword::WHEN) { + None + } else { + Some(self.parse_expr()?) + }; + + self.expect_keyword_is(Keyword::WHEN)?; + let when_blocks = self.parse_keyword_separated(Keyword::WHEN, |parser| { + parser.parse_conditional_statements( + ConditionalStatementKind::When, + &[Keyword::WHEN, Keyword::ELSE, Keyword::END], + ) + })?; + + let else_block = if self.parse_keyword(Keyword::ELSE) { + Some(self.parse_statement_list(&[Keyword::END])?) + } else { + None + }; + + self.expect_keyword_is(Keyword::END)?; + let has_end_case = self.parse_keyword(Keyword::CASE); + + Ok(Statement::Case(CaseStatement { + match_expr, + when_blocks, + else_block, + has_end_case, + })) + } + + /// Parse an `IF` statement. + /// + /// See [Statement::If] + pub fn parse_if_stmt(&mut self) -> Result { + self.expect_keyword_is(Keyword::IF)?; + let if_block = self.parse_conditional_statements( + ConditionalStatementKind::If, + &[Keyword::ELSE, Keyword::ELSEIF, Keyword::END], + )?; + + let elseif_blocks = if self.parse_keyword(Keyword::ELSEIF) { + self.parse_keyword_separated(Keyword::ELSEIF, |parser| { + parser.parse_conditional_statements( + ConditionalStatementKind::ElseIf, + &[Keyword::ELSEIF, Keyword::ELSE, Keyword::END], + ) + })? + } else { + vec![] + }; + + let else_block = if self.parse_keyword(Keyword::ELSE) { + Some(self.parse_statement_list(&[Keyword::END])?) + } else { + None + }; + + self.expect_keywords(&[Keyword::END, Keyword::IF])?; + + Ok(Statement::If(IfStatement { + if_block, + elseif_blocks, + else_block, + })) + } + + /// Parses an expression and associated list of statements + /// belonging to a conditional statement like `IF` or `WHEN`. + /// + /// Example: + /// ```sql + /// IF condition THEN statement1; statement2; + /// ``` + fn parse_conditional_statements( + &mut self, + kind: ConditionalStatementKind, + terminal_keywords: &[Keyword], + ) -> Result { + let condition = self.parse_expr()?; + self.expect_keyword_is(Keyword::THEN)?; + let statements = self.parse_statement_list(terminal_keywords)?; + + Ok(ConditionalStatements { + condition, + statements, + kind, + }) + } + pub fn parse_comment(&mut self) -> Result { let if_exists = self.parse_keywords(&[Keyword::IF, Keyword::EXISTS]); diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index b5d42ea6c..905b6d392 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -14176,6 +14176,120 @@ fn test_visit_order() { ); } +#[test] +fn parse_case_statement() { + let sql = "CASE 1 WHEN 2 THEN SELECT 1; SELECT 2; ELSE SELECT 3; END CASE"; + let Statement::Case(stmt) = verified_stmt(sql) else { + unreachable!() + }; + + assert_eq!(Some(Expr::value(number("1"))), stmt.match_expr); + assert_eq!(Expr::value(number("2")), stmt.when_blocks[0].condition); + assert_eq!(2, stmt.when_blocks[0].statements.len()); + assert_eq!(1, stmt.else_block.unwrap().len()); + + verified_stmt(concat!( + "CASE 1", + " WHEN a THEN", + " SELECT 1; SELECT 2; SELECT 3;", + " WHEN b THEN", + " SELECT 4; SELECT 5;", + " ELSE", + " SELECT 7; SELECT 8;", + " END CASE" + )); + verified_stmt(concat!( + "CASE 1", + " WHEN a THEN", + " SELECT 1; SELECT 2; SELECT 3;", + " WHEN b THEN", + " SELECT 4; SELECT 5;", + " END CASE" + )); + verified_stmt(concat!( + "CASE 1", + " WHEN a THEN", + " SELECT 1; SELECT 2; SELECT 3;", + " END CASE" + )); + verified_stmt(concat!( + "CASE 1", + " WHEN a THEN", + " SELECT 1; SELECT 2; SELECT 3;", + " END" + )); + + assert_eq!( + ParserError::ParserError("Expected: THEN, found: END".to_string()), + parse_sql_statements("CASE 1 WHEN a END").unwrap_err() + ); + assert_eq!( + ParserError::ParserError("Expected: WHEN, found: ELSE".to_string()), + parse_sql_statements("CASE 1 ELSE SELECT 1; END").unwrap_err() + ); +} + +#[test] +fn parse_if_statement() { + let sql = "IF 1 THEN SELECT 1; ELSEIF 2 THEN SELECT 2; ELSE SELECT 3; END IF"; + let Statement::If(stmt) = verified_stmt(sql) else { + unreachable!() + }; + assert_eq!(Expr::value(number("1")), stmt.if_block.condition); + assert_eq!(Expr::value(number("2")), stmt.elseif_blocks[0].condition); + assert_eq!(1, stmt.else_block.unwrap().len()); + + verified_stmt(concat!( + "IF 1 THEN", + " SELECT 1;", + " SELECT 2;", + " SELECT 3;", + " ELSEIF 2 THEN", + " SELECT 4;", + " SELECT 5;", + " ELSEIF 3 THEN", + " SELECT 6;", + " SELECT 7;", + " ELSE", + " SELECT 8;", + " SELECT 9;", + " END IF" + )); + verified_stmt(concat!( + "IF 1 THEN", + " SELECT 1;", + " SELECT 2;", + " ELSE", + " SELECT 3;", + " SELECT 4;", + " END IF" + )); + verified_stmt(concat!( + "IF 1 THEN", + " SELECT 1;", + " SELECT 2;", + " SELECT 3;", + " ELSEIF 2 THEN", + " SELECT 3;", + " SELECT 4;", + " END IF" + )); + verified_stmt(concat!("IF 1 THEN", " SELECT 1;", " SELECT 2;", " END IF")); + verified_stmt(concat!( + "IF (1) THEN", + " SELECT 1;", + " SELECT 2;", + " END IF" + )); + verified_stmt("IF 1 THEN END IF"); + verified_stmt("IF 1 THEN SELECT 1; ELSEIF 1 THEN END IF"); + + assert_eq!( + ParserError::ParserError("Expected: IF, found: EOF".to_string()), + parse_sql_statements("IF 1 THEN SELECT 1; ELSEIF 1 THEN SELECT 2; END").unwrap_err() + ); +} + #[test] fn test_lambdas() { let dialects = all_dialects_where(|d| d.supports_lambda_functions());