-
Notifications
You must be signed in to change notification settings - Fork 265
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Lann Martin <lann.martin@fermyon.com>
- Loading branch information
Showing
2 changed files
with
97 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
pub mod split; | ||
|
||
use std::{ | ||
path::PathBuf, | ||
sync::{Arc, Mutex}, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
use std::os::raw::c_char; | ||
|
||
use rusqlite::{ffi, Connection}; | ||
|
||
pub struct SqlSplitter { | ||
conn: Connection, | ||
} | ||
|
||
impl SqlSplitter { | ||
pub fn new() -> rusqlite::Result<Self> { | ||
let conn = Connection::open_in_memory()?; | ||
Ok(Self { conn }) | ||
} | ||
|
||
pub fn split<'a>(&self, sql: &'a str) -> rusqlite::Result<Option<(&'a str, &'a str)>> { | ||
// Adapted from rusqlite::InnerConnection::prepare | ||
let c_sql = sql.as_ptr().cast::<c_char>(); | ||
let len = sql.len().try_into().expect("sql len must be < 2GiB"); | ||
let mut c_stmt: *mut ffi::sqlite3_stmt = std::ptr::null_mut(); | ||
let mut c_tail: *const c_char = std::ptr::null(); | ||
let ret = unsafe { | ||
ffi::sqlite3_prepare(self.conn.handle(), c_sql, len, &mut c_stmt, &mut c_tail) | ||
}; | ||
if ret != ffi::SQLITE_OK { | ||
// Note: This *should* fail and return a nicely-formatted error, | ||
// but isn't exactly the same code path so could succeed somehow. | ||
self.conn.execute_batch(sql)?; | ||
return Err(rusqlite::Error::SqliteFailure(ffi::Error::new(ret), None)); | ||
} | ||
if c_stmt.is_null() { | ||
return Ok(None); | ||
} | ||
let parsed_len = if c_tail.is_null() { | ||
sql.len() | ||
} else { | ||
((c_tail as isize) - (c_sql as isize)) as usize | ||
}; | ||
let (parsed, tail) = sql.split_at(parsed_len); | ||
Ok(Some((parsed.trim(), tail.trim()))) | ||
} | ||
|
||
pub fn split_all<'a>(&self, mut sql: &'a str) -> rusqlite::Result<Vec<&'a str>> { | ||
let mut splits = Vec::with_capacity(1); | ||
while let Some((parsed, tail)) = self.split(sql)? { | ||
splits.push(parsed); | ||
sql = tail; | ||
} | ||
Ok(splits) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use rusqlite::Error; | ||
|
||
use super::*; | ||
|
||
#[test] | ||
fn happy_tests() { | ||
let splitter = SqlSplitter::new().unwrap(); | ||
for (input, want_splits) in [ | ||
(" ", &[][..]), | ||
("SELECT 1", &["SELECT 1"]), | ||
("SELECT 1;", &["SELECT 1;"]), | ||
("SELECT 1; -- trailing comment", &["SELECT 1;"]), | ||
("SELECT 1;SELECT 2", &["SELECT 1;", "SELECT 2"]), | ||
("SELECT 1 ; SELECT 2 ", &["SELECT 1 ;", "SELECT 2"]), | ||
("SELECT 1 ; SELECT 2; ", &["SELECT 1 ;", "SELECT 2;"]), | ||
] { | ||
let got_splits = splitter | ||
.split_all(input) | ||
.unwrap_or_else(|err| panic!("failed to parse {input:?}: {err:?}")); | ||
assert_eq!( | ||
got_splits, want_splits, | ||
"unexpected output for input {input:?}" | ||
); | ||
} | ||
} | ||
|
||
#[test] | ||
fn sad_tests() { | ||
let splitter = SqlSplitter::new().unwrap(); | ||
|
||
assert!(matches!( | ||
splitter.split_all("NOT A SQL STATEMENT").unwrap_err(), | ||
Error::SqlInputError { .. } | ||
)); | ||
assert!(matches!( | ||
splitter | ||
.split_all("SELECT 1; NOT A SQL STATEMENT") | ||
.unwrap_err(), | ||
Error::SqlInputError { .. } | ||
)); | ||
} | ||
} |