From f4060d3bd786aacdc6bb76ee3a559f46967ab05a Mon Sep 17 00:00:00 2001 From: Jerome Gravel-Niquet Date: Wed, 21 Aug 2024 13:36:42 -0400 Subject: [PATCH 1/2] override returned types for some functions like COUNT or when casting occurs --- crates/corro-pg/src/lib.rs | 154 +++++++++++++++++++++++++++++++++---- 1 file changed, 141 insertions(+), 13 deletions(-) diff --git a/crates/corro-pg/src/lib.rs b/crates/corro-pg/src/lib.rs index 1c868682..9c3bfc15 100644 --- a/crates/corro-pg/src/lib.rs +++ b/crates/corro-pg/src/lib.rs @@ -48,7 +48,7 @@ use rusqlite::{ use spawn::spawn_counted; use sqlite3_parser::ast::{ As, Cmd, ColumnDefinition, CreateTableBody, Expr, FromClause, Id, InsertBody, Limit, Name, - OneSelect, ResultColumn, Select, SelectTable, Stmt, With, + OneSelect, ResultColumn, Select, SelectBody, SelectTable, Stmt, With, }; use sqlparser::ast::Statement as PgStatement; use tokio::{ @@ -848,21 +848,144 @@ pub async fn start( .collect(); } + let mut field_type_overrides = HashMap::new(); + + match &parsed_cmd { + ParsedCmd::Sqlite(Cmd::Stmt(stmt)) => match stmt { + Stmt::Select(Select { + body: + SelectBody { + select: + OneSelect::Select { + columns: cols, .. + }, + .. + }, + .. + }) + | Stmt::Delete { + returning: Some(cols), + .. + } + | Stmt::Insert { + returning: Some(cols), + .. + } + | Stmt::Update { + returning: Some(cols), + .. + } => { + for (i, col) in cols.iter().enumerate() { + if let ResultColumn::Expr(expr, _as) = col { + let type_override = match expr { + Expr::Cast { + type_name, .. + } => { + match name_to_type( + &type_name.name, + ) { + Ok(t) => Some(t), + Err(e) => { + back_tx.blocking_send( + (e.into(), true) + .into(), + )?; + discard_until_sync = + true; + continue 'outer; + } + } + } + Expr::FunctionCall { + name, .. + } + | Expr::FunctionCallStar { + name, + .. + } => match name + .0 + .as_str() + .to_uppercase() + .as_ref() + { + "COUNT" => Some(Type::INT8), + _ => None, + }, + _ => None, + }; + if let Some(type_override) = + type_override + { + match prepped.column_name(i) { + Ok(col_name) => { + field_type_overrides + .insert( + col_name, + type_override, + ); + } + Err(e) => { + error!("col index didn't exist at {i}, attempted to override type as: {type_override}: {e}"); + } + } + } + } else { + break; + } + } + } + _ => {} + }, + ParsedCmd::Postgres(_stmt) => { + // TODO: handle type overrides here too + // let cols = match stmt { + // PgStatement::Insert { returning, .. } + // | PgStatement::Update { returning, .. } + // | PgStatement::Delete { returning, .. } => { + // returning + // } + // PgStatement::Query(query) => { + // match *query.body { + // sqlparser::ast::SetExpr::Select( + // select, + // ) => Some(select.projection), + // _ => None, + // } + // } + // _ => None, + // }; + + // if let Some(cols) = cols { + + // } + } + _ => {} + } + let mut fields = vec![]; for col in prepped.columns() { - let col_type = match name_to_type( - col.decl_type().unwrap_or("text"), - ) { - Ok(t) => t, - Err(e) => { - back_tx - .blocking_send((e.into(), true).into())?; - discard_until_sync = true; - continue 'outer; - } - }; + let col_name = col.name(); + let col_type = + match field_type_overrides.remove(col_name) { + Some(t) => t, + None => match col.decl_type() { + None => Type::TEXT, + Some(decl_type) => { + match name_to_type(decl_type) { + Ok(t) => t, + Err(e) => { + back_tx.blocking_send( + (e.into(), true).into(), + )?; + discard_until_sync = true; + continue 'outer; + } + } + } + }, + }; fields.push(FieldInfo::new( - col.name().to_string(), + col_name.to_string(), None, None, col_type, @@ -3253,6 +3376,11 @@ mod tests { println!("updated_at: {updated_at:?}"); assert_eq!(future, updated_at); + + let row = client + .query_one("SELECT COUNT(*) FROM kitchensink", &[]) + .await?; + println!("COUNT ROW: {row:?}"); } tripwire_tx.send(()).await.ok(); From 03daefd51cc9f2396fbaf3b9735a00237c5a5150 Mon Sep 17 00:00:00 2001 From: Jerome Gravel-Niquet Date: Wed, 21 Aug 2024 14:32:07 -0400 Subject: [PATCH 2/2] handle field types in a standard way across all functions --- crates/corro-pg/src/lib.rs | 367 +++++++++++++++++-------------------- 1 file changed, 171 insertions(+), 196 deletions(-) diff --git a/crates/corro-pg/src/lib.rs b/crates/corro-pg/src/lib.rs index 9c3bfc15..9b0caf4a 100644 --- a/crates/corro-pg/src/lib.rs +++ b/crates/corro-pg/src/lib.rs @@ -47,8 +47,8 @@ use rusqlite::{ }; use spawn::spawn_counted; use sqlite3_parser::ast::{ - As, Cmd, ColumnDefinition, CreateTableBody, Expr, FromClause, Id, InsertBody, Limit, Name, - OneSelect, ResultColumn, Select, SelectBody, SelectTable, Stmt, With, + As, Cmd, ColumnDefinition, CreateTableBody, Expr, FromClause, Id, InsertBody, Limit, Literal, + Name, OneSelect, ResultColumn, Select, SelectBody, SelectTable, Stmt, With, }; use sqlparser::ast::Statement as PgStatement; use tokio::{ @@ -848,150 +848,18 @@ pub async fn start( .collect(); } - let mut field_type_overrides = HashMap::new(); - - match &parsed_cmd { - ParsedCmd::Sqlite(Cmd::Stmt(stmt)) => match stmt { - Stmt::Select(Select { - body: - SelectBody { - select: - OneSelect::Select { - columns: cols, .. - }, - .. - }, - .. - }) - | Stmt::Delete { - returning: Some(cols), - .. - } - | Stmt::Insert { - returning: Some(cols), - .. - } - | Stmt::Update { - returning: Some(cols), - .. - } => { - for (i, col) in cols.iter().enumerate() { - if let ResultColumn::Expr(expr, _as) = col { - let type_override = match expr { - Expr::Cast { - type_name, .. - } => { - match name_to_type( - &type_name.name, - ) { - Ok(t) => Some(t), - Err(e) => { - back_tx.blocking_send( - (e.into(), true) - .into(), - )?; - discard_until_sync = - true; - continue 'outer; - } - } - } - Expr::FunctionCall { - name, .. - } - | Expr::FunctionCallStar { - name, - .. - } => match name - .0 - .as_str() - .to_uppercase() - .as_ref() - { - "COUNT" => Some(Type::INT8), - _ => None, - }, - _ => None, - }; - if let Some(type_override) = - type_override - { - match prepped.column_name(i) { - Ok(col_name) => { - field_type_overrides - .insert( - col_name, - type_override, - ); - } - Err(e) => { - error!("col index didn't exist at {i}, attempted to override type as: {type_override}: {e}"); - } - } - } - } else { - break; - } - } - } - _ => {} - }, - ParsedCmd::Postgres(_stmt) => { - // TODO: handle type overrides here too - // let cols = match stmt { - // PgStatement::Insert { returning, .. } - // | PgStatement::Update { returning, .. } - // | PgStatement::Delete { returning, .. } => { - // returning - // } - // PgStatement::Query(query) => { - // match *query.body { - // sqlparser::ast::SetExpr::Select( - // select, - // ) => Some(select.projection), - // _ => None, - // } - // } - // _ => None, - // }; - - // if let Some(cols) = cols { - - // } + let fields = match field_types( + &prepped, + &parsed_cmd, + FieldFormats::All(FieldFormat::Text), + ) { + Ok(fields) => fields, + Err(e) => { + back_tx.blocking_send((e.into(), true).into())?; + discard_until_sync = true; + continue 'outer; } - _ => {} - } - - let mut fields = vec![]; - for col in prepped.columns() { - let col_name = col.name(); - let col_type = - match field_type_overrides.remove(col_name) { - Some(t) => t, - None => match col.decl_type() { - None => Type::TEXT, - Some(decl_type) => { - match name_to_type(decl_type) { - Ok(t) => t, - Err(e) => { - back_tx.blocking_send( - (e.into(), true).into(), - )?; - discard_until_sync = true; - continue 'outer; - } - } - } - }, - }; - fields.push(FieldInfo::new( - col_name.to_string(), - None, - None, - col_type, - FieldFormat::Text, - )); - } + }; prepared.insert( name.into(), @@ -1108,38 +976,29 @@ pub async fn start( Some(Portal::Parsed { stmt, result_formats, + cmd, .. }) => { - let mut oids = vec![]; - let mut fields = vec![]; - for (i, col) in stmt.columns().into_iter().enumerate() { - let col_type = - match name_to_type( - col.decl_type().unwrap_or("text"), - ) { - Ok(t) => t, - Err(e) => { - back_tx.blocking_send(( + let fields = match field_types( + stmt, + cmd, + FieldFormats::Each(&result_formats), + ) { + Ok(fields) => fields, + Err(e) => { + back_tx.blocking_send( + ( PgWireBackendMessage::ErrorResponse( e.into(), ), true, - ).into())?; - continue 'outer; - } - }; - oids.push(col_type.oid()); - fields.push(FieldInfo::new( - col.name().to_string(), - None, - None, - col_type, - result_formats - .get(i) - .copied() - .unwrap_or(FieldFormat::Text), - )); - } + ) + .into(), + )?; + continue 'outer; + } + }; + back_tx.blocking_send( ( PgWireBackendMessage::RowDescription( @@ -1931,17 +1790,7 @@ impl Session { conn.prepare(&cmd.to_string())? }; - let mut fields = vec![]; - for col in prepped.columns() { - let col_type = name_to_type(col.decl_type().unwrap_or("text"))?; - fields.push(FieldInfo::new( - col.name().to_string(), - None, - None, - col_type, - FieldFormat::Text, - )); - } + let fields = field_types(&prepped, cmd, FieldFormats::All(FieldFormat::Text))?; if send_row_desc { back_tx @@ -2037,19 +1886,7 @@ impl Session { back_tx: &Sender, ) -> Result<(), QueryError> { // TODO: maybe we don't need to recompute this... - let mut fields = vec![]; - for (i, col) in prepped.columns().into_iter().enumerate() { - trace!("col decl_type: {:?}", col.decl_type()); - let col_type = name_to_type(col.decl_type().unwrap_or("any"))?; - - fields.push(FieldInfo::new( - col.name().to_string(), - None, - None, - col_type, - result_formats.get(i).copied().unwrap_or(FieldFormat::Text), - )); - } + let fields = field_types(prepped, cmd, FieldFormats::Each(result_formats))?; trace!("fields: {fields:?}"); @@ -3171,6 +3008,141 @@ fn parameter_types<'schema, 'stmt>( params } +enum FieldFormats<'a> { + All(FieldFormat), + Each(&'a [FieldFormat]), +} + +impl<'a> FieldFormats<'a> { + fn get(&self, i: usize) -> FieldFormat { + match self { + FieldFormats::All(format) => *format, + FieldFormats::Each(formats) => formats.get(i).copied().unwrap_or(FieldFormat::Text), + } + } +} + +fn field_types( + prepped: &Statement, + parsed_cmd: &ParsedCmd, + field_formats: FieldFormats<'_>, +) -> Result, UnsupportedSqliteToPostgresType> { + let mut field_type_overrides = HashMap::new(); + + match parsed_cmd { + ParsedCmd::Sqlite(Cmd::Stmt(stmt)) => match stmt { + Stmt::Select(Select { + body: + SelectBody { + select: OneSelect::Select { columns: cols, .. }, + .. + }, + .. + }) + | Stmt::Delete { + returning: Some(cols), + .. + } + | Stmt::Insert { + returning: Some(cols), + .. + } + | Stmt::Update { + returning: Some(cols), + .. + } => { + for (i, col) in cols.iter().enumerate() { + if let ResultColumn::Expr(expr, _as) = col { + let type_override = match expr { + Expr::Cast { type_name, .. } => Some(name_to_type(&type_name.name)?), + Expr::FunctionCall { name, .. } + | Expr::FunctionCallStar { name, .. } => { + match name.0.as_str().to_uppercase().as_ref() { + "COUNT" => Some(Type::INT8), + _ => None, + } + } + Expr::Literal(lit) => match lit { + Literal::Numeric(s) => Some(if s.contains('.') { + Type::FLOAT8 + } else { + Type::INT8 + }), + Literal::String(_) => Some(Type::TEXT), + Literal::Blob(_) => Some(Type::BYTEA), + Literal::Keyword(_) => None, + Literal::Null => None, + Literal::CurrentDate => Some(Type::DATE), + Literal::CurrentTime => Some(Type::TIME), + Literal::CurrentTimestamp => Some(Type::TIMESTAMP), + }, + _ => None, + }; + if let Some(type_override) = type_override { + match prepped.column_name(i) { + Ok(col_name) => { + field_type_overrides.insert(col_name, type_override); + } + Err(e) => { + error!("col index didn't exist at {i}, attempted to override type as: {type_override}: {e}"); + } + } + } + } else { + break; + } + } + } + _ => {} + }, + ParsedCmd::Postgres(_stmt) => { + // TODO: handle type overrides here too + // let cols = match stmt { + // PgStatement::Insert { returning, .. } + // | PgStatement::Update { returning, .. } + // | PgStatement::Delete { returning, .. } => { + // returning + // } + // PgStatement::Query(query) => { + // match *query.body { + // sqlparser::ast::SetExpr::Select( + // select, + // ) => Some(select.projection), + // _ => None, + // } + // } + // _ => None, + // }; + + // if let Some(cols) = cols { + + // } + } + _ => {} + } + + let mut fields = vec![]; + for (i, col) in prepped.columns().iter().enumerate() { + let col_name = col.name(); + let col_type = match field_type_overrides.remove(col_name) { + Some(t) => t, + None => match col.decl_type() { + None => Type::TEXT, + Some(decl_type) => name_to_type(decl_type)?, + }, + }; + fields.push(FieldInfo::new( + col_name.to_string(), + None, + None, + col_type, + field_formats.get(i), + )); + } + + Ok(fields) +} + #[cfg(test)] mod tests { use std::time::{Duration, Instant}; @@ -3378,7 +3350,10 @@ mod tests { assert_eq!(future, updated_at); let row = client - .query_one("SELECT COUNT(*) FROM kitchensink", &[]) + .query_one( + "SELECT COUNT(*) AS yep, COUNT(id) yeppers FROM kitchensink", + &[], + ) .await?; println!("COUNT ROW: {row:?}"); }