Skip to content

Commit

Permalink
Added automatic type cast support.
Browse files Browse the repository at this point in the history
Signed-off-by: Pavel Kirilin <win10@list.ru>
  • Loading branch information
s3rius committed Nov 3, 2023
1 parent c34dd8a commit 878f089
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 25 deletions.
19 changes: 19 additions & 0 deletions python/tests/test_extra_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,22 @@ class TestUDT(extra_types.ScyllaPyUDT):

res = await scylla.execute(f"SELECT * FROM {table_name}")
assert res.all() == [{"id": 1, "udt_col": asdict(udt_val)}]


@pytest.mark.parametrize(
["typ", "val"],
[
("BIGINT", 1),
("TINYINT", 1),
("SMALLINT", 1),
("INT", 1),
("FLOAT", 1.0),
("DOUBLE", 1.0),
],
)
@pytest.mark.anyio
async def test_autocast_positional(scylla: Scylla, typ: str, val: Any) -> None:
table_name = random_string(4)
await scylla.execute(f"CREATE TABLE {table_name}(id INT PRIMARY KEY, val {typ})")
prepared = await scylla.prepare(f"INSERT INTO {table_name}(id, val) VALUES (?, ?)")
await scylla.execute(prepared, [1, val])
2 changes: 1 addition & 1 deletion src/batches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl ScyllaPyInlineBatch {
self.inner.append_statement(query);
if let Some(passed_params) = values {
self.values
.push(parse_python_query_params(Some(passed_params), false)?);
.push(parse_python_query_params(Some(passed_params), false, None)?);
} else {
self.values.push(SerializedValues::new());
}
Expand Down
2 changes: 1 addition & 1 deletion src/prepared_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use scylla::prepared_statement::PreparedStatement;
#[pyclass(name = "PreparedQuery")]
#[derive(Clone, Debug)]
pub struct ScyllaPyPreparedQuery {
inner: PreparedStatement,
pub inner: PreparedStatement,
}

impl From<PreparedStatement> for ScyllaPyPreparedQuery {
Expand Down
4 changes: 2 additions & 2 deletions src/query_builder/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ impl Delete {
slf.where_clauses_.push(clause);
if let Some(vals) = values {
for value in vals {
slf.values_.push(py_to_value(value)?);
slf.values_.push(py_to_value(value, None)?);
}
}
Ok(slf)
Expand Down Expand Up @@ -148,7 +148,7 @@ impl Delete {
) -> ScyllaPyResult<PyRefMut<'a, Self>> {
let parsed_values = if let Some(vals) = values {
vals.iter()
.map(|item| py_to_value(item))
.map(|item| py_to_value(item, None))
.collect::<Result<Vec<_>, _>>()?
} else {
vec![]
Expand Down
2 changes: 1 addition & 1 deletion src/query_builder/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ impl Insert {
if value.is_none() {
slf.values_.push(ScyllaPyCQLDTO::Unset);
} else {
slf.values_.push(py_to_value(value)?);
slf.values_.push(py_to_value(value, None)?);
}
Ok(slf)
}
Expand Down
2 changes: 1 addition & 1 deletion src/query_builder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ impl Select {
slf.where_clauses_.push(clause);
if let Some(vals) = values {
for value in vals {
slf.values_.push(py_to_value(value)?);
slf.values_.push(py_to_value(value, None)?);
}
}
Ok(slf)
Expand Down
10 changes: 5 additions & 5 deletions src/query_builder/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl Update {
value: &'a PyAny,
) -> ScyllaPyResult<PyRefMut<'a, Self>> {
slf.assignments_.push(UpdateAssignment::Simple(name));
slf.values_.push(py_to_value(value)?);
slf.values_.push(py_to_value(value, None)?);
Ok(slf)
}

Expand All @@ -147,7 +147,7 @@ impl Update {
) -> ScyllaPyResult<PyRefMut<'a, Self>> {
slf.assignments_
.push(UpdateAssignment::Inc(name.clone(), name));
slf.values_.push(py_to_value(value)?);
slf.values_.push(py_to_value(value, None)?);
Ok(slf)
}

Expand All @@ -164,7 +164,7 @@ impl Update {
) -> ScyllaPyResult<PyRefMut<'a, Self>> {
slf.assignments_
.push(UpdateAssignment::Dec(name.clone(), name));
slf.values_.push(py_to_value(value)?);
slf.values_.push(py_to_value(value, None)?);
Ok(slf)
}
/// Add where clause.
Expand All @@ -187,7 +187,7 @@ impl Update {
slf.where_clauses_.push(clause);
if let Some(vals) = values {
for value in vals {
slf.where_values_.push(py_to_value(value)?);
slf.where_values_.push(py_to_value(value, None)?);
}
}
Ok(slf)
Expand Down Expand Up @@ -248,7 +248,7 @@ impl Update {
) -> ScyllaPyResult<PyRefMut<'a, Self>> {
let parsed_values = if let Some(vals) = values {
vals.iter()
.map(|item| py_to_value(item))
.map(|item| py_to_value(item, None))
.collect::<Result<Vec<_>, _>>()?
} else {
vec![]
Expand Down
12 changes: 10 additions & 2 deletions src/scylla_cls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,13 @@ impl Scylla {
params: Option<&'a PyAny>,
paged: bool,
) -> ScyllaPyResult<&'a PyAny> {
let mut col_spec = None;
// We need to prepare parameter we're going to use
// in query.
let query_params = parse_python_query_params(params, true)?;
if let ExecuteInput::PreparedQuery(prepared) = &query {
col_spec = Some(prepared.inner.get_prepared_metadata().col_specs.as_ref());
}
let query_params = parse_python_query_params(params, true, col_spec)?;
// We need this clone, to safely share the session between threads.
let (query, prepared) = match query {
ExecuteInput::Text(txt) => (Some(Query::new(txt)), None),
Expand Down Expand Up @@ -322,7 +326,11 @@ impl Scylla {
let mut batch_params = Vec::new();
if let Some(passed_params) = params {
for query_params in passed_params {
batch_params.push(parse_python_query_params(Some(query_params), false)?);
batch_params.push(parse_python_query_params(
Some(query_params),
false,
None,
)?);
}
}
(batch.into(), batch_params)
Expand Down
57 changes: 45 additions & 12 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use scylla::{
},
BufMut,
};
use scylla_cql::frame::response::result::ColumnSpec;

use std::net::IpAddr;

Expand Down Expand Up @@ -150,7 +151,10 @@ impl Value for ScyllaPyCQLDTO {
/// May raise an error, if
/// value cannot be converted or unnown type was passed.
#[allow(clippy::too_many_lines)]
pub fn py_to_value(item: &PyAny) -> ScyllaPyResult<ScyllaPyCQLDTO> {
pub fn py_to_value(
item: &PyAny,
column_type: Option<&ColumnType>,
) -> ScyllaPyResult<ScyllaPyCQLDTO> {
if item.is_none() {
Ok(ScyllaPyCQLDTO::Null)
} else if item.is_instance_of::<PyString>() {
Expand All @@ -160,9 +164,20 @@ pub fn py_to_value(item: &PyAny) -> ScyllaPyResult<ScyllaPyCQLDTO> {
} else if item.is_instance_of::<PyBool>() {
Ok(ScyllaPyCQLDTO::Bool(item.extract::<bool>()?))
} else if item.is_instance_of::<PyInt>() {
Ok(ScyllaPyCQLDTO::Int(item.extract::<i32>()?))
match column_type {
Some(ColumnType::TinyInt) => Ok(ScyllaPyCQLDTO::TinyInt(item.extract::<i8>()?)),
Some(ColumnType::SmallInt) => Ok(ScyllaPyCQLDTO::SmallInt(item.extract::<i16>()?)),
Some(ColumnType::BigInt) => Ok(ScyllaPyCQLDTO::BigInt(item.extract::<i64>()?)),
Some(ColumnType::Counter) => Ok(ScyllaPyCQLDTO::Counter(item.extract::<i64>()?)),
Some(_) | None => Ok(ScyllaPyCQLDTO::Int(item.extract::<i32>()?)),
}
} else if item.is_instance_of::<PyFloat>() {
Ok(ScyllaPyCQLDTO::Float(eq_float::F32(item.extract::<f32>()?)))
match column_type {
Some(ColumnType::Double) => Ok(ScyllaPyCQLDTO::Double(eq_float::F64(
item.extract::<f64>()?,
))),
Some(_) | None => Ok(ScyllaPyCQLDTO::Float(eq_float::F32(item.extract::<f32>()?))),
}
} else if item.is_instance_of::<SmallInt>() {
Ok(ScyllaPyCQLDTO::SmallInt(
item.extract::<SmallInt>()?.get_value(),
Expand Down Expand Up @@ -198,9 +213,13 @@ pub fn py_to_value(item: &PyAny) -> ScyllaPyResult<ScyllaPyCQLDTO> {
buf.put_i32(0);
for val in dumped_py {
// Here we serialize all fields.
py_to_value(val)?.serialize(buf.as_mut()).map_err(|err| {
ScyllaPyError::BindingError(format!("Cannot serialize UDT field because of {err}"))
})?;
py_to_value(val, None)?
.serialize(buf.as_mut())
.map_err(|err| {
ScyllaPyError::BindingError(format!(
"Cannot serialize UDT field because of {err}"
))
})?;
}
// Then we calculate the size of the UDT value, cast it to i32
// and put it in the beginning of the buffer.
Expand Down Expand Up @@ -245,7 +264,7 @@ pub fn py_to_value(item: &PyAny) -> ScyllaPyResult<ScyllaPyCQLDTO> {
{
let mut items = Vec::new();
for inner in item.iter()? {
items.push(py_to_value(inner?)?);
items.push(py_to_value(inner?, column_type)?);
}
Ok(ScyllaPyCQLDTO::List(items))
} else if item.is_instance_of::<PyDict>() {
Expand All @@ -258,8 +277,8 @@ pub fn py_to_value(item: &PyAny) -> ScyllaPyResult<ScyllaPyCQLDTO> {
ScyllaPyError::BindingError(format!("Cannot cast to tuple: {err}"))
})?;
items.push((
py_to_value(item_tuple.get_item(0)?)?,
py_to_value(item_tuple.get_item(1)?)?,
py_to_value(item_tuple.get_item(0)?, column_type)?,
py_to_value(item_tuple.get_item(1)?, column_type)?,
));
}
Ok(ScyllaPyCQLDTO::Map(items))
Expand Down Expand Up @@ -553,6 +572,7 @@ pub fn cql_to_py<'a>(
pub fn parse_python_query_params(
params: Option<&PyAny>,
allow_dicts: bool,
col_spec: Option<&[ColumnSpec]>,
) -> ScyllaPyResult<SerializedValues> {
let mut values = SerializedValues::new();

Expand All @@ -564,17 +584,30 @@ pub fn parse_python_query_params(
// Otherwise it parses dict to named parameters.
if params.is_instance_of::<PyList>() || params.is_instance_of::<PyTuple>() {
let params = params.extract::<Vec<&PyAny>>()?;
for param in params {
let py_dto = py_to_value(param)?;
for (index, param) in params.iter().enumerate() {
let coltype = col_spec.and_then(|specs| specs.get(index)).map(|f| &f.typ);
let py_dto = py_to_value(param, coltype)?;
values.add_value(&py_dto)?;
}
return Ok(values);
} else if params.is_instance_of::<PyDict>() {
if allow_dicts {
let types_map = col_spec
.map(|specs| {
specs
.iter()
.map(|spec| (spec.name.as_str(), spec.typ.clone()))
.collect::<HashMap<_, _, BuildHasherDefault<rustc_hash::FxHasher>>>()
})
.unwrap_or_default();
// let map = HashMap::with_capacity_and_hasher(, hasher)
let dict = params
.extract::<HashMap<&str, &PyAny, BuildHasherDefault<rustc_hash::FxHasher>>>()?;
for (name, value) in dict {
values.add_named_value(name.to_lowercase().as_str(), &py_to_value(value)?)?;
values.add_named_value(
name.to_lowercase().as_str(),
&py_to_value(value, types_map.get(name))?,
)?;
}
return Ok(values);
}
Expand Down

0 comments on commit 878f089

Please sign in to comment.