Skip to content

Commit

Permalink
Merge pull request #4115 from weiznich/fix/sqlite_row_iter_shouldn_t_…
Browse files Browse the repository at this point in the history
…panic_if_called_again_after_error

Fixed a potential panic in SQLite row iterators
  • Loading branch information
weiznich committed Jul 16, 2024
1 parent e380c52 commit 2653a27
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 110 deletions.
1 change: 1 addition & 0 deletions diesel/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ cfg-if = "1"
dotenvy = "0.15"
ipnetwork = ">=0.12.2, <0.21.0"
quickcheck = "1.0.3"
tempfile = "3.10.1"

[features]
default = ["with-deprecated", "32-column-tables"]
Expand Down
289 changes: 186 additions & 103 deletions diesel/src/sqlite/connection/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,126 +200,209 @@ impl<'stmt, 'query> Field<'stmt, Sqlite> for SqliteField<'stmt, 'query> {
}
}

#[test]
fn fun_with_row_iters() {
crate::table! {
#[allow(unused_parens)]
users(id) {
id -> Integer,
name -> Text,
#[cfg(test)]
mod tests {
use super::*;

#[test]
fn fun_with_row_iters() {
crate::table! {
#[allow(unused_parens)]
users(id) {
id -> Integer,
name -> Text,
}
}
}

use crate::connection::LoadConnection;
use crate::deserialize::{FromSql, FromSqlRow};
use crate::prelude::*;
use crate::row::{Field, Row};
use crate::sql_types;
use crate::connection::LoadConnection;
use crate::deserialize::{FromSql, FromSqlRow};
use crate::prelude::*;
use crate::row::{Field, Row};
use crate::sql_types;

let conn = &mut crate::test_helpers::connection();
let conn = &mut crate::test_helpers::connection();

crate::sql_query("CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);")
.execute(conn)
.unwrap();
crate::sql_query("CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);")
.execute(conn)
.unwrap();

crate::insert_into(users::table)
.values(vec![
(users::id.eq(1), users::name.eq("Sean")),
(users::id.eq(2), users::name.eq("Tess")),
])
.execute(conn)
.unwrap();
crate::insert_into(users::table)
.values(vec![
(users::id.eq(1), users::name.eq("Sean")),
(users::id.eq(2), users::name.eq("Tess")),
])
.execute(conn)
.unwrap();

let query = users::table.select((users::id, users::name));
let query = users::table.select((users::id, users::name));

let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))];
let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))];

let row_iter = conn.load(query).unwrap();
for (row, expected) in row_iter.zip(&expected) {
let row = row.unwrap();
let row_iter = conn.load(query).unwrap();
for (row, expected) in row_iter.zip(&expected) {
let row = row.unwrap();

let deserialized = <(i32, String) as FromSqlRow<
(sql_types::Integer, sql_types::Text),
_,
>>::build_from_row(&row)
.unwrap();
let deserialized = <(i32, String) as FromSqlRow<
(sql_types::Integer, sql_types::Text),
_,
>>::build_from_row(&row)
.unwrap();

assert_eq!(&deserialized, expected);
}
assert_eq!(&deserialized, expected);
}

{
let collected_rows = conn.load(query).unwrap().collect::<Vec<_>>();
{
let collected_rows = conn.load(query).unwrap().collect::<Vec<_>>();

for (row, expected) in collected_rows.iter().zip(&expected) {
let deserialized = row
.as_ref()
.map(|row| {
<(i32, String) as FromSqlRow<
for (row, expected) in collected_rows.iter().zip(&expected) {
let deserialized = row
.as_ref()
.map(|row| {
<(i32, String) as FromSqlRow<
(sql_types::Integer, sql_types::Text),
_,
>>::build_from_row(row).unwrap()
})
.unwrap();
})
.unwrap();

assert_eq!(&deserialized, expected);
assert_eq!(&deserialized, expected);
}
}

let mut row_iter = conn.load(query).unwrap();

let first_row = row_iter.next().unwrap().unwrap();
let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
let first_values = (first_fields.0.value(), first_fields.1.value());

assert!(row_iter.next().unwrap().is_err());
std::mem::drop(first_values);
assert!(row_iter.next().unwrap().is_err());
std::mem::drop(first_fields);

let second_row = row_iter.next().unwrap().unwrap();
let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap());
let second_values = (second_fields.0.value(), second_fields.1.value());

assert!(row_iter.next().unwrap().is_err());
std::mem::drop(second_values);
assert!(row_iter.next().unwrap().is_err());
std::mem::drop(second_fields);

assert!(row_iter.next().is_none());

let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap());

let first_values = (first_fields.0.value(), first_fields.1.value());
let second_values = (second_fields.0.value(), second_fields.1.value());

assert_eq!(
<i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(first_values.0)
.unwrap(),
expected[0].0
);
assert_eq!(
<String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(first_values.1)
.unwrap(),
expected[0].1
);

