Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pyo3 Bound<'py, T> api #734

Merged
merged 10 commits into from
Jun 18, 2024
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ substrait = ["dep:datafusion-substrait"]
[dependencies]
tokio = { version = "1.35", features = ["macros", "rt", "rt-multi-thread", "sync"] }
rand = "0.8"
pyo3 = { version = "0.21", features = ["extension-module", "abi3", "abi3-py38", "gil-refs"] }
pyo3 = { version = "0.21", features = ["extension-module", "abi3", "abi3-py38"] }
arrow = { version = "52", feature = ["pyarrow"] }
datafusion = { version = "39.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
datafusion-common = { version = "39.0.0", features = ["pyarrow"] }
Expand Down Expand Up @@ -67,3 +67,4 @@ crate-type = ["cdylib", "rlib"]
[profile.release]
lto = true
codegen-units = 1

2 changes: 1 addition & 1 deletion src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub mod function;
pub mod schema;

/// Initializes the `common` module to match the pattern of `datafusion-common` https://docs.rs/datafusion-common/18.0.0/datafusion_common/index.html
pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<df_schema::PyDFSchema>()?;
m.add_class::<data_type::PyDataType>()?;
m.add_class::<data_type::DataTypeMap>()?;
Expand Down
2 changes: 1 addition & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl PyConfig {

/// Get all configuration options
pub fn get_all(&mut self, py: Python) -> PyResult<PyObject> {
let dict = PyDict::new(py);
let dict = PyDict::new_bound(py);
let options = self.config.to_owned();
for entry in options.entries() {
dict.set_item(entry.key, entry.value.clone().into_py(py))?;
Expand Down
23 changes: 14 additions & 9 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,11 @@ impl PySessionContext {
pub fn register_object_store(
&mut self,
scheme: &str,
store: &PyAny,
store: &Bound<'_, PyAny>,
host: Option<&str>,
) -> PyResult<()> {
let res: Result<(Arc<dyn ObjectStore>, String), PyErr> =
match StorageContexts::extract(store) {
match StorageContexts::extract_bound(store) {
Ok(store) => match store {
StorageContexts::AmazonS3(s3) => Ok((s3.inner, s3.bucket_name)),
StorageContexts::GoogleCloudStorage(gcs) => Ok((gcs.inner, gcs.bucket_name)),
Expand Down Expand Up @@ -443,8 +443,8 @@ impl PySessionContext {
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import("pyarrow")?.getattr("Table")?;
let args = PyTuple::new(py, &[data]);
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pylist", args)?.into();

// Convert Arrow Table to datafusion DataFrame
Expand All @@ -463,8 +463,8 @@ impl PySessionContext {
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import("pyarrow")?.getattr("Table")?;
let args = PyTuple::new(py, &[data]);
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pydict", args)?.into();

// Convert Arrow Table to datafusion DataFrame
Expand Down Expand Up @@ -507,8 +507,8 @@ impl PySessionContext {
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import("pyarrow")?.getattr("Table")?;
let args = PyTuple::new(py, &[data]);
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pandas", args)?.into();

// Convert Arrow Table to datafusion DataFrame
Expand Down Expand Up @@ -710,7 +710,12 @@ impl PySessionContext {
}

// Registers a PyArrow.Dataset
pub fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) -> PyResult<()> {
pub fn register_dataset(
&self,
name: &str,
dataset: &Bound<'_, PyAny>,
py: Python,
) -> PyResult<()> {
let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);

self.ctx
Expand Down
67 changes: 41 additions & 26 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use datafusion::prelude::*;
use datafusion_common::UnnestOptions;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::PyTuple;
use tokio::task::JoinHandle;

Expand Down Expand Up @@ -56,23 +57,25 @@ impl PyDataFrame {

#[pymethods]
impl PyDataFrame {
fn __getitem__(&self, key: PyObject) -> PyResult<Self> {
Python::with_gil(|py| {
if let Ok(key) = key.extract::<&str>(py) {
self.select_columns(vec![key])
} else if let Ok(tuple) = key.extract::<&PyTuple>(py) {
let keys = tuple
.iter()
.map(|item| item.extract::<&str>())
.collect::<PyResult<Vec<&str>>>()?;
self.select_columns(keys)
} else if let Ok(keys) = key.extract::<Vec<&str>>(py) {
self.select_columns(keys)
} else {
let message = "DataFrame can only be indexed by string index or indices";
Err(PyTypeError::new_err(message))
}
})
/// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1, col2, col3]]`
fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyResult<Self> {
if let Ok(key) = key.extract::<PyBackedStr>() {
// df[col]
self.select_columns(vec![key])
} else if let Ok(tuple) = key.extract::<&PyTuple>() {
// df[col1, col2, col3]
let keys = tuple
.iter()
.map(|item| item.extract::<PyBackedStr>())
.collect::<PyResult<Vec<PyBackedStr>>>()?;
self.select_columns(keys)
} else if let Ok(keys) = key.extract::<Vec<PyBackedStr>>() {
// df[[col1, col2, col3]]
self.select_columns(keys)
} else {
let message = "DataFrame can only be indexed by string index or indices";
Err(PyTypeError::new_err(message))
}
}

fn __repr__(&self, py: Python) -> PyResult<String> {
Expand All @@ -98,7 +101,8 @@ impl PyDataFrame {
}

#[pyo3(signature = (*args))]
fn select_columns(&self, args: Vec<&str>) -> PyResult<Self> {
fn select_columns(&self, args: Vec<PyBackedStr>) -> PyResult<Self> {
let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
let df = self.df.as_ref().clone().select_columns(&args)?;
Ok(Self::new(df))
}
Expand Down Expand Up @@ -194,7 +198,7 @@ impl PyDataFrame {
fn join(
&self,
right: PyDataFrame,
join_keys: (Vec<&str>, Vec<&str>),
join_keys: (Vec<PyBackedStr>, Vec<PyBackedStr>),
how: &str,
) -> PyResult<Self> {
let join_type = match how {
Expand All @@ -212,11 +216,22 @@ impl PyDataFrame {
}
};

let left_keys = join_keys
.0
.iter()
.map(|s| s.as_ref())
.collect::<Vec<&str>>();
let right_keys = join_keys
.1
.iter()
.map(|s| s.as_ref())
.collect::<Vec<&str>>();

let df = self.df.as_ref().clone().join(
right.df.as_ref().clone(),
join_type,
&join_keys.0,
&join_keys.1,
&left_keys,
&right_keys,
None,
)?;
Ok(Self::new(df))
Expand Down Expand Up @@ -414,8 +429,8 @@ impl PyDataFrame {

Python::with_gil(|py| {
// Instantiate pyarrow Table object and use its from_batches method
let table_class = py.import("pyarrow")?.getattr("Table")?;
let args = PyTuple::new(py, &[batches, schema]);
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[batches, schema]);
let table: PyObject = table_class.call_method1("from_batches", args)?.into();
Ok(table)
})
Expand Down Expand Up @@ -489,8 +504,8 @@ impl PyDataFrame {
let table = self.to_arrow_table(py)?;

Python::with_gil(|py| {
let dataframe = py.import("polars")?.getattr("DataFrame")?;
let args = PyTuple::new(py, &[table]);
let dataframe = py.import_bound("polars")?.getattr("DataFrame")?;
let args = PyTuple::new_bound(py, &[table]);
let result: PyObject = dataframe.call1(args)?.into();
Ok(result)
})
Expand All @@ -514,7 +529,7 @@ fn print_dataframe(py: Python, df: DataFrame) -> PyResult<()> {

// Import the Python 'builtins' module to access the print function
// Note that println! does not print to the Python debug console and is not visible in notebooks for instance
let print = py.import("builtins")?.getattr("print")?;
let print = py.import_bound("builtins")?.getattr("print")?;
print.call1((result,))?;
Ok(())
}
13 changes: 7 additions & 6 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@ pub(crate) struct Dataset {

impl Dataset {
// Creates a Python PyArrow.Dataset
pub fn new(dataset: &PyAny, py: Python) -> PyResult<Self> {
pub fn new(dataset: &Bound<'_, PyAny>, py: Python) -> PyResult<Self> {
// Ensure that we were passed an instance of pyarrow.dataset.Dataset
let ds = PyModule::import(py, "pyarrow.dataset")?;
let ds_type: &PyType = ds.getattr("Dataset")?.downcast()?;
let ds = PyModule::import_bound(py, "pyarrow.dataset")?;
let ds_attr = ds.getattr("Dataset")?;
let ds_type = ds_attr.downcast::<PyType>()?;
if dataset.is_instance(ds_type)? {
Ok(Dataset {
dataset: dataset.into(),
dataset: dataset.clone().unbind(),
})
} else {
Err(PyValueError::new_err(
Expand All @@ -73,7 +74,7 @@ impl TableProvider for Dataset {
/// Get a reference to the schema for this table
fn schema(&self) -> SchemaRef {
Python::with_gil(|py| {
let dataset = self.dataset.as_ref(py);
let dataset = self.dataset.bind(py);
// This can panic but since we checked that self.dataset is a pyarrow.dataset.Dataset it should never
Arc::new(
dataset
Expand Down Expand Up @@ -108,7 +109,7 @@ impl TableProvider for Dataset {
) -> DFResult<Arc<dyn ExecutionPlan>> {
Python::with_gil(|py| {
let plan: Arc<dyn ExecutionPlan> = Arc::new(
DatasetExec::new(py, self.dataset.as_ref(py), projection.cloned(), filters)
DatasetExec::new(py, self.dataset.bind(py), projection.cloned(), filters)
.map_err(|err| DataFusionError::External(Box::new(err)))?,
);
Ok(plan)
Expand Down
36 changes: 17 additions & 19 deletions src/dataset_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl Iterator for PyArrowBatchesAdapter {

fn next(&mut self) -> Option<Self::Item> {
Python::with_gil(|py| {
let mut batches: &PyIterator = self.batches.as_ref(py);
let mut batches = self.batches.clone().into_bound(py);
Some(
batches
.next()?
Expand All @@ -79,7 +79,7 @@ pub(crate) struct DatasetExec {
impl DatasetExec {
pub fn new(
py: Python,
dataset: &PyAny,
dataset: &Bound<'_, PyAny>,
projection: Option<Vec<usize>>,
filters: &[Expr],
) -> Result<Self, DataFusionError> {
Expand All @@ -103,15 +103,15 @@ impl DatasetExec {
})
.transpose()?;

let kwargs = PyDict::new(py);
let kwargs = PyDict::new_bound(py);

kwargs.set_item("columns", columns.clone())?;
kwargs.set_item(
"filter",
filter_expr.as_ref().map(|expr| expr.clone_ref(py)),
)?;

let scanner = dataset.call_method("scanner", (), Some(kwargs))?;
let scanner = dataset.call_method("scanner", (), Some(&kwargs))?;

let schema = Arc::new(
scanner
Expand All @@ -120,19 +120,17 @@ impl DatasetExec {
.0,
);

let builtins = Python::import(py, "builtins")?;
let builtins = Python::import_bound(py, "builtins")?;
let pylist = builtins.getattr("list")?;

// Get the fragments or partitions of the dataset
let fragments_iterator: &PyAny = dataset.call_method1(
let fragments_iterator: Bound<'_, PyAny> = dataset.call_method1(
"get_fragments",
(filter_expr.as_ref().map(|expr| expr.clone_ref(py)),),
)?;

let fragments: &PyList = pylist
.call1((fragments_iterator,))?
.downcast()
.map_err(PyErr::from)?;
let fragments_iter = pylist.call1((fragments_iterator,))?;
let fragments = fragments_iter.downcast::<PyList>().map_err(PyErr::from)?;

let projected_statistics = Statistics::new_unknown(&schema);
let plan_properties = datafusion::physical_plan::PlanProperties::new(
Expand All @@ -142,9 +140,9 @@ impl DatasetExec {
);

Ok(DatasetExec {
dataset: dataset.into(),
dataset: dataset.clone().unbind(),
schema,
fragments: fragments.into(),
fragments: fragments.clone().unbind(),
columns,
filter_expr,
projected_statistics,
Expand Down Expand Up @@ -183,8 +181,8 @@ impl ExecutionPlan for DatasetExec {
) -> DFResult<SendableRecordBatchStream> {
let batch_size = context.session_config().batch_size();
Python::with_gil(|py| {
let dataset = self.dataset.as_ref(py);
let fragments = self.fragments.as_ref(py);
let dataset = self.dataset.bind(py);
let fragments = self.fragments.bind(py);
let fragment = fragments
.get_item(partition)
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
Expand All @@ -193,7 +191,7 @@ impl ExecutionPlan for DatasetExec {
let dataset_schema = dataset
.getattr("schema")
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
let kwargs = PyDict::new(py);
let kwargs = PyDict::new_bound(py);
kwargs
.set_item("columns", self.columns.clone())
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
Expand All @@ -207,15 +205,15 @@ impl ExecutionPlan for DatasetExec {
.set_item("batch_size", batch_size)
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
let scanner = fragment
.call_method("scanner", (dataset_schema,), Some(kwargs))
.call_method("scanner", (dataset_schema,), Some(&kwargs))
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
let schema: SchemaRef = Arc::new(
scanner
.getattr("projected_schema")
.and_then(|schema| Ok(schema.extract::<PyArrowType<_>>()?.0))
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?,
);
let record_batches: &PyIterator = scanner
let record_batches: Bound<'_, PyIterator> = scanner
.call_method0("to_batches")
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?
.iter()
Expand Down Expand Up @@ -264,7 +262,7 @@ impl ExecutionPlanProperties for DatasetExec {
impl DisplayAs for DatasetExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
Python::with_gil(|py| {
let number_of_fragments = self.fragments.as_ref(py).len();
let number_of_fragments = self.fragments.bind(py).len();
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
let projected_columns: Vec<String> = self
Expand All @@ -274,7 +272,7 @@ impl DisplayAs for DatasetExec {
.map(|x| x.name().to_owned())
.collect();
if let Some(filter_expr) = &self.filter_expr {
let filter_expr = filter_expr.as_ref(py).str().or(Err(std::fmt::Error))?;
let filter_expr = filter_expr.bind(py).str().or(Err(std::fmt::Error))?;
write!(
f,
"DatasetExec: number_of_fragments={}, filter_expr={}, projection=[{}]",
Expand Down
2 changes: 1 addition & 1 deletion src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ impl PyExpr {
}

/// Initializes the `expr` module to match the pattern of `datafusion-expr` https://docs.rs/datafusion-expr/latest/datafusion_expr/
pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyExpr>()?;
m.add_class::<PyColumn>()?;
m.add_class::<PyLiteral>()?;
Expand Down
Loading
Loading