diff --git a/sqlx-core/src/sqlite/connection/explain.rs b/sqlx-core/src/sqlite/connection/explain.rs index 6179125577..310d803227 100644 --- a/sqlx-core/src/sqlite/connection/explain.rs +++ b/sqlx-core/src/sqlite/connection/explain.rs @@ -97,6 +97,50 @@ fn opcode_to_type(op: &str) -> DataType { } } +fn root_block_columns( + conn: &mut ConnectionState, +) -> Result>, Error> { + let table_block_columns: Vec<(i64, i64, String)> = execute::iter( + conn, + "SELECT s.rootpage, col.cid as colnum, col.type + FROM sqlite_schema s + JOIN pragma_table_info(s.name) AS col + WHERE s.type = 'table'", + None, + false, + )? + .filter_map(|res| res.map(|either| either.right()).transpose()) + .map(|row| FromRow::from_row(&row?)) + .collect::, Error>>()?; + + let index_block_columns: Vec<(i64, i64, String)> = execute::iter( + conn, + "SELECT s.rootpage, idx.seqno as colnum, col.type + FROM sqlite_schema s + JOIN pragma_index_info(s.name) AS idx + LEFT JOIN pragma_table_info(s.tbl_name) as col + ON col.cid = idx.cid + WHERE s.type = 'index'", + None, + false, + )? + .filter_map(|res| res.map(|either| either.right()).transpose()) + .map(|row| FromRow::from_row(&row?)) + .collect::, Error>>()?; + + let mut row_info: HashMap> = HashMap::new(); + for (block, colnum, datatype) in table_block_columns { + let row_info = row_info.entry(block).or_default(); + row_info.insert(colnum, datatype.parse().unwrap_or(DataType::Null)); + } + for (block, colnum, datatype) in index_block_columns { + let row_info = row_info.entry(block).or_default(); + row_info.insert(colnum, datatype.parse().unwrap_or(DataType::Null)); + } + + return Ok(row_info); +} + // Opcode Reference: https://sqlite.org/opcode.html pub(super) fn explain( conn: &mut ConnectionState, @@ -112,6 +156,8 @@ pub(super) fn explain( // Nullable columns let mut n = HashMap::::with_capacity(6); + let root_block_cols = root_block_columns(conn)?; + let program: Vec<(i64, String, i64, i64, i64, Vec)> = execute::iter(conn, &format!("EXPLAIN {}", query), None, false)? .filter_map(|res| res.map(|either| either.right()).transpose()) @@ -190,7 +236,23 @@ pub(super) fn explain( OP_OPEN_READ | OP_OPEN_WRITE | OP_OPEN_EPHEMERAL | OP_OPEN_AUTOINDEX => { //Create a new pointer which is referenced by p1 - p.insert(p1, HashMap::with_capacity(6)); + + //Create a new pointer which is referenced by p1, take column metadata from db schema if found + if p3 == 0 { + if let Some(columns) = root_block_cols.get(&p2) { + p.insert( + p1, + columns + .iter() + .map(|(&colnum, &datatype)| (colnum, datatype)) + .collect(), + ); + } else { + p.insert(p1, HashMap::with_capacity(6)); + } + } else { + p.insert(p1, HashMap::with_capacity(6)); + } } OP_VARIABLE => { @@ -339,3 +401,126 @@ pub(super) fn explain( Ok((output, nullable)) } + +#[test] +fn test_root_block_columns_has_types() { + use crate::sqlite::SqliteConnectOptions; + use std::str::FromStr; + let conn_options = SqliteConnectOptions::from_str("sqlite::memory:").unwrap(); + let mut conn = super::EstablishParams::from_options(&conn_options) + .unwrap() + .establish() + .unwrap(); + + assert!(execute::iter( + &mut conn, + r"CREATE TABLE t(a INTEGER PRIMARY KEY, b_null TEXT NULL, b TEXT NOT NULL);", + None, + false + ) + .unwrap() + .next() + .is_some()); + assert!( + execute::iter(&mut conn, r"CREATE INDEX i1 on t (a,b_null);", None, false) + .unwrap() + .next() + .is_some() + ); + assert!(execute::iter( + &mut conn, + r"CREATE UNIQUE INDEX i2 on t (a,b_null);", + None, + false + ) + .unwrap() + .next() + .is_some()); + assert!(execute::iter( + &mut conn, + r"CREATE TABLE t2(a INTEGER, b_null NUMERIC NULL, b NUMERIC NOT NULL);", + None, + false + ) + .unwrap() + .next() + .is_some()); + assert!(execute::iter( + &mut conn, + r"CREATE INDEX t2i1 on t2 (a,b_null);", + None, + false + ) + .unwrap() + .next() + .is_some()); + assert!(execute::iter( + &mut conn, + r"CREATE UNIQUE INDEX t2i2 on t2 (a,b);", + None, + false + ) + .unwrap() + .next() + .is_some()); + + let table_block_nums: HashMap = execute::iter( + &mut conn, + r"select name, rootpage from sqlite_master", + None, + false, + ) + .unwrap() + .filter_map(|res| res.map(|either| either.right()).transpose()) + .map(|row| FromRow::from_row(row.as_ref().unwrap())) + .collect::, Error>>() + .unwrap(); + + let root_block_cols = root_block_columns(&mut conn).unwrap(); + + assert_eq!(6, root_block_cols.len()); + + //prove that we have some information for each table & index + for blocknum in table_block_nums.values() { + assert!(root_block_cols.contains_key(blocknum)); + } + + //prove that each block has the correct information + { + let blocknum = table_block_nums["t"]; + assert_eq!((DataType::Int64), root_block_cols[&blocknum][&0]); + assert_eq!((DataType::Text), root_block_cols[&blocknum][&1]); + assert_eq!((DataType::Text), root_block_cols[&blocknum][&2]); + } + + { + let blocknum = table_block_nums["i1"]; + assert_eq!((DataType::Int64), root_block_cols[&blocknum][&0]); + assert_eq!((DataType::Text), root_block_cols[&blocknum][&1]); + } + + { + let blocknum = table_block_nums["i2"]; + assert_eq!((DataType::Int64), root_block_cols[&blocknum][&0]); + assert_eq!((DataType::Text), root_block_cols[&blocknum][&1]); + } + + { + let blocknum = table_block_nums["t2"]; + assert_eq!((DataType::Int64), root_block_cols[&blocknum][&0]); + assert_eq!((DataType::Null), root_block_cols[&blocknum][&1]); + assert_eq!((DataType::Null), root_block_cols[&blocknum][&2]); + } + + { + let blocknum = table_block_nums["t2i1"]; + assert_eq!((DataType::Int64), root_block_cols[&blocknum][&0]); + assert_eq!((DataType::Null), root_block_cols[&blocknum][&1]); + } + + { + let blocknum = table_block_nums["t2i2"]; + assert_eq!((DataType::Int64), root_block_cols[&blocknum][&0]); + assert_eq!((DataType::Null), root_block_cols[&blocknum][&1]); + } +} diff --git a/tests/sqlite/describe.rs b/tests/sqlite/describe.rs index 90d59284ea..e75d606223 100644 --- a/tests/sqlite/describe.rs +++ b/tests/sqlite/describe.rs @@ -242,5 +242,17 @@ async fn it_describes_left_join() -> anyhow::Result<()> { assert_eq!(d.column(1).type_info().name(), "INTEGER"); assert_eq!(d.nullable(1), Some(false)); + let d = conn + .describe( + "select tweet.id, accounts.id from accounts left join tweet on tweet.id = accounts.id", + ) + .await?; + + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(true)); + + assert_eq!(d.column(1).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(1), Some(false)); + Ok(()) }