Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sqlite: Add split module #1772

Merged
merged 2 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions crates/sqlite-libsql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ edition = { workspace = true }
[dependencies]
async-trait = "0.1.68"
anyhow = "1.0"
# We don't actually use rusqlite itself, but we'd like the same bundled
# libsqlite3-sys as used by spin-sqlite-inproc.
rusqlite = { version = "0.29.0", features = [ "bundled" ] }
spin-sqlite = { path = "../sqlite" }
spin-world = { path = "../world" }
sqlparser = "0.34"
Expand Down
14 changes: 7 additions & 7 deletions crates/sqlite-libsql/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
mod split;

use spin_world::sqlite::{self, RowResult};

#[derive(Clone)]
Expand Down Expand Up @@ -43,13 +45,11 @@ impl spin_sqlite::Connection for LibsqlClient {

// Unfortunately, the libsql library requires that the statements are already split
// into individual statement strings which requires us to parse the supplied SQL string.
let stmts = sqlparser::parser::Parser::parse_sql(
&sqlparser::dialect::SQLiteDialect {},
statements,
)?
.into_iter()
.map(|st| st.to_string())
.map(libsql_client::Statement::from);
let stmts = split::split_sql(statements)
.map(|res| Ok(res?.to_string()))
.collect::<anyhow::Result<Vec<_>>>()?
.into_iter()
.map(libsql_client::Statement::from);

let _ = client.batch(stmts).await?;

Expand Down
168 changes: 168 additions & 0 deletions crates/sqlite-libsql/src/split.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
use std::ffi::CStr;

use rusqlite::ffi;

/// Splits the given SQL into complete Sqlite statements.
///
/// Yields an error if the SQL includes incomplete Sqlite statements or if
/// Sqlite returns an error.
pub fn split_sql(mut sql: &str) -> impl Iterator<Item = Result<&str, Error>> {
std::iter::from_fn(move || {
if sql.is_empty() {
return None;
}
match split_sql_once(sql) {
Ok((stmt, tail)) => {
sql = tail;
Some(Ok(stmt))
}
Err(err) => {
sql = "";
Some(Err(err))
}
}
})
}

/// Splits the given SQL into one complete Sqlite statement and any remaining
/// text after the ending semicolon.
///
/// Returns an error if the SQL is an _incomplete_ Sqlite statement or if Sqlite
/// returns an error.
pub fn split_sql_once(sql: &str) -> Result<(&str, &str), Error> {
for (idx, _) in sql.match_indices(';') {
let (candidate, tail) = sql.split_at(idx + 1);
match ensure_complete(candidate) {
Ok(()) => return Ok((candidate, tail)),
Err(Error::Incomplete) => {
// May be a semicolon inside e.g. a string literal.
continue;
}
Err(err) => return Err(err),
}
}
ensure_complete(sql)?;
Ok((sql, ""))
}

// Validates that the given SQL is complete.
// Returns an error if the SQL is an incomplete Sqlite statement or if Sqlite
// returns an error.
fn ensure_complete(sql: &str) -> Result<(), Error> {
let mut bytes: Vec<u8> = sql.into();
if !bytes.ends_with(b";") {
bytes.extend_from_slice(b"\n;");
}
bytes.push(b'\0');
let c_str = CStr::from_bytes_with_nul(&bytes).unwrap();
let c_ptr = c_str.as_ptr() as *const std::os::raw::c_char;
match unsafe { ffi::sqlite3_complete(c_ptr) } {
1 => Ok(()),
0 => Err(Error::Incomplete),
code => Err(Error::Sqlite(ffi::Error::new(code))),
}
}

/// The error type for splitting SQL.
#[derive(Debug, PartialEq)]
pub enum Error {
/// Returned for incomplete Sqlite statements, e.g. an unterminated string.
Incomplete,
/// Returned for errors from Sqlite itself.
Sqlite(ffi::Error),
}

impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Incomplete => write!(f, "not a complete SQL statement"),
Self::Sqlite(err) => write!(f, "{err}"),
}
}
}
impl std::error::Error for Error {}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_split_sql() {
for (input, want_stmts) in [
("", &[][..]),
("/* comment */", &["/* comment */"]),
("SELECT 1;", &["SELECT 1;"]),
("SELECT 1;SELECT 2", &["SELECT 1;", "SELECT 2"]),
("SELECT 1;SELECT 2", &["SELECT 1;", "SELECT 2"]),
(
"CREATE TABLE fiteme23 (num INT);\nINSERT INTO fiteme23(num) VALUES(55);",
&[
"CREATE TABLE fiteme23 (num INT);",
"\nINSERT INTO fiteme23(num) VALUES(55);",
],
),
] {
let stmts = split_sql(input)
.collect::<Result<Vec<_>, Error>>()
.unwrap_or_else(|err| panic!("Failed to split {input:?}: {err}"));
assert_eq!(stmts, want_stmts, "for {input:?}");
}
}

#[test]
fn test_split_sql_once_no_tail() {
for input in [
"",
" ",
"SELECT 1",
"SELECT 1;",
"SELECT * From some_table",
"SELECT 1 -- trailing comment",
"SELECT 1 -- trailing comment\n;",
"SELECT 1 /* trailing comment */",
"SELECT 1 /* trailing comment */;",
"-- leading comment\nSELECT 1",
"/* leading comment */ SELECT 1",
" -- Just a comment",
"/* comment one */ -- comment two",
"CREATE virtual TABLE vss_blog_posts3 USING vss0(embedding(384))",
"CREATE TRIGGER update_customer_address UPDATE OF address ON customers \n BEGIN\n UPDATE orders SET address = new.address WHERE customer_name = old.name;\n END;",
] {
let (stmt, tail) = split_sql_once(input)
.unwrap_or_else(|err| panic!("Failed to split {input:?}: {err}"));
assert_eq!(stmt, input, "for {input:?}");
assert_eq!(tail, "", "for {input:?}");
}
}

#[test]
fn test_split_sql_once_tail() {
for (input, want_stmt, want_tail) in [
("SELECT 1; ", "SELECT 1;", " "),
("SELECT 1;SELECT 2", "SELECT 1;", "SELECT 2"),
("SELECT 1; -- tail", "SELECT 1;", " -- tail"),
("--leading\n; SELECT 1", "--leading\n;", " SELECT 1"),
("/* leading */; SELECT 1", "/* leading */;", " SELECT 1"),
] {
let (stmt, tail) = split_sql_once(input)
.unwrap_or_else(|err| panic!("Failed to split {input:?}: {err}"));
assert_eq!(stmt, want_stmt, "for {input:?}");
assert_eq!(tail, want_tail, "for {input:?}");
}
}

#[test]
fn test_split_sql_once_incomplete() {
for input in [
"SELECT 'incomplete string",
"/* incomplete comment",
"SELECT /* tricky comment '*/ '",
] {
assert_eq!(
split_sql_once(input),
Err(Error::Incomplete),
"for {input:?}"
);
}
}
}