Skip to content

Add CASE and IF statement support #1741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 190 additions & 8 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,15 @@ where
DisplaySeparated { slice, sep: ", " }
}

/// Writes the given statements to the formatter, each ending with
/// a semicolon and space separated.
fn format_statement_list(f: &mut fmt::Formatter, statements: &[Statement]) -> fmt::Result {
write!(f, "{}", display_separated(statements, "; "))?;
// We manually insert semicolon for the last statement,
// since display_separated doesn't handle that case.
write!(f, ";")
}

/// An identifier, decomposed into its value or character data and the quote style.
#[derive(Debug, Clone, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
Expand Down Expand Up @@ -2080,6 +2089,173 @@ pub enum Password {
NullPassword,
}

/// A `CASE` statement.
///
/// Examples:
/// ```sql
/// CASE
/// WHEN EXISTS(SELECT 1)
/// THEN SELECT 1 FROM T;
/// WHEN EXISTS(SELECT 2)
/// THEN SELECT 1 FROM U;
/// ELSE
/// SELECT 1 FROM V;
/// END CASE;
/// ```
///
/// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#case_search_expression)
/// [Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/case)
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct CaseStatement {
pub match_expr: Option<Expr>,
pub when_blocks: Vec<ConditionalStatements>,
pub else_block: Option<Vec<Statement>>,
/// TRUE if the statement ends with `END CASE` (vs `END`).
pub has_end_case: bool,
}

impl fmt::Display for CaseStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let CaseStatement {
match_expr,
when_blocks,
else_block,
has_end_case,
} = self;

write!(f, "CASE")?;

if let Some(expr) = match_expr {
write!(f, " {expr}")?;
}

if !when_blocks.is_empty() {
write!(f, " {}", display_separated(when_blocks, " "))?;
}

if let Some(else_block) = else_block {
write!(f, " ELSE ")?;
format_statement_list(f, else_block)?;
}

write!(f, " END")?;
if *has_end_case {
write!(f, " CASE")?;
}

Ok(())
}
}

/// An `IF` statement.
///
/// Examples:
/// ```sql
/// IF TRUE THEN
/// SELECT 1;
/// SELECT 2;
/// ELSEIF TRUE THEN
/// SELECT 3;
/// ELSE
/// SELECT 4;
/// END IF
/// ```
///
/// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#if)
/// [Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/if)
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct IfStatement {
pub if_block: ConditionalStatements,
pub elseif_blocks: Vec<ConditionalStatements>,
pub else_block: Option<Vec<Statement>>,
}

impl fmt::Display for IfStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let IfStatement {
if_block,
elseif_blocks,
else_block,
} = self;

write!(f, "{if_block}")?;

if !elseif_blocks.is_empty() {
write!(f, " {}", display_separated(elseif_blocks, " "))?;
}

if let Some(else_block) = else_block {
write!(f, " ELSE ")?;
format_statement_list(f, else_block)?;
}

write!(f, " END IF")?;

Ok(())
}
}

/// Represents a type of [ConditionalStatements]
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ConditionalStatementKind {
/// `WHEN <condition> THEN <statements>`
When,
/// `IF <condition> THEN <statements>`
If,
/// `ELSEIF <condition> THEN <statements>`
ElseIf,
}

/// A block within a [Statement::Case] or [Statement::If]-like statement
///
/// Examples:
/// ```sql
/// WHEN EXISTS(SELECT 1) THEN SELECT 1;
///
/// IF TRUE THEN SELECT 1; SELECT 2;
/// ```
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct ConditionalStatements {
/// The condition expression.
pub condition: Expr,
/// Statement list of the `THEN` clause.
pub statements: Vec<Statement>,
pub kind: ConditionalStatementKind,
}

impl fmt::Display for ConditionalStatements {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let ConditionalStatements {
condition: expr,
statements,
kind,
} = self;

let kind = match kind {
ConditionalStatementKind::When => "WHEN",
ConditionalStatementKind::If => "IF",
ConditionalStatementKind::ElseIf => "ELSEIF",
};

write!(f, "{kind} {expr} THEN")?;

if !statements.is_empty() {
write!(f, " ")?;
format_statement_list(f, statements)?;
}

Ok(())
}
}

