Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep track of column typing in SQLite EXPLAIN parsing #1323

Merged
merged 3 commits into from
Jul 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 97 additions & 19 deletions sqlx-core/src/sqlite/connection/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ const SQLITE_AFF_REAL: u8 = 0x45; /* 'E' */
const OP_INIT: &str = "Init";
const OP_GOTO: &str = "Goto";
const OP_COLUMN: &str = "Column";
const OP_MAKE_RECORD: &str = "MakeRecord";
const OP_INSERT: &str = "Insert";
const OP_IDX_INSERT: &str = "IdxInsert";
const OP_OPEN_READ: &str = "OpenRead";
const OP_OPEN_WRITE: &str = "OpenWrite";
const OP_OPEN_EPHEMERAL: &str = "OpenEphemeral";
const OP_OPEN_AUTOINDEX: &str = "OpenAutoindex";
const OP_AGG_STEP: &str = "AggStep";
const OP_FUNCTION: &str = "Function";
const OP_MOVE: &str = "Move";
Expand All @@ -34,6 +41,7 @@ const OP_BLOB: &str = "Blob";
const OP_VARIABLE: &str = "Variable";
const OP_COUNT: &str = "Count";
const OP_ROWID: &str = "Rowid";
const OP_NEWROWID: &str = "NewRowid";
const OP_OR: &str = "Or";
const OP_AND: &str = "And";
const OP_BIT_AND: &str = "BitAnd";
Expand All @@ -48,6 +56,21 @@ const OP_REMAINDER: &str = "Remainder";
const OP_CONCAT: &str = "Concat";
const OP_RESULT_ROW: &str = "ResultRow";

#[derive(Debug, Clone, Eq, PartialEq)]
enum RegDataType {
Single(DataType),
Record(Vec<DataType>),
}

impl RegDataType {
fn map_to_datatype(self) -> DataType {
match self {
RegDataType::Single(d) => d,
RegDataType::Record(_) => DataType::Null, //If we're trying to coerce to a regular Datatype, we can assume a Record is invalid for the context
}
}
}

