Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(bindings/python): support query_iter #280

Merged
merged 1 commit into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ databend-driver = { workspace = true, features = ["rustls", "flight-sql"] }
pyo3 = { version = "0.19", features = ["abi3-py37"] }
pyo3-asyncio = { version = "0.19", features = ["tokio-runtime"] }
tokio = "1.28"
tokio-stream = "0.1"
5 changes: 5 additions & 0 deletions bindings/python/package/databend_driver/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@
class Row:
def values(self) -> tuple: ...

class RowIterator:
def __aiter__(self) -> RowIterator: ...
async def __anext__(self) -> Row: ...

# flake8: noqa
class AsyncDatabendConnection:
async def exec(self, sql: str) -> int: ...
async def query_row(self, sql: str) -> Row: ...
async def query_iter(self, sql: str) -> RowIterator: ...

# flake8: noqa
class AsyncDatabendClient:
Expand Down
37 changes: 36 additions & 1 deletion bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;

use pyo3::create_exception;
use pyo3::exceptions::PyException;
use pyo3::exceptions::{PyException, PyStopAsyncIteration};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyTuple};
use pyo3_asyncio::tokio::future_into_py;
use tokio::sync::Mutex;
use tokio_stream::StreamExt;

create_exception!(
databend_client,
Expand Down Expand Up @@ -74,6 +78,14 @@ impl AsyncDatabendConnection {
Ok(Row(row))
})
}

pub fn query_iter<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult<&'p PyAny> {
let this = self.0.clone();
future_into_py(py, async move {
let streamer = this.query_iter(&sql).await.unwrap();
Ok(RowIterator(Arc::new(Mutex::new(streamer))))
})
}
}

#[pyclass(module = "databend_driver")]
Expand Down Expand Up @@ -169,3 +181,26 @@ impl IntoPy<PyObject> for NumberValue {
}
}
}

#[pyclass(module = "databend_driver")]
pub struct RowIterator(Arc<Mutex<databend_driver::RowIterator>>);

#[pymethods]
impl RowIterator {
fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
fn __anext__<'a>(&self, py: Python<'a>) -> PyResult<Option<PyObject>> {
let streamer = self.0.clone();
let future = future_into_py(py, async move {
match streamer.lock().await.next().await {
Some(val) => match val {
Err(e) => Err(PyException::new_err(format!("{}", e))),
Ok(ret) => Ok(Row(ret)),
},
None => Err(PyStopAsyncIteration::new_err("The iterator is exhausted")),
}
});
Ok(Some(future?.into()))
}
}
44 changes: 38 additions & 6 deletions bindings/python/tests/steps/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,20 @@ async def _(context):
@when("Create a test table")
@async_run_until_complete
async def _(context):
# TODO:
pass
await context.conn.exec("DROP TABLE IF EXISTS test")
await context.conn.exec(
"""
CREATE TABLE test (
i64 Int64,
u64 UInt64,
f64 Float64,
s String,
s2 String,
d Date,
t DateTime
)
"""
)


@then("Select string {input} should be equal to {output}")
Expand All @@ -47,15 +59,35 @@ async def _(context, input, output):
@then("Select numbers should iterate all rows")
@async_run_until_complete
async def _(context):
# TODO:
pass
rows = await context.conn.query_iter("SELECT number FROM numbers(5)")
ret = []
async for row in rows:
ret.append(row.values()[0])
expected = [0, 1, 2, 3, 4]
assert ret == expected


@then("Insert and Select should be equal")
@async_run_until_complete
async def _(context):
# TODO:
pass
await context.conn.exec(
"""
INSERT INTO test VALUES
(-1, 1, 1.0, '1', '1', '2011-03-06', '2011-03-06 06:20:00'),
(-2, 2, 2.0, '2', '2', '2012-05-31', '2012-05-31 11:20:00'),
(-3, 3, 3.0, '3', '2', '2016-04-04', '2016-04-04 11:30:00')
"""
)
rows = await context.conn.query_iter("SELECT * FROM test")
ret = []
async for row in rows:
ret.append(row.values())
expected = [
(-1, 1, 1.0, "1", "1", "2011-03-06", "2011-03-06 06:20:00"),
(-2, 2, 2.0, "2", "2", "2012-05-31", "2012-05-31 11:20:00"),
(-3, 3, 3.0, "3", "2", "2016-04-04", "2016-04-04 11:30:00"),
]
assert ret == expected


@then("Stream load and Select should be equal")
Expand Down