/// Represents an expression assignment within a variable `DECLARE` statement.
///
/// Examples:
Expand Down Expand Up @@ -2647,6 +2823,10 @@ pub enum Statement {
file_format: Option<FileFormat>,
source: Box<Query>,
},
/// A `CASE` statement.
Case(CaseStatement),
/// An `IF` statement.
If(IfStatement),
/// ```sql
/// CALL <function>
/// ```
Expand Down Expand Up @@ -3940,6 +4120,12 @@ impl fmt::Display for Statement {
}
Ok(())
}
Statement::Case(stmt) => {
write!(f, "{stmt}")
}
Statement::If(stmt) => {
write!(f, "{stmt}")
}
Statement::AttachDatabase {
schema_name,
database_file_name,
Expand Down Expand Up @@ -4942,18 +5128,14 @@ impl fmt::Display for Statement {
write!(f, " {}", display_comma_separated(modes))?;
}
if !statements.is_empty() {
write!(f, " {}", display_separated(statements, "; "))?;
// We manually insert semicolon for the last statement,
// since display_separated doesn't handle that case.
write!(f, ";")?;
write!(f, " ")?;
format_statement_list(f, statements)?;
}
if let Some(exception_statements) = exception_statements {
write!(f, " EXCEPTION WHEN ERROR THEN")?;
if !exception_statements.is_empty() {
write!(f, " {}", display_separated(exception_statements, "; "))?;
// We manually insert semicolon for the last statement,
// since display_separated doesn't handle that case.
write!(f, ";")?;
write!(f, " ")?;
format_statement_list(f, exception_statements)?;
}
}
if *has_end_keyword {
Expand Down
78 changes: 64 additions & 14 deletions src/ast/spans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,21 @@ use crate::tokenizer::Span;

use super::{
dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr, AlterColumnOperation,
AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, CloseCursor,
ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef, ConflictTarget, ConnectBy,
ConstraintCharacteristics, CopySource, CreateIndex, CreateTable, CreateTableOptions, Cte,
Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr, ExprWithAlias, Fetch, FromTable,
Function, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList,
FunctionArguments, GroupByExpr, HavingBound, IlikeSelectItem, Insert, Interpolate,
InterpolateExpr, Join, JoinConstraint, JoinOperator, JsonPath, JsonPathElem, LateralView,
LimitClause, MatchRecognizePattern, Measure, NamedWindowDefinition, ObjectName, ObjectNamePart,
Offset, OnConflict, OnConflictAction, OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition,
PivotValueSource, ProjectionSelect, Query, ReferentialAction, RenameSelectItem,
ReplaceSelectElement, ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption,
Statement, Subscript, SymbolDefinition, TableAlias, TableAliasColumnDef, TableConstraint,
TableFactor, TableObject, TableOptionsClustered, TableWithJoins, UpdateTableFromKind, Use,
Value, Values, ViewColumnDef, WildcardAdditionalOptions, With, WithFill,
AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, CaseStatement,
CloseCursor, ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef, ConditionalStatements,
ConflictTarget, ConnectBy, ConstraintCharacteristics, CopySource, CreateIndex, CreateTable,
CreateTableOptions, Cte, Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr,
ExprWithAlias, Fetch, FromTable, Function, FunctionArg, FunctionArgExpr,
FunctionArgumentClause, FunctionArgumentList, FunctionArguments, GroupByExpr, HavingBound,
IfStatement, IlikeSelectItem, Insert, Interpolate, InterpolateExpr, Join, JoinConstraint,
JoinOperator, JsonPath, JsonPathElem, LateralView, LimitClause, MatchRecognizePattern, Measure,
NamedWindowDefinition, ObjectName, ObjectNamePart, Offset, OnConflict, OnConflictAction,
OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, PivotValueSource, ProjectionSelect,
Query, ReferentialAction, RenameSelectItem, ReplaceSelectElement, ReplaceSelectItem, Select,
SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript, SymbolDefinition, TableAlias,
TableAliasColumnDef, TableConstraint, TableFactor, TableObject, TableOptionsClustered,
TableWithJoins, UpdateTableFromKind, Use, Value, Values, ViewColumnDef,
WildcardAdditionalOptions, With, WithFill,
};

/// Given an iterator of spans, return the [Span::union] of all spans.
Expand Down Expand Up @@ -334,6 +335,8 @@ impl Spanned for Statement {
file_format: _,
source,
} => source.span(),
Statement::Case(stmt) => stmt.span(),
Statement::If(stmt) => stmt.span(),
Statement::Call(function) => function.span(),
Statement::Copy {
source,
Expand Down Expand Up @@ -732,6 +735,53 @@ impl Spanned for CreateIndex {
}
}

impl Spanned for CaseStatement {
fn span(&self) -> Span {
let CaseStatement {
match_expr,
when_blocks,
else_block,
has_end_case: _,
} = self;

union_spans(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is quite fancy 👌

match_expr
.iter()
.map(|e| e.span())
.chain(when_blocks.iter().map(|b| b.span()))
.chain(else_block.iter().flat_map(|e| e.iter().map(|s| s.span()))),
)
}
}

impl Spanned for IfStatement {
fn span(&self) -> Span {
let IfStatement {
if_block,
elseif_blocks,
else_block,
} = self;

union_spans(
iter::once(if_block.span())
.chain(elseif_blocks.iter().map(|b| b.span()))
.chain(else_block.iter().flat_map(|e| e.iter().map(|s| s.span()))),
)
}
}

impl Spanned for ConditionalStatements {
fn span(&self) -> Span {
let ConditionalStatements {
condition,
statements,
kind: _,
} = self;

union_spans(iter::once(condition.span()).chain(statements.iter().map(|s| s.span())))
}
}

/// # partial span
///
/// Missing spans:
Expand Down
1 change: 1 addition & 0 deletions src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ define_keywords!(
ELEMENT,
ELEMENTS,
ELSE,
ELSEIF,
EMPTY,
ENABLE,
ENABLE_SCHEMA_EVOLUTION,
Expand Down
Loading