Skip to content

Allow stored procedures to be defined without BEGIN/END #1834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
57 changes: 42 additions & 15 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2380,11 +2380,16 @@ impl fmt::Display for BeginEndStatements {
end_token: AttachedToken(end_token),
} = self;

write!(f, "{begin_token} ")?;
if begin_token.token != Token::EOF {
write!(f, "{begin_token} ")?;
}
if !statements.is_empty() {
format_statement_list(f, statements)?;
}
write!(f, " {end_token}")
if end_token.token != Token::EOF {
write!(f, " {end_token}")?;
}
Ok(())
}
}

Expand Down Expand Up @@ -3729,7 +3734,12 @@ pub enum Statement {
/// ```
///
/// Postgres: <https://www.postgresql.org/docs/current/sql-createtrigger.html>
/// SQL Server: <https://learn.microsoft.com/en-us/sql/t-sql/statements/create-trigger-transact-sql>
CreateTrigger {
/// True if this is a `CREATE OR ALTER TRIGGER` statement
///
/// [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-trigger-transact-sql?view=sql-server-ver16#arguments)
or_alter: bool,
/// The `OR REPLACE` clause is used to re-create the trigger if it already exists.
///
/// Example:
Expand Down Expand Up @@ -3790,7 +3800,9 @@ pub enum Statement {
/// Triggering conditions
condition: Option<Expr>,
/// Execute logic block
exec_body: TriggerExecBody,
exec_body: Option<TriggerExecBody>,
/// For SQL dialects with statement(s) for a body
statements: Option<ConditionalStatements>,
/// The characteristic of the trigger, which include whether the trigger is `DEFERRABLE`, `INITIALLY DEFERRED`, or `INITIALLY IMMEDIATE`,
characteristics: Option<ConstraintCharacteristics>,
},
Expand All @@ -3814,7 +3826,7 @@ pub enum Statement {
or_alter: bool,
name: ObjectName,
params: Option<Vec<ProcedureParam>>,
body: Vec<Statement>,
body: ConditionalStatements,
},
/// ```sql
/// CREATE MACRO
Expand Down Expand Up @@ -4587,6 +4599,7 @@ impl fmt::Display for Statement {
}
Statement::CreateFunction(create_function) => create_function.fmt(f),
Statement::CreateTrigger {
or_alter,
or_replace,
is_constraint,
name,
Expand All @@ -4599,19 +4612,30 @@ impl fmt::Display for Statement {
condition,
include_each,
exec_body,
statements,
characteristics,
} => {
write!(
f,
"CREATE {or_replace}{is_constraint}TRIGGER {name} {period}",
"CREATE {or_alter}{or_replace}{is_constraint}TRIGGER {name} ",
or_alter = if *or_alter { "OR ALTER " } else { "" },
or_replace = if *or_replace { "OR REPLACE " } else { "" },
is_constraint = if *is_constraint { "CONSTRAINT " } else { "" },
)?;

if !events.is_empty() {
write!(f, " {}", display_separated(events, " OR "))?;
if exec_body.is_some() {
write!(f, "{period}")?;
if !events.is_empty() {
write!(f, " {}", display_separated(events, " OR "))?;
}
write!(f, " ON {table_name}")?;
} else {
write!(f, "ON {table_name}")?;
write!(f, " {period}")?;
if !events.is_empty() {
write!(f, " {}", display_separated(events, ", "))?;
}
}
write!(f, " ON {table_name}")?;

if let Some(referenced_table_name) = referenced_table_name {
write!(f, " FROM {referenced_table_name}")?;
Expand All @@ -4627,13 +4651,19 @@ impl fmt::Display for Statement {

if *include_each {
write!(f, " FOR EACH {trigger_object}")?;
} else {
} else if exec_body.is_some() {
write!(f, " FOR {trigger_object}")?;
}
if let Some(condition) = condition {
write!(f, " WHEN {condition}")?;
}
write!(f, " EXECUTE {exec_body}")
if let Some(exec_body) = exec_body {
write!(f, " EXECUTE {exec_body}")?;
}
if let Some(statements) = statements {
write!(f, " AS {statements}")?;
}
Ok(())
}
Statement::DropTrigger {
if_exists,
Expand Down Expand Up @@ -4672,11 +4702,8 @@ impl fmt::Display for Statement {
write!(f, " ({})", display_comma_separated(p))?;
}
}
write!(
f,
" AS BEGIN {body} END",
body = display_separated(body, "; ")
)

write!(f, " AS {body}")
}
Statement::CreateMacro {
or_replace,
Expand Down
2 changes: 2 additions & 0 deletions src/ast/trigger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ impl fmt::Display for TriggerEvent {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum TriggerPeriod {
For,
After,
Before,
InsteadOf,
Expand All @@ -118,6 +119,7 @@ pub enum TriggerPeriod {
impl fmt::Display for TriggerPeriod {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
TriggerPeriod::For => write!(f, "FOR"),
TriggerPeriod::After => write!(f, "AFTER"),
TriggerPeriod::Before => write!(f, "BEFORE"),
TriggerPeriod::InsteadOf => write!(f, "INSTEAD OF"),
Expand Down
46 changes: 46 additions & 0 deletions src/dialect/mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use crate::ast::helpers::attached_token::AttachedToken;
use crate::ast::{
BeginEndStatements, ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement,
TriggerObject,
};
use crate::dialect::Dialect;
use crate::keywords::{self, Keyword};
Expand Down Expand Up @@ -125,6 +126,15 @@ impl Dialect for MsSqlDialect {
fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
if parser.peek_keyword(Keyword::IF) {
Some(self.parse_if_stmt(parser))
} else if parser.parse_keywords(&[Keyword::CREATE, Keyword::TRIGGER]) {
Some(self.parse_create_trigger(parser, false))
} else if parser.parse_keywords(&[
Keyword::CREATE,
Keyword::OR,
Keyword::ALTER,
Keyword::TRIGGER,
]) {
Some(self.parse_create_trigger(parser, true))
} else {
None
}
Expand Down Expand Up @@ -215,6 +225,42 @@ impl MsSqlDialect {
}))
}

/// Parse `CREATE TRIGGER` for [MsSql]
///
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-trigger-transact-sql
fn parse_create_trigger(
&self,
parser: &mut Parser,
or_alter: bool,
) -> Result<Statement, ParserError> {
let name = parser.parse_object_name(false)?;
parser.expect_keyword_is(Keyword::ON)?;
let table_name = parser.parse_object_name(false)?;
let period = parser.parse_trigger_period()?;
let events = parser.parse_comma_separated(Parser::parse_trigger_event)?;

parser.expect_keyword_is(Keyword::AS)?;
let statements = Some(parser.parse_conditional_statements(&[Keyword::END])?);

Ok(Statement::CreateTrigger {
or_alter,
or_replace: false,
is_constraint: false,
name,
period,
events,
table_name,
referenced_table_name: None,
referencing: Vec::new(),
trigger_object: TriggerObject::Statement,
include_each: false,
condition: None,
exec_body: None,
statements,
characteristics: None,
})
}

/// Parse a sequence of statements, optionally separated by semicolon.
///
/// Stops parsing when reaching EOF or the given keyword.
Expand Down
51 changes: 34 additions & 17 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -745,26 +745,38 @@ impl<'a> Parser<'a> {
}
};

let conditional_statements = self.parse_conditional_statements(terminal_keywords)?;

Ok(ConditionalStatementBlock {
start_token: AttachedToken(start_token),
condition,
then_token,
conditional_statements,
})
}

/// Parse a BEGIN/END block or a sequence of statements
/// This could be inside of a conditional (IF, CASE, WHILE etc.) or an object body defined optionally BEGIN/END and one or more statements.
pub(crate) fn parse_conditional_statements(
&mut self,
terminal_keywords: &[Keyword],
) -> Result<ConditionalStatements, ParserError> {
let conditional_statements = if self.peek_keyword(Keyword::BEGIN) {
let begin_token = self.expect_keyword(Keyword::BEGIN)?;
let statements = self.parse_statement_list(terminal_keywords)?;
let end_token = self.expect_keyword(Keyword::END)?;

ConditionalStatements::BeginEnd(BeginEndStatements {
begin_token: AttachedToken(begin_token),
statements,
end_token: AttachedToken(end_token),
})
} else {
let statements = self.parse_statement_list(terminal_keywords)?;
ConditionalStatements::Sequence { statements }
ConditionalStatements::Sequence {
statements: self.parse_statement_list(terminal_keywords)?,
}
};

Ok(ConditionalStatementBlock {
start_token: AttachedToken(start_token),
condition,
then_token,
conditional_statements,
})
Ok(conditional_statements)
}

/// Parse a `RAISE` statement.
Expand Down Expand Up @@ -4614,9 +4626,9 @@ impl<'a> Parser<'a> {
} else if self.parse_keyword(Keyword::FUNCTION) {
self.parse_create_function(or_alter, or_replace, temporary)
} else if self.parse_keyword(Keyword::TRIGGER) {
self.parse_create_trigger(or_replace, false)
self.parse_create_trigger(or_alter, or_replace, false)
} else if self.parse_keywords(&[Keyword::CONSTRAINT, Keyword::TRIGGER]) {
self.parse_create_trigger(or_replace, true)
self.parse_create_trigger(or_alter, or_replace, true)
} else if self.parse_keyword(Keyword::MACRO) {
self.parse_create_macro(or_replace, temporary)
} else if self.parse_keyword(Keyword::SECRET) {
Expand Down Expand Up @@ -5314,10 +5326,11 @@ impl<'a> Parser<'a> {

pub fn parse_create_trigger(
&mut self,
or_alter: bool,
or_replace: bool,
is_constraint: bool,
) -> Result<Statement, ParserError> {
if !dialect_of!(self is PostgreSqlDialect | GenericDialect | MySqlDialect) {
if !dialect_of!(self is PostgreSqlDialect | GenericDialect | MySqlDialect | MsSqlDialect) {
self.prev_token();
return self.expected("an object type after CREATE", self.peek_token());
}
Expand Down Expand Up @@ -5363,6 +5376,7 @@ impl<'a> Parser<'a> {
let exec_body = self.parse_trigger_exec_body()?;

Ok(Statement::CreateTrigger {
or_alter,
or_replace,
is_constraint,
name,
Expand All @@ -5374,18 +5388,21 @@ impl<'a> Parser<'a> {
trigger_object,
include_each,
condition,
exec_body,
exec_body: Some(exec_body),
statements: None,
characteristics,
})
}

pub fn parse_trigger_period(&mut self) -> Result<TriggerPeriod, ParserError> {
Ok(
match self.expect_one_of_keywords(&[
Keyword::FOR,
Keyword::BEFORE,
Keyword::AFTER,
Keyword::INSTEAD,
])? {
Keyword::FOR => TriggerPeriod::For,
Keyword::BEFORE => TriggerPeriod::Before,
Keyword::AFTER => TriggerPeriod::After,
Keyword::INSTEAD => self
Expand Down Expand Up @@ -15457,14 +15474,14 @@ impl<'a> Parser<'a> {
let name = self.parse_object_name(false)?;
let params = self.parse_optional_procedure_parameters()?;
self.expect_keyword_is(Keyword::AS)?;
self.expect_keyword_is(Keyword::BEGIN)?;
let statements = self.parse_statements()?;
self.expect_keyword_is(Keyword::END)?;

let body = self.parse_conditional_statements(&[Keyword::END])?;

Ok(Statement::CreateProcedure {
name,
or_alter,
params,
body: statements,
body,
})
}

Expand Down
Loading