diff --git a/python/tests/test_extra_types.py b/python/tests/test_extra_types.py index 28ed073..c01e305 100644 --- a/python/tests/test_extra_types.py +++ b/python/tests/test_extra_types.py @@ -69,3 +69,22 @@ async def test_unset(scylla: Scylla) -> None: f"INSERT INTO {table_name}(id, name) VALUES (?, ?)", [1, extra_types.Unset()], ) + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ["typ", "val"], + [ + ("BIGINT", 1), + ("TINYINT", 1), + ("SMALLINT", 1), + ("INT", 1), + ("FLOAT", 1.0), + ("DOUBLE", 1.0), + ], +) +async def test_autocast_positional(scylla: Scylla, typ: str, val: Any) -> None: + table_name = random_string(4) + await scylla.execute(f"CREATE TABLE {table_name}(id INT PRIMARY KEY, val {typ})") + prepared = await scylla.prepare(f"INSERT INTO {table_name}(id, val) VALUES (?, ?)") + await scylla.execute(prepared, [1, val]) diff --git a/src/batches.rs b/src/batches.rs index 7fdc42d..76e7b6e 100644 --- a/src/batches.rs +++ b/src/batches.rs @@ -121,7 +121,7 @@ impl ScyllaPyInlineBatch { self.inner.append_statement(query); if let Some(passed_params) = values { self.values - .push(parse_python_query_params(Some(passed_params), false)?); + .push(parse_python_query_params(Some(passed_params), false, None)?); } else { self.values.push(SerializedValues::new()); } diff --git a/src/prepared_queries.rs b/src/prepared_queries.rs index 64623b1..65ee4f9 100644 --- a/src/prepared_queries.rs +++ b/src/prepared_queries.rs @@ -4,7 +4,7 @@ use scylla::prepared_statement::PreparedStatement; #[pyclass(name = "PreparedQuery")] #[derive(Clone, Debug)] pub struct ScyllaPyPreparedQuery { - inner: PreparedStatement, + pub inner: PreparedStatement, } impl From for ScyllaPyPreparedQuery { diff --git a/src/query_builder/delete.rs b/src/query_builder/delete.rs index 14df316..79a4c43 100644 --- a/src/query_builder/delete.rs +++ b/src/query_builder/delete.rs @@ -110,7 +110,7 @@ impl Delete { slf.where_clauses_.push(clause); if let Some(vals) = values { for value in vals { - slf.values_.push(py_to_value(value)?); + slf.values_.push(py_to_value(value, None)?); } } Ok(slf) @@ -148,7 +148,7 @@ impl Delete { ) -> ScyllaPyResult> { let parsed_values = if let Some(vals) = values { vals.iter() - .map(|item| py_to_value(item)) + .map(|item| py_to_value(item, None)) .collect::, _>>()? } else { vec![] diff --git a/src/query_builder/insert.rs b/src/query_builder/insert.rs index fb00826..ed916ac 100644 --- a/src/query_builder/insert.rs +++ b/src/query_builder/insert.rs @@ -114,7 +114,7 @@ impl Insert { if value.is_none() { slf.values_.push(ScyllaPyCQLDTO::Unset); } else { - slf.values_.push(py_to_value(value)?); + slf.values_.push(py_to_value(value, None)?); } Ok(slf) } diff --git a/src/query_builder/select.rs b/src/query_builder/select.rs index 875544b..f1683d2 100644 --- a/src/query_builder/select.rs +++ b/src/query_builder/select.rs @@ -153,7 +153,7 @@ impl Select { slf.where_clauses_.push(clause); if let Some(vals) = values { for value in vals { - slf.values_.push(py_to_value(value)?); + slf.values_.push(py_to_value(value, None)?); } } Ok(slf) diff --git a/src/query_builder/update.rs b/src/query_builder/update.rs index 43d678a..0a9120f 100644 --- a/src/query_builder/update.rs +++ b/src/query_builder/update.rs @@ -130,7 +130,7 @@ impl Update { value: &'a PyAny, ) -> ScyllaPyResult> { slf.assignments_.push(UpdateAssignment::Simple(name)); - slf.values_.push(py_to_value(value)?); + slf.values_.push(py_to_value(value, None)?); Ok(slf) } @@ -147,7 +147,7 @@ impl Update { ) -> ScyllaPyResult> { slf.assignments_ .push(UpdateAssignment::Inc(name.clone(), name)); - slf.values_.push(py_to_value(value)?); + slf.values_.push(py_to_value(value, None)?); Ok(slf) } @@ -164,7 +164,7 @@ impl Update { ) -> ScyllaPyResult> { slf.assignments_ .push(UpdateAssignment::Dec(name.clone(), name)); - slf.values_.push(py_to_value(value)?); + slf.values_.push(py_to_value(value, None)?); Ok(slf) } /// Add where clause. @@ -187,7 +187,7 @@ impl Update { slf.where_clauses_.push(clause); if let Some(vals) = values { for value in vals { - slf.where_values_.push(py_to_value(value)?); + slf.where_values_.push(py_to_value(value, None)?); } } Ok(slf) @@ -248,7 +248,7 @@ impl Update { ) -> ScyllaPyResult> { let parsed_values = if let Some(vals) = values { vals.iter() - .map(|item| py_to_value(item)) + .map(|item| py_to_value(item, None)) .collect::, _>>()? } else { vec![] diff --git a/src/scylla_cls.rs b/src/scylla_cls.rs index 320d23d..11352e7 100644 --- a/src/scylla_cls.rs +++ b/src/scylla_cls.rs @@ -290,9 +290,13 @@ impl Scylla { params: Option<&'a PyAny>, paged: bool, ) -> ScyllaPyResult<&'a PyAny> { + let mut col_spec = None; // We need to prepare parameter we're going to use // in query. - let query_params = parse_python_query_params(params, true)?; + if let ExecuteInput::PreparedQuery(prepared) = &query { + col_spec = Some(prepared.inner.get_prepared_metadata().col_specs.as_ref()); + } + let query_params = parse_python_query_params(params, true, col_spec)?; // We need this clone, to safely share the session between threads. let (query, prepared) = match query { ExecuteInput::Text(txt) => (Some(Query::new(txt)), None), @@ -322,7 +326,11 @@ impl Scylla { let mut batch_params = Vec::new(); if let Some(passed_params) = params { for query_params in passed_params { - batch_params.push(parse_python_query_params(Some(query_params), false)?); + batch_params.push(parse_python_query_params( + Some(query_params), + false, + None, + )?); } } (batch.into(), batch_params) diff --git a/src/utils.rs b/src/utils.rs index 5b4a960..3f00eea 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -9,6 +9,7 @@ use scylla::frame::{ response::result::{ColumnType, CqlValue}, value::{SerializedValues, Value}, }; +use scylla_cql::frame::response::result::ColumnSpec; use std::net::IpAddr; @@ -140,7 +141,11 @@ impl Value for ScyllaPyCQLDTO { /// /// May raise an error, if /// value cannot be converted or unnown type was passed. -pub fn py_to_value(item: &PyAny) -> ScyllaPyResult { +#[allow(clippy::too_many_lines)] +pub fn py_to_value( + item: &PyAny, + column_type: Option<&ColumnType>, +) -> ScyllaPyResult { if item.is_none() { Ok(ScyllaPyCQLDTO::Null) } else if item.is_instance_of::() { @@ -150,9 +155,20 @@ pub fn py_to_value(item: &PyAny) -> ScyllaPyResult { } else if item.is_instance_of::() { Ok(ScyllaPyCQLDTO::Bool(item.extract::()?)) } else if item.is_instance_of::() { - Ok(ScyllaPyCQLDTO::Int(item.extract::()?)) + match column_type { + Some(ColumnType::TinyInt) => Ok(ScyllaPyCQLDTO::TinyInt(item.extract::()?)), + Some(ColumnType::SmallInt) => Ok(ScyllaPyCQLDTO::SmallInt(item.extract::()?)), + Some(ColumnType::BigInt) => Ok(ScyllaPyCQLDTO::BigInt(item.extract::()?)), + Some(ColumnType::Counter) => Ok(ScyllaPyCQLDTO::Counter(item.extract::()?)), + Some(_) | None => Ok(ScyllaPyCQLDTO::Int(item.extract::()?)), + } } else if item.is_instance_of::() { - Ok(ScyllaPyCQLDTO::Float(eq_float::F32(item.extract::()?))) + match column_type { + Some(ColumnType::Double) => Ok(ScyllaPyCQLDTO::Double(eq_float::F64( + item.extract::()?, + ))), + Some(_) | None => Ok(ScyllaPyCQLDTO::Float(eq_float::F32(item.extract::()?))), + } } else if item.is_instance_of::() { Ok(ScyllaPyCQLDTO::SmallInt( item.extract::()?.get_value(), @@ -209,7 +225,7 @@ pub fn py_to_value(item: &PyAny) -> ScyllaPyResult { { let mut items = Vec::new(); for inner in item.iter()? { - items.push(py_to_value(inner?)?); + items.push(py_to_value(inner?, column_type)?); } Ok(ScyllaPyCQLDTO::List(items)) } else if item.is_instance_of::() { @@ -222,8 +238,8 @@ pub fn py_to_value(item: &PyAny) -> ScyllaPyResult { ScyllaPyError::BindingError(format!("Cannot cast to tuple: {err}")) })?; items.push(( - py_to_value(item_tuple.get_item(0)?)?, - py_to_value(item_tuple.get_item(1)?)?, + py_to_value(item_tuple.get_item(0)?, column_type)?, + py_to_value(item_tuple.get_item(1)?, column_type)?, )); } Ok(ScyllaPyCQLDTO::Map(items)) @@ -484,6 +500,7 @@ pub fn cql_to_py<'a>( pub fn parse_python_query_params( params: Option<&PyAny>, allow_dicts: bool, + col_spec: Option<&[ColumnSpec]>, ) -> ScyllaPyResult { let mut values = SerializedValues::new(); @@ -495,17 +512,30 @@ pub fn parse_python_query_params( // Otherwise it parses dict to named parameters. if params.is_instance_of::() || params.is_instance_of::() { let params = params.extract::>()?; - for param in params { - let py_dto = py_to_value(param)?; + for (index, param) in params.iter().enumerate() { + let coltype = col_spec.and_then(|specs| specs.get(index)).map(|f| &f.typ); + let py_dto = py_to_value(param, coltype)?; values.add_value(&py_dto)?; } return Ok(values); } else if params.is_instance_of::() { if allow_dicts { + let types_map = col_spec + .map(|specs| { + specs + .iter() + .map(|spec| (spec.name.as_str(), spec.typ.clone())) + .collect::>>() + }) + .unwrap_or_default(); + // let map = HashMap::with_capacity_and_hasher(, hasher) let dict = params .extract::>>()?; for (name, value) in dict { - values.add_named_value(name.to_lowercase().as_str(), &py_to_value(value)?)?; + values.add_named_value( + name.to_lowercase().as_str(), + &py_to_value(value, types_map.get(name))?, + )?; } return Ok(values); }