diff --git a/crates/corro-pg/src/lib.rs b/crates/corro-pg/src/lib.rs index 1c868682..9b0caf4a 100644 --- a/crates/corro-pg/src/lib.rs +++ b/crates/corro-pg/src/lib.rs @@ -47,8 +47,8 @@ use rusqlite::{ }; use spawn::spawn_counted; use sqlite3_parser::ast::{ - As, Cmd, ColumnDefinition, CreateTableBody, Expr, FromClause, Id, InsertBody, Limit, Name, - OneSelect, ResultColumn, Select, SelectTable, Stmt, With, + As, Cmd, ColumnDefinition, CreateTableBody, Expr, FromClause, Id, InsertBody, Limit, Literal, + Name, OneSelect, ResultColumn, Select, SelectBody, SelectTable, Stmt, With, }; use sqlparser::ast::Statement as PgStatement; use tokio::{ @@ -848,27 +848,18 @@ pub async fn start( .collect(); } - let mut fields = vec![]; - for col in prepped.columns() { - let col_type = match name_to_type( - col.decl_type().unwrap_or("text"), - ) { - Ok(t) => t, - Err(e) => { - back_tx - .blocking_send((e.into(), true).into())?; - discard_until_sync = true; - continue 'outer; - } - }; - fields.push(FieldInfo::new( - col.name().to_string(), - None, - None, - col_type, - FieldFormat::Text, - )); - } + let fields = match field_types( + &prepped, + &parsed_cmd, + FieldFormats::All(FieldFormat::Text), + ) { + Ok(fields) => fields, + Err(e) => { + back_tx.blocking_send((e.into(), true).into())?; + discard_until_sync = true; + continue 'outer; + } + }; prepared.insert( name.into(), @@ -985,38 +976,29 @@ pub async fn start( Some(Portal::Parsed { stmt, result_formats, + cmd, .. }) => { - let mut oids = vec![]; - let mut fields = vec![]; - for (i, col) in stmt.columns().into_iter().enumerate() { - let col_type = - match name_to_type( - col.decl_type().unwrap_or("text"), - ) { - Ok(t) => t, - Err(e) => { - back_tx.blocking_send(( + let fields = match field_types( + stmt, + cmd, + FieldFormats::Each(&result_formats), + ) { + Ok(fields) => fields, + Err(e) => { + back_tx.blocking_send( + ( PgWireBackendMessage::ErrorResponse( e.into(), ), true, - ).into())?; - continue 'outer; - } - }; - oids.push(col_type.oid()); - fields.push(FieldInfo::new( - col.name().to_string(), - None, - None, - col_type, - result_formats - .get(i) - .copied() - .unwrap_or(FieldFormat::Text), - )); - } + ) + .into(), + )?; + continue 'outer; + } + }; + back_tx.blocking_send( ( PgWireBackendMessage::RowDescription( @@ -1808,17 +1790,7 @@ impl Session { conn.prepare(&cmd.to_string())? }; - let mut fields = vec![]; - for col in prepped.columns() { - let col_type = name_to_type(col.decl_type().unwrap_or("text"))?; - fields.push(FieldInfo::new( - col.name().to_string(), - None, - None, - col_type, - FieldFormat::Text, - )); - } + let fields = field_types(&prepped, cmd, FieldFormats::All(FieldFormat::Text))?; if send_row_desc { back_tx @@ -1914,19 +1886,7 @@ impl Session { back_tx: &Sender, ) -> Result<(), QueryError> { // TODO: maybe we don't need to recompute this... - let mut fields = vec![]; - for (i, col) in prepped.columns().into_iter().enumerate() { - trace!("col decl_type: {:?}", col.decl_type()); - let col_type = name_to_type(col.decl_type().unwrap_or("any"))?; - - fields.push(FieldInfo::new( - col.name().to_string(), - None, - None, - col_type, - result_formats.get(i).copied().unwrap_or(FieldFormat::Text), - )); - } + let fields = field_types(prepped, cmd, FieldFormats::Each(result_formats))?; trace!("fields: {fields:?}"); @@ -3048,6 +3008,141 @@ fn parameter_types<'schema, 'stmt>( params } +enum FieldFormats<'a> { + All(FieldFormat), + Each(&'a [FieldFormat]), +} + +impl<'a> FieldFormats<'a> { + fn get(&self, i: usize) -> FieldFormat { + match self { + FieldFormats::All(format) => *format, + FieldFormats::Each(formats) => formats.get(i).copied().unwrap_or(FieldFormat::Text), + } + } +} + +fn field_types( + prepped: &Statement, + parsed_cmd: &ParsedCmd, + field_formats: FieldFormats<'_>, +) -> Result, UnsupportedSqliteToPostgresType> { + let mut field_type_overrides = HashMap::new(); + + match parsed_cmd { + ParsedCmd::Sqlite(Cmd::Stmt(stmt)) => match stmt { + Stmt::Select(Select { + body: + SelectBody { + select: OneSelect::Select { columns: cols, .. }, + .. + }, + .. + }) + | Stmt::Delete { + returning: Some(cols), + .. + } + | Stmt::Insert { + returning: Some(cols), + .. + } + | Stmt::Update { + returning: Some(cols), + .. + } => { + for (i, col) in cols.iter().enumerate() { + if let ResultColumn::Expr(expr, _as) = col { + let type_override = match expr { + Expr::Cast { type_name, .. } => Some(name_to_type(&type_name.name)?), + Expr::FunctionCall { name, .. } + | Expr::FunctionCallStar { name, .. } => { + match name.0.as_str().to_uppercase().as_ref() { + "COUNT" => Some(Type::INT8), + _ => None, + } + } + Expr::Literal(lit) => match lit { + Literal::Numeric(s) => Some(if s.contains('.') { + Type::FLOAT8 + } else { + Type::INT8 + }), + Literal::String(_) => Some(Type::TEXT), + Literal::Blob(_) => Some(Type::BYTEA), + Literal::Keyword(_) => None, + Literal::Null => None, + Literal::CurrentDate => Some(Type::DATE), + Literal::CurrentTime => Some(Type::TIME), + Literal::CurrentTimestamp => Some(Type::TIMESTAMP), + }, + _ => None, + }; + if let Some(type_override) = type_override { + match prepped.column_name(i) { + Ok(col_name) => { + field_type_overrides.insert(col_name, type_override); + } + Err(e) => { + error!("col index didn't exist at {i}, attempted to override type as: {type_override}: {e}"); + } + } + } + } else { + break; + } + } + } + _ => {} + }, + ParsedCmd::Postgres(_stmt) => { + // TODO: handle type overrides here too + // let cols = match stmt { + // PgStatement::Insert { returning, .. } + // | PgStatement::Update { returning, .. } + // | PgStatement::Delete { returning, .. } => { + // returning + // } + // PgStatement::Query(query) => { + // match *query.body { + // sqlparser::ast::SetExpr::Select( + // select, + // ) => Some(select.projection), + // _ => None, + // } + // } + // _ => None, + // }; + + // if let Some(cols) = cols { + + // } + } + _ => {} + } + + let mut fields = vec![]; + for (i, col) in prepped.columns().iter().enumerate() { + let col_name = col.name(); + let col_type = match field_type_overrides.remove(col_name) { + Some(t) => t, + None => match col.decl_type() { + None => Type::TEXT, + Some(decl_type) => name_to_type(decl_type)?, + }, + }; + fields.push(FieldInfo::new( + col_name.to_string(), + None, + None, + col_type, + field_formats.get(i), + )); + } + + Ok(fields) +} + #[cfg(test)] mod tests { use std::time::{Duration, Instant}; @@ -3253,6 +3348,14 @@ mod tests { println!("updated_at: {updated_at:?}"); assert_eq!(future, updated_at); + + let row = client + .query_one( + "SELECT COUNT(*) AS yep, COUNT(id) yeppers FROM kitchensink", + &[], + ) + .await?; + println!("COUNT ROW: {row:?}"); } tripwire_tx.send(()).await.ok();