Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test against sqlite's test suite #123

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "tests/sqllogictest-sqlite"]
path = tests/sqllogictest-sqlite
url = https://github.com/risinglightdb/sqllogictest-sqlite
45 changes: 37 additions & 8 deletions sqllogictest/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,10 @@ impl<D: AsyncDB> Runner<D> {
self.validator = validator;
}

pub fn with_hash_threshold(&mut self, hash_threshold: usize) {
self.hash_threshold = hash_threshold;
}

pub async fn apply_record(&mut self, record: Record) -> RecordOutput {
match record {
Record::Statement { conditions, .. } if self.should_skip(&conditions) => {
Expand Down Expand Up @@ -496,15 +500,30 @@ impl<D: AsyncDB> Runner<D> {
}
};

let mut value_sort = false;
match sort_mode.as_ref().or(self.sort_mode.as_ref()) {
None | Some(SortMode::NoSort) => {}
Some(SortMode::RowSort) => {
rows.sort_unstable();
}
Some(SortMode::ValueSort) => todo!("value sort"),
Some(SortMode::ValueSort) => {
rows = rows
.iter()
.flat_map(|row| row.iter())
.map(|s| vec![s.to_owned()])
.collect();
rows.sort_unstable();
value_sort = true;
}
};

if self.hash_threshold > 0 && rows.len() > self.hash_threshold {
let num_values = if value_sort {
rows.len()
} else {
rows.len() * types.len()
};

if self.hash_threshold > 0 && num_values > self.hash_threshold {
let mut md5 = md5::Context::new();
for line in &rows {
for value in line {
Expand Down Expand Up @@ -688,17 +707,27 @@ impl<D: AsyncDB> Runner<D> {
}

// We compare normalized results. Whitespace characters are ignored.
let normalized_rows = rows
.into_iter()
.map(|strs| strs.iter().map(normalize_string).join(" "))
.collect_vec();

let expected_results = expected_results.iter().map(normalize_string).collect_vec();
if !(self.validator)(&normalized_rows, &expected_results) {

let actual_results =
if types.len() > 1 && rows.len() * types.len() == expected_results.len() {
// value-wise mode
rows.into_iter()
.flat_map(|strs| strs.iter().map(normalize_string).collect_vec())
.collect_vec()
} else {
// row-wise mode
rows.into_iter()
.map(|strs| strs.iter().map(normalize_string).join(" "))
.collect_vec()
};

if !(self.validator)(&actual_results, &expected_results) {
return Err(TestErrorKind::QueryResultMismatch {
sql,
expected: expected_results.join("\n"),
actual: normalized_rows.join("\n"),
actual: actual_results.join("\n"),
}
.at(loc));
}
Expand Down
8 changes: 7 additions & 1 deletion tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@ publish = false

[dependencies]
sqllogictest = { path = "../sqllogictest" }
rusqlite = { version = "0.28", features = ["bundled"] }

[[test]]
name = "harness"
path = "./harness.rs"
harness = false
harness = false

[[test]]
name = "sqlite"
path = "./sqlite.rs"
harness = false
87 changes: 87 additions & 0 deletions tests/sqlite.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use rusqlite::{types::ValueRef, Connection, Error};

use sqllogictest::{harness, ColumnType, DBOutput, Runner, DB};

fn hash_threshold(filename: &str) -> usize {
match filename {
"sqlite/select1.test" => 8,
"sqlite/select4.test" => 8,
"sqlite/select5.test" => 8,
_ => 0,
}
}

fn main() {
let paths = harness::glob("sqllogictest-sqlite/test/**/select*.test").unwrap();
let mut tests = vec![];
for entry in paths {
let path = entry.unwrap();
let filename = path.to_str().unwrap().to_string();
tests.push(harness::Trial::test(filename.clone(), move || {
let mut tester = Runner::new(db_fn());
tester.with_hash_threshold(hash_threshold(&filename));
tester.run_file(path)?;
Ok(())
}));
}
harness::run(&harness::Arguments::from_args(), tests).exit();
}

struct ConnectionWrapper(Connection);

fn db_fn() -> ConnectionWrapper {
let c = Connection::open_in_memory().unwrap();
ConnectionWrapper(c)
}

fn value_to_string(v: ValueRef) -> String {
match v {
ValueRef::Null => "NULL".to_string(),
ValueRef::Integer(i) => i.to_string(),
ValueRef::Real(r) => r.to_string(),
ValueRef::Text(s) => std::str::from_utf8(s).unwrap().to_string(),
ValueRef::Blob(_) => todo!(),
}
}

impl DB for ConnectionWrapper {
type Error = Error;

fn run(&mut self, sql: &str) -> Result<DBOutput, Self::Error> {
let mut output = vec![];

let is_query_sql = {
let lower_sql = sql.trim_start().to_ascii_lowercase();
lower_sql.starts_with("select")
|| lower_sql.starts_with("values")
|| lower_sql.starts_with("show")
|| lower_sql.starts_with("with")
|| lower_sql.starts_with("describe")
};

if is_query_sql {
let mut stmt = self.0.prepare(sql)?;
let column_count = stmt.column_count();
let mut rows = stmt.query([])?;
while let Some(row) = rows.next()? {
let mut row_output = vec![];
for i in 0..column_count {
let row = row.get_ref(i)?;
row_output.push(value_to_string(row));
}
output.push(row_output);
}
Ok(DBOutput::Rows {
types: vec![ColumnType::Any; column_count],
rows: output,
})
} else {
let cnt = self.0.execute(sql, [])?;
Ok(DBOutput::StatementComplete(cnt as u64))
}
}

fn engine_name(&self) -> &str {
"sqlite"
}
}
1 change: 1 addition & 0 deletions tests/sqllogictest-sqlite
Submodule sqllogictest-sqlite added at 4ab49f