Skip to content

Commit

Permalink
Handle quoting identifiers properly
Browse files Browse the repository at this point in the history
  • Loading branch information
madejejej committed Dec 18, 2024
1 parent 783ec65 commit 9e01c22
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 19 deletions.
32 changes: 19 additions & 13 deletions core/translate/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,13 @@ fn bind_column_references(
return Ok(());
}
let mut match_result = None;
let normalized_id = normalize_ident(id.0.as_str());
for (tbl_idx, table) in referenced_tables.iter().enumerate() {
let col_idx = table
.table
.columns
.iter()
.position(|c| c.name.eq_ignore_ascii_case(&id.0));
.position(|c| c.name.eq_ignore_ascii_case(&normalized_id));
if col_idx.is_some() {
if match_result.is_some() {
crate::bail_parse_error!("Column {} is ambiguous", id.0);
Expand All @@ -124,20 +125,23 @@ fn bind_column_references(
Ok(())
}
ast::Expr::Qualified(tbl, id) => {
let matching_tbl_idx = referenced_tables
.iter()
.position(|t| t.table_identifier.eq_ignore_ascii_case(&tbl.0));
let normalized_table_name = normalize_ident(tbl.0.as_str());
let matching_tbl_idx = referenced_tables.iter().position(|t| {
t.table_identifier
.eq_ignore_ascii_case(&normalized_table_name)
});
if matching_tbl_idx.is_none() {
crate::bail_parse_error!("Table {} not found", tbl.0);
crate::bail_parse_error!("Table {} not found", normalized_table_name);
}
let tbl_idx = matching_tbl_idx.unwrap();
let normalized_id = normalize_ident(id.0.as_str());
let col_idx = referenced_tables[tbl_idx]
.table
.columns
.iter()
.position(|c| c.name.eq_ignore_ascii_case(&id.0));
.position(|c| c.name.eq_ignore_ascii_case(&normalized_id));
if col_idx.is_none() {
crate::bail_parse_error!("Column {} not found", id.0);
crate::bail_parse_error!("Column {} not found", normalized_id);
}
let col = referenced_tables[tbl_idx]
.table
Expand Down Expand Up @@ -504,8 +508,9 @@ fn parse_from(

let first_table = match *from.select.unwrap() {
ast::SelectTable::Table(qualified_name, maybe_alias, _) => {
let Some(table) = schema.get_table(&qualified_name.name.0) else {
crate::bail_parse_error!("Table {} not found", qualified_name.name.0);
let normalized_qualified_name = normalize_ident(qualified_name.name.0.as_str());
let Some(table) = schema.get_table(&normalized_qualified_name) else {
crate::bail_parse_error!("Table {} not found", normalized_qualified_name);
};
let alias = maybe_alias
.map(|a| match a {
Expand All @@ -516,7 +521,7 @@ fn parse_from(

BTreeTableReference {
table: table.clone(),
table_identifier: alias.unwrap_or(qualified_name.name.0),
table_identifier: alias.unwrap_or(normalized_qualified_name),
table_index: 0,
}
}
Expand Down Expand Up @@ -570,8 +575,9 @@ fn parse_join(

let table = match table {
ast::SelectTable::Table(qualified_name, maybe_alias, _) => {
let Some(table) = schema.get_table(&qualified_name.name.0) else {
crate::bail_parse_error!("Table {} not found", qualified_name.name.0);
let normalized_name = normalize_ident(qualified_name.name.0.as_str());
let Some(table) = schema.get_table(&normalized_name) else {
crate::bail_parse_error!("Table {} not found", normalized_name);
};
let alias = maybe_alias
.map(|a| match a {
Expand All @@ -581,7 +587,7 @@ fn parse_join(
.map(|a| a.0);
BTreeTableReference {
table: table.clone(),
table_identifier: alias.unwrap_or(qualified_name.name.0),
table_identifier: alias.unwrap_or(normalized_name),
table_index,
}
}
Expand Down
28 changes: 22 additions & 6 deletions core/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@ use crate::{
Result, RowResult, Rows, IO,
};

pub fn normalize_ident(ident: &str) -> String {
(if ident.starts_with('"') && ident.ends_with('"') {
&ident[1..ident.len() - 1]
// https://sqlite.org/lang_keywords.html
const QUOTE_PAIRS: &[(char, char)] = &[('"', '"'), ('[', ']'), ('`', '`')];

pub fn normalize_ident(identifier: &str) -> String {
let quote_pair = QUOTE_PAIRS
.iter()
.find(|&(start, end)| identifier.starts_with(*start) && identifier.ends_with(*end));

if let Some(&(start, end)) = quote_pair {
&identifier[1..identifier.len() - 1]
} else {
ident
})
identifier
}
.to_lowercase()
}

Expand Down Expand Up @@ -65,7 +72,6 @@ fn cmp_numeric_strings(num_str: &str, other: &str) -> bool {
}
}

const QUOTE_PAIRS: &[(char, char)] = &[('"', '"'), ('[', ']'), ('`', '`')];
pub fn check_ident_equivalency(ident1: &str, ident2: &str) -> bool {
fn strip_quotes(identifier: &str) -> &str {
for &(start, end) in QUOTE_PAIRS {
Expand Down Expand Up @@ -276,7 +282,17 @@ pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool {

#[cfg(test)]
pub mod tests {
use super::*;
use sqlite3_parser::ast::{self, Expr, Id, Literal, Operator::*, Type};

#[test]
fn test_normalize_ident() {
assert_eq!(normalize_ident("foo"), "foo");
assert_eq!(normalize_ident("`foo`"), "foo");
assert_eq!(normalize_ident("[foo]"), "foo");
assert_eq!(normalize_ident("\"foo\""), "foo");
}

#[test]
fn test_basic_addition_exprs_are_equivalent() {
let expr1 = Expr::Binary(
Expand Down
6 changes: 6 additions & 0 deletions testing/join.test
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,12 @@ do_execsql_test join-using-multiple {
Cindy|Salazar|cap
Tommy|Perry|shirt"}

do_execsql_test join-using-multiple-with-quoting {
select u.first_name, u.last_name, p.name from users u join users u2 using(id) join [products] p using(`id`) limit 3;
} {"Jamie|Foster|hat
Cindy|Salazar|cap
Tommy|Perry|shirt"}

# NATURAL JOIN desugars to JOIN USING (common_column1, common_column2...)
do_execsql_test join-using {
select * from users natural join products limit 3;
Expand Down
8 changes: 8 additions & 0 deletions testing/select.test
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ do_execsql_test table-star-2 {
select p.*, u.first_name from users u join products p on u.id = p.id limit 1;
} {1|hat|79.0|Jamie}

do_execsql_test select_with_quoting {
select `users`.id from [users] where users.[id] = 5;
} {5}

do_execsql_test select_with_quoting_2 {
select "users".`id` from users where `users`.[id] = 5;
} {5}

do_execsql_test seekrowid {
select * from users u where u.id = 5;
} {"5|Edward|Miller|christiankramer@example.com|725-281-1033|08522 English Plain|Lake Keith|ID|23283|15"}
Expand Down

0 comments on commit 9e01c22

Please sign in to comment.