Skip to content

Commit

Permalink
Merge pull request #253 from superfly/pg-override-types
Browse files Browse the repository at this point in the history
Override some PG field types from a parsed query
  • Loading branch information
somtochiama committed Aug 26, 2024
2 parents 8617910 + 03daefd commit ddc4184
Showing 1 changed file with 176 additions and 73 deletions.
249 changes: 176 additions & 73 deletions crates/corro-pg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ use rusqlite::{
};
use spawn::spawn_counted;
use sqlite3_parser::ast::{
As, Cmd, ColumnDefinition, CreateTableBody, Expr, FromClause, Id, InsertBody, Limit, Name,
OneSelect, ResultColumn, Select, SelectTable, Stmt, With,
As, Cmd, ColumnDefinition, CreateTableBody, Expr, FromClause, Id, InsertBody, Limit, Literal,
Name, OneSelect, ResultColumn, Select, SelectBody, SelectTable, Stmt, With,
};
use sqlparser::ast::Statement as PgStatement;
use tokio::{
Expand Down Expand Up @@ -848,27 +848,18 @@ pub async fn start(
.collect();
}

let mut fields = vec![];
for col in prepped.columns() {
let col_type = match name_to_type(
col.decl_type().unwrap_or("text"),
) {
Ok(t) => t,
Err(e) => {
back_tx
.blocking_send((e.into(), true).into())?;
discard_until_sync = true;
continue 'outer;
}
};
fields.push(FieldInfo::new(
col.name().to_string(),
None,
None,
col_type,
FieldFormat::Text,
));
}
let fields = match field_types(
&prepped,
&parsed_cmd,
FieldFormats::All(FieldFormat::Text),
) {
Ok(fields) => fields,
Err(e) => {
back_tx.blocking_send((e.into(), true).into())?;
discard_until_sync = true;
continue 'outer;
}
};

prepared.insert(
name.into(),
Expand Down Expand Up @@ -985,38 +976,29 @@ pub async fn start(
Some(Portal::Parsed {
stmt,
result_formats,
cmd,
..
}) => {
let mut oids = vec![];
let mut fields = vec![];
for (i, col) in stmt.columns().into_iter().enumerate() {
let col_type =
match name_to_type(
col.decl_type().unwrap_or("text"),
) {
Ok(t) => t,
Err(e) => {
back_tx.blocking_send((
let fields = match field_types(
stmt,
cmd,
FieldFormats::Each(&result_formats),
) {
Ok(fields) => fields,
Err(e) => {
back_tx.blocking_send(
(
PgWireBackendMessage::ErrorResponse(
e.into(),
),
true,
).into())?;
continue 'outer;
}
};
oids.push(col_type.oid());
fields.push(FieldInfo::new(
col.name().to_string(),
None,
None,
col_type,
result_formats
.get(i)
.copied()
.unwrap_or(FieldFormat::Text),
));
}
)
.into(),
)?;
continue 'outer;
}
};

back_tx.blocking_send(
(
PgWireBackendMessage::RowDescription(
Expand Down Expand Up @@ -1808,17 +1790,7 @@ impl Session {
conn.prepare(&cmd.to_string())?
};

let mut fields = vec![];
for col in prepped.columns() {
let col_type = name_to_type(col.decl_type().unwrap_or("text"))?;
fields.push(FieldInfo::new(
col.name().to_string(),
None,
None,
col_type,
FieldFormat::Text,
));
}
let fields = field_types(&prepped, cmd, FieldFormats::All(FieldFormat::Text))?;

if send_row_desc {
back_tx
Expand Down Expand Up @@ -1914,19 +1886,7 @@ impl Session {
back_tx: &Sender<BackendResponse>,
) -> Result<(), QueryError> {
// TODO: maybe we don't need to recompute this...
let mut fields = vec![];
for (i, col) in prepped.columns().into_iter().enumerate() {
trace!("col decl_type: {:?}", col.decl_type());
let col_type = name_to_type(col.decl_type().unwrap_or("any"))?;

fields.push(FieldInfo::new(
col.name().to_string(),
None,
None,
col_type,
result_formats.get(i).copied().unwrap_or(FieldFormat::Text),
));
}
let fields = field_types(prepped, cmd, FieldFormats::Each(result_formats))?;

trace!("fields: {fields:?}");

Expand Down Expand Up @@ -3048,6 +3008,141 @@ fn parameter_types<'schema, 'stmt>(
params
}

enum FieldFormats<'a> {
All(FieldFormat),
Each(&'a [FieldFormat]),
}

impl<'a> FieldFormats<'a> {
fn get(&self, i: usize) -> FieldFormat {
match self {
FieldFormats::All(format) => *format,
FieldFormats::Each(formats) => formats.get(i).copied().unwrap_or(FieldFormat::Text),
}
}
}

fn field_types(
prepped: &Statement,
parsed_cmd: &ParsedCmd,
field_formats: FieldFormats<'_>,
) -> Result<Vec<FieldInfo>, UnsupportedSqliteToPostgresType> {
let mut field_type_overrides = HashMap::new();

match parsed_cmd {
ParsedCmd::Sqlite(Cmd::Stmt(stmt)) => match stmt {
Stmt::Select(Select {
body:
SelectBody {
select: OneSelect::Select { columns: cols, .. },
..
},
..
})
| Stmt::Delete {
returning: Some(cols),
..
}
| Stmt::Insert {
returning: Some(cols),
..
}
| Stmt::Update {
returning: Some(cols),
..
} => {
for (i, col) in cols.iter().enumerate() {
if let ResultColumn::Expr(expr, _as) = col {
let type_override = match expr {
Expr::Cast { type_name, .. } => Some(name_to_type(&type_name.name)?),
Expr::FunctionCall { name, .. }
| Expr::FunctionCallStar { name, .. } => {
match name.0.as_str().to_uppercase().as_ref() {
"COUNT" => Some(Type::INT8),
_ => None,
}
}
Expr::Literal(lit) => match lit {
Literal::Numeric(s) => Some(if s.contains('.') {
Type::FLOAT8
} else {
Type::INT8
}),
Literal::String(_) => Some(Type::TEXT),
Literal::Blob(_) => Some(Type::BYTEA),
Literal::Keyword(_) => None,
Literal::Null => None,
Literal::CurrentDate => Some(Type::DATE),
Literal::CurrentTime => Some(Type::TIME),
Literal::CurrentTimestamp => Some(Type::TIMESTAMP),
},
_ => None,
};
if let Some(type_override) = type_override {
match prepped.column_name(i) {
Ok(col_name) => {
field_type_overrides.insert(col_name, type_override);
}
Err(e) => {
error!("col index didn't exist at {i}, attempted to override type as: {type_override}: {e}");
}
}
}
} else {
break;
}
}
}
_ => {}
},
ParsedCmd::Postgres(_stmt) => {
// TODO: handle type overrides here too
// let cols = match stmt {
// PgStatement::Insert { returning, .. }
// | PgStatement::Update { returning, .. }
// | PgStatement::Delete { returning, .. } => {
// returning
// }
// PgStatement::Query(query) => {
// match *query.body {
// sqlparser::ast::SetExpr::Select(
// select,
// ) => Some(select.projection),
// _ => None,
// }
// }
// _ => None,
// };

// if let Some(cols) = cols {

// }
}
_ => {}
}

let mut fields = vec![];
for (i, col) in prepped.columns().iter().enumerate() {
let col_name = col.name();
let col_type = match field_type_overrides.remove(col_name) {
Some(t) => t,
None => match col.decl_type() {
None => Type::TEXT,
Some(decl_type) => name_to_type(decl_type)?,
},
};
fields.push(FieldInfo::new(
col_name.to_string(),
None,
None,
col_type,
field_formats.get(i),
));
}

Ok(fields)
}

#[cfg(test)]
mod tests {
use std::time::{Duration, Instant};
Expand Down Expand Up @@ -3253,6 +3348,14 @@ mod tests {
println!("updated_at: {updated_at:?}");

assert_eq!(future, updated_at);

let row = client
.query_one(
"SELECT COUNT(*) AS yep, COUNT(id) yeppers FROM kitchensink",
&[],
)
.await?;
println!("COUNT ROW: {row:?}");
}

tripwire_tx.send(()).await.ok();
Expand Down

0 comments on commit ddc4184

Please sign in to comment.