#[allow(clippy::wildcard_in_or_patterns)]
fn affinity_to_type(affinity: u8) -> DataType {
match affinity {
Expand All @@ -73,13 +96,19 @@ fn opcode_to_type(op: &str) -> DataType {
}
}

// Opcode Reference: https://sqlite.org/opcode.html
pub(super) async fn explain(
conn: &mut SqliteConnection,
query: &str,
) -> Result<(Vec<SqliteTypeInfo>, Vec<Option<bool>>), Error> {
let mut r = HashMap::<i64, DataType>::with_capacity(6);
// Registers
let mut r = HashMap::<i64, RegDataType>::with_capacity(6);
// Map between pointer and register
let mut r_cursor = HashMap::<i64, Vec<i64>>::with_capacity(6);
// Rows that pointers point to
let mut p = HashMap::<i64, HashMap<i64, DataType>>::with_capacity(6);

// Nullable columns
let mut n = HashMap::<i64, bool>::with_capacity(6);

let program =
Expand Down Expand Up @@ -119,15 +148,52 @@ pub(super) async fn explain(
}

OP_COLUMN => {
r_cursor.entry(p1).or_default().push(p3);
//Get the row stored at p1, or NULL; get the column stored at p2, or NULL
if let Some(record) = p.get(&p1) {
if let Some(col) = record.get(&p2) {
// insert into p3 the datatype of the col
r.insert(p3, RegDataType::Single(*col));
// map between pointer p1 and register p3
r_cursor.entry(p1).or_default().push(p3);
} else {
r.insert(p3, RegDataType::Single(DataType::Null));
}
} else {
r.insert(p3, RegDataType::Single(DataType::Null));
}
}

OP_MAKE_RECORD => {
// p3 = Record([p1 .. p1 + p2])
let mut record = Vec::with_capacity(p2 as usize);
for reg in p1..p1 + p2 {
record.push(
r.get(&reg)
.map(|d| d.clone().map_to_datatype())
.unwrap_or(DataType::Null),
);
}
r.insert(p3, RegDataType::Record(record));
}

OP_INSERT | OP_IDX_INSERT => {
if let Some(RegDataType::Record(record)) = r.get(&p2) {
if let Some(row) = p.get_mut(&p1) {
// Insert the record into wherever pointer p1 is
*row = (0..).zip(record.iter().copied()).collect();
}
}
//Noop if the register p2 isn't a record, or if pointer p1 does not exist
}

// r[p3] = <value of column>
r.insert(p3, DataType::Null);
OP_OPEN_READ | OP_OPEN_WRITE | OP_OPEN_EPHEMERAL | OP_OPEN_AUTOINDEX => {
//Create a new pointer which is referenced by p1
p.insert(p1, HashMap::with_capacity(6));
}

OP_VARIABLE => {
// r[p2] = <value of variable>
r.insert(p2, DataType::Null);
r.insert(p2, RegDataType::Single(DataType::Null));
n.insert(p3, true);
}

Expand All @@ -136,7 +202,7 @@ pub(super) async fn explain(
match from_utf8(p4).map_err(Error::protocol)? {
"last_insert_rowid(0)" => {
// last_insert_rowid() -> INTEGER
r.insert(p3, DataType::Int64);
r.insert(p3, RegDataType::Single(DataType::Int64));
n.insert(p3, n.get(&p3).copied().unwrap_or(false));
}

Expand All @@ -145,9 +211,9 @@ pub(super) async fn explain(
}

OP_NULL_ROW => {
// all values of cursor X are potentially nullable
for column in &r_cursor[&p1] {
n.insert(*column, true);
// all registers that map to cursor X are potentially nullable
for register in &r_cursor[&p1] {
n.insert(*register, true);
}
}

Expand All @@ -156,9 +222,9 @@ pub(super) async fn explain(

if p4.starts_with("count(") {
// count(_) -> INTEGER
r.insert(p3, DataType::Int64);
r.insert(p3, RegDataType::Single(DataType::Int64));
n.insert(p3, n.get(&p3).copied().unwrap_or(false));
} else if let Some(v) = r.get(&p2).copied() {
} else if let Some(v) = r.get(&p2).cloned() {
// r[p3] = AGG ( r[p2] )
r.insert(p3, v);
let val = n.get(&p2).copied().unwrap_or(true);
Expand All @@ -169,13 +235,13 @@ pub(super) async fn explain(
OP_CAST => {
// affinity(r[p1])
if let Some(v) = r.get_mut(&p1) {
*v = affinity_to_type(p2 as u8);
*v = RegDataType::Single(affinity_to_type(p2 as u8));
}
}

OP_COPY | OP_MOVE | OP_SCOPY | OP_INT_COPY => {
// r[p2] = r[p1]
if let Some(v) = r.get(&p1).copied() {
if let Some(v) = r.get(&p1).cloned() {
r.insert(p2, v);

if let Some(null) = n.get(&p1).copied() {
Expand All @@ -184,15 +250,16 @@ pub(super) async fn explain(
}
}

OP_OR | OP_AND | OP_BLOB | OP_COUNT | OP_REAL | OP_STRING8 | OP_INTEGER | OP_ROWID => {
OP_OR | OP_AND | OP_BLOB | OP_COUNT | OP_REAL | OP_STRING8 | OP_INTEGER | OP_ROWID
| OP_NEWROWID => {
// r[p2] = <value of constant>
r.insert(p2, opcode_to_type(&opcode));
r.insert(p2, RegDataType::Single(opcode_to_type(&opcode)));
n.insert(p2, n.get(&p2).copied().unwrap_or(false));
}

OP_NOT => {
// r[p2] = NOT r[p1]
if let Some(a) = r.get(&p1).copied() {
if let Some(a) = r.get(&p1).cloned() {
r.insert(p2, a);
let val = n.get(&p1).copied().unwrap_or(true);
n.insert(p2, val);
Expand All @@ -202,9 +269,16 @@ pub(super) async fn explain(
OP_BIT_AND | OP_BIT_OR | OP_SHIFT_LEFT | OP_SHIFT_RIGHT | OP_ADD | OP_SUBTRACT
| OP_MULTIPLY | OP_DIVIDE | OP_REMAINDER | OP_CONCAT => {
// r[p3] = r[p1] + r[p2]
match (r.get(&p1).copied(), r.get(&p2).copied()) {
match (r.get(&p1).cloned(), r.get(&p2).cloned()) {
(Some(a), Some(b)) => {
r.insert(p3, if matches!(a, DataType::Null) { b } else { a });
r.insert(
p3,
if matches!(a, RegDataType::Single(DataType::Null)) {
b
} else {
a
},
);
}

(Some(v), None) => {
Expand Down Expand Up @@ -252,7 +326,11 @@ pub(super) async fn explain(

if let Some(result) = result {
for i in result {
output.push(SqliteTypeInfo(r.remove(&i).unwrap_or(DataType::Null)));
output.push(SqliteTypeInfo(
r.remove(&i)
.map(|d| d.map_to_datatype())
.unwrap_or(DataType::Null),
));
nullable.push(n.remove(&i));
}
}
Expand Down
15 changes: 15 additions & 0 deletions tests/sqlite/describe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,21 @@ async fn it_describes_insert_with_read_only() -> anyhow::Result<()> {
Ok(())
}

#[sqlx_macros::test]
async fn it_describes_insert_with_returning() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;

let d = conn
.describe("INSERT INTO tweet (id, text) VALUES (2, 'Hello') RETURNING *")
.await?;

assert_eq!(d.columns().len(), 4);
assert_eq!(d.column(0).type_info().name(), "INTEGER");
assert_eq!(d.column(1).type_info().name(), "TEXT");
abonander marked this conversation as resolved.
Show resolved Hide resolved

Ok(())
}

#[sqlx_macros::test]
async fn it_describes_bad_statement() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;
Expand Down