diff --git a/sqlx-core/src/sqlite/connection/explain.rs b/sqlx-core/src/sqlite/connection/explain.rs index b470ce1ecf..7d54bfddfc 100644 --- a/sqlx-core/src/sqlite/connection/explain.rs +++ b/sqlx-core/src/sqlite/connection/explain.rs @@ -4,6 +4,7 @@ use crate::sqlite::connection::{execute, ConnectionState}; use crate::sqlite::type_info::DataType; use crate::sqlite::SqliteTypeInfo; use crate::HashMap; +use std::collections::HashSet; use std::str::from_utf8; // affinity @@ -121,7 +122,7 @@ const OP_CONCAT: &str = "Concat"; const OP_RESULT_ROW: &str = "ResultRow"; const OP_HALT: &str = "Halt"; -#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] struct ColumnType { pub datatype: DataType, pub nullable: Option, @@ -145,7 +146,7 @@ impl ColumnType { } } -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq, Hash)] enum RegDataType { Single(ColumnType), Record(Vec), @@ -341,6 +342,49 @@ struct QueryState { pub result: Option, Option)>>, } +#[derive(Debug, Hash, PartialEq, Eq)] +struct BranchStateHash { + instruction: usize, + registers: Vec<(i64, RegDataType)>, + cursors: Vec<(i64, i64, Option)>, +} + +impl BranchStateHash { + pub fn from_query_state(st: &QueryState) -> Self { + let mut reg = vec![]; + for (k, v) in &st.r { + reg.push((*k, v.clone())); + } + reg.sort_by_key(|v| v.0); + + let mut cur = vec![]; + for (k, v) in &st.p { + match v { + CursorDataType::Normal(hm) => { + for (i, col) in hm { + cur.push((*k, *i, Some(col.clone()))); + } + } + CursorDataType::Pseudo(i) => { + cur.push((*k, *i, None)); + } + } + } + cur.sort_by(|a, b| { + if a.0 == b.0 { + a.1.cmp(&b.1) + } else { + a.0.cmp(&b.0) + } + }); + Self { + instruction: st.program_i, + registers: reg, + cursors: cur, + } + } +} + // Opcode Reference: https://sqlite.org/opcode.html pub(super) fn explain( conn: &mut ConnectionState, @@ -366,6 +410,8 @@ pub(super) fn explain( result: None, }]; + let mut visited_branch_state: HashSet = HashSet::new(); + let mut result_states = Vec::new(); while let Some(mut state) = states.pop() { @@ -407,7 +453,12 @@ pub(super) fn explain( let mut branch_state = state.clone(); branch_state.program_i = p2 as usize; - states.push(branch_state); + + let bs_hash = BranchStateHash::from_query_state(&branch_state); + if !visited_branch_state.contains(&bs_hash) { + visited_branch_state.insert(bs_hash); + states.push(branch_state); + } state.program_i += 1; continue; @@ -520,15 +571,27 @@ pub(super) fn explain( let mut branch_state = state.clone(); branch_state.program_i = p1 as usize; - states.push(branch_state); + let bs_hash = BranchStateHash::from_query_state(&branch_state); + if !visited_branch_state.contains(&bs_hash) { + visited_branch_state.insert(bs_hash); + states.push(branch_state); + } let mut branch_state = state.clone(); branch_state.program_i = p2 as usize; - states.push(branch_state); + let bs_hash = BranchStateHash::from_query_state(&branch_state); + if !visited_branch_state.contains(&bs_hash) { + visited_branch_state.insert(bs_hash); + states.push(branch_state); + } let mut branch_state = state.clone(); branch_state.program_i = p3 as usize; - states.push(branch_state); + let bs_hash = BranchStateHash::from_query_state(&branch_state); + if !visited_branch_state.contains(&bs_hash) { + visited_branch_state.insert(bs_hash); + states.push(branch_state); + } } OP_COLUMN => {