Skip to content

Commit

Permalink
feat(bindings/python): support query_iter (#280)
Browse files Browse the repository at this point in the history
  • Loading branch information
everpcpc authored Oct 27, 2023
1 parent db6b232 commit 8151d51
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 7 deletions.
1 change: 1 addition & 0 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ databend-driver = { workspace = true, features = ["rustls", "flight-sql"] }
pyo3 = { version = "0.19", features = ["abi3-py37"] }
pyo3-asyncio = { version = "0.19", features = ["tokio-runtime"] }
tokio = "1.28"
tokio-stream = "0.1"
5 changes: 5 additions & 0 deletions bindings/python/package/databend_driver/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@
class Row:
def values(self) -> tuple: ...

class RowIterator:
def __aiter__(self) -> RowIterator: ...
async def __anext__(self) -> Row: ...

# flake8: noqa
class AsyncDatabendConnection:
async def exec(self, sql: str) -> int: ...
async def query_row(self, sql: str) -> Row: ...
async def query_iter(self, sql: str) -> RowIterator: ...

# flake8: noqa
class AsyncDatabendClient:
Expand Down
37 changes: 36 additions & 1 deletion bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;

use pyo3::create_exception;
use pyo3::exceptions::PyException;
use pyo3::exceptions::{PyException, PyStopAsyncIteration};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyTuple};
use pyo3_asyncio::tokio::future_into_py;
use tokio::sync::Mutex;
use tokio_stream::StreamExt;

create_exception!(
databend_client,
Expand Down Expand Up @@ -74,6 +78,14 @@ impl AsyncDatabendConnection {
Ok(Row(row))
})
}

pub fn query_iter<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult<&'p PyAny> {
let this = self.0.clone();
future_into_py(py, async move {
let streamer = this.query_iter(&sql).await.unwrap();
Ok(RowIterator(Arc::new(Mutex::new(streamer))))
})
}
}

#[pyclass(module = "databend_driver")]
Expand Down Expand Up @@ -169,3 +181,26 @@ impl IntoPy<PyObject> for NumberValue {
}
}
}

#[pyclass(module = "databend_driver")]
pub struct RowIterator(Arc<Mutex<databend_driver::RowIterator>>);

#[pymethods]
impl RowIterator {
fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
fn __anext__<'a>(&self, py: Python<'a>) -> PyResult<Option<PyObject>> {
let streamer = self.0.clone();
let future = future_into_py(py, async move {
match streamer.lock().await.next().await {
Some(val) => match val {
Err(e) => Err(PyException::new_err(format!("{}", e))),
Ok(ret) => Ok(Row(ret)),
},
None => Err(PyStopAsyncIteration::new_err("The iterator is exhausted")),
}
});
Ok(Some(future?.into()))
}
}
44 changes: 38 additions & 6 deletions bindings/python/tests/steps/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,20 @@ async def _(context):
@when("Create a test table")
@async_run_until_complete
async def _(context):
# TODO:
pass
await context.conn.exec("DROP TABLE IF EXISTS test")
await context.conn.exec(
"""
CREATE TABLE test (
i64 Int64,
u64 UInt64,
f64 Float64,
s String,
s2 String,
d Date,
t DateTime
)
"""
)


@then("Select string {input} should be equal to {output}")
Expand All @@ -47,15 +59,35 @@ async def _(context, input, output):
@then("Select numbers should iterate all rows")
@async_run_until_complete
async def _(context):
# TODO:
pass
rows = await context.conn.query_iter("SELECT number FROM numbers(5)")
ret = []
async for row in rows:
ret.append(row.values()[0])
expected = [0, 1, 2, 3, 4]
assert ret == expected


@then("Insert and Select should be equal")
@async_run_until_complete
async def _(context):
# TODO:
pass
await context.conn.exec(
"""
INSERT INTO test VALUES
(-1, 1, 1.0, '1', '1', '2011-03-06', '2011-03-06 06:20:00'),
(-2, 2, 2.0, '2', '2', '2012-05-31', '2012-05-31 11:20:00'),
(-3, 3, 3.0, '3', '2', '2016-04-04', '2016-04-04 11:30:00')
"""
)
rows = await context.conn.query_iter("SELECT * FROM test")
ret = []
async for row in rows:
ret.append(row.values())
expected = [
(-1, 1, 1.0, "1", "1", "2011-03-06", "2011-03-06 06:20:00"),
(-2, 2, 2.0, "2", "2", "2012-05-31", "2012-05-31 11:20:00"),
(-3, 3, 3.0, "3", "2", "2016-04-04", "2016-04-04 11:30:00"),
]
assert ret == expected


@then("Stream load and Select should be equal")
Expand Down

0 comments on commit 8151d51

Please sign in to comment.