assert_eq!(
<i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(second_values.0)
.unwrap(),
expected[1].0
);
assert_eq!(
<String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(second_values.1)
.unwrap(),
expected[1].1
);

let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
let first_values = (first_fields.0.value(), first_fields.1.value());

assert_eq!(
<i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(first_values.0)
.unwrap(),
expected[0].0
);
assert_eq!(
<String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(first_values.1)
.unwrap(),
expected[0].1
);
}

let mut row_iter = conn.load(query).unwrap();

let first_row = row_iter.next().unwrap().unwrap();
let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
let first_values = (first_fields.0.value(), first_fields.1.value());

assert!(row_iter.next().unwrap().is_err());
std::mem::drop(first_values);
assert!(row_iter.next().unwrap().is_err());
std::mem::drop(first_fields);

let second_row = row_iter.next().unwrap().unwrap();
let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap());
let second_values = (second_fields.0.value(), second_fields.1.value());

assert!(row_iter.next().unwrap().is_err());
std::mem::drop(second_values);
assert!(row_iter.next().unwrap().is_err());
std::mem::drop(second_fields);

assert!(row_iter.next().is_none());

let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap());

let first_values = (first_fields.0.value(), first_fields.1.value());
let second_values = (second_fields.0.value(), second_fields.1.value());

assert_eq!(
<i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(first_values.0).unwrap(),
expected[0].0
);
assert_eq!(
<String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(first_values.1).unwrap(),
expected[0].1
);

assert_eq!(
<i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(second_values.0).unwrap(),
expected[1].0
);
assert_eq!(
<String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(second_values.1).unwrap(),
expected[1].1
);

let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
let first_values = (first_fields.0.value(), first_fields.1.value());

assert_eq!(
<i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(first_values.0).unwrap(),
expected[0].0
);
assert_eq!(
<String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(first_values.1).unwrap(),
expected[0].1
);
#[cfg(feature = "returning_clauses_for_sqlite_3_35")]
crate::define_sql_function! {fn sleep(a: diesel::sql_types::Integer) -> diesel::sql_types::Integer}

#[test]
#[cfg(feature = "returning_clauses_for_sqlite_3_35")]
fn parallel_iter_with_error() {
use crate::connection::Connection;
use crate::connection::LoadConnection;
use crate::connection::SimpleConnection;
use crate::expression_methods::ExpressionMethods;
use crate::SqliteConnection;
use std::sync::{Arc, Barrier};
use std::time::Duration;

let temp_dir = tempfile::tempdir().unwrap();
let db_path = format!("{}/test.db", temp_dir.path().display());
let mut conn1 = SqliteConnection::establish(&db_path).unwrap();
let mut conn2 = SqliteConnection::establish(&db_path).unwrap();

crate::table! {
users {
id -> Integer,
name -> Text,
}
}

conn1
.batch_execute("CREATE TABLE users(id INTEGER NOT NULL PRIMARY KEY, name TEXT)")
.unwrap();

let barrier = Arc::new(Barrier::new(2));
let barrier2 = barrier.clone();

// we unblock the main thread from the sleep function
sleep_utils::register_impl(&mut conn2, move |a: i32| {
barrier.wait();
std::thread::sleep(Duration::from_secs(a as u64));
a
})
.unwrap();

// spawn a background thread that locks the database file
let handle = std::thread::spawn(move || {
use crate::query_dsl::RunQueryDsl;

conn2
.immediate_transaction(|conn| diesel::select(sleep(1)).execute(conn))
.unwrap();
});
barrier2.wait();

// execute some action that also requires a lock
let mut iter = conn1
.load(
diesel::insert_into(users::table)
.values((users::id.eq(1), users::name.eq("John")))
.returning(users::id),
)
.unwrap();

// get the first iterator result, that should return the lock error
let n = iter.next().unwrap();
assert!(n.is_err());

// check that the iterator is now empty
let n = iter.next();
assert!(n.is_none());

// join the background thread
handle.join().unwrap();
}
}
14 changes: 7 additions & 7 deletions diesel/src/sqlite/connection/statement_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl<'stmt, 'query> Iterator for StatementIterator<'stmt, 'query> {
fn next(&mut self) -> Option<Self::Item> {
use PrivateStatementIterator::{NotStarted, Started};
match &mut self.inner {
NotStarted(ref mut stmt) if stmt.is_some() => {
NotStarted(ref mut stmt @ Some(_)) => {
let mut stmt = stmt
.take()
.expect("It must be there because we checked that above");
Expand Down Expand Up @@ -161,12 +161,12 @@ impl<'stmt, 'query> Iterator for StatementIterator<'stmt, 'query> {
)
}
}
NotStarted(_) => unreachable!(
"You've reached an impossible internal state. \
If you ever see this error message please open \
an issue at https://github.com/diesel-rs/diesel \
providing example code how to trigger this error."
),
NotStarted(_s) => {
// we likely got an error while executing the other
// `NotStarted` branch above. In this case we just want to stop
// iterating here
None
}
}
}
}

0 comments on commit 2653a27

Please sign in to comment.