Skip to content

Commit

Permalink
Showing 3 changed files with 20 additions and 5 deletions.
5 changes: 2 additions & 3 deletions src/context.rs
Original file line number Diff line number Diff line change
@@ -36,7 +36,7 @@ use crate::sql::logical::PyLogicalPlan;
use crate::store::StorageContexts;
use crate::udaf::PyAggregateUDF;
use crate::udf::PyScalarUDF;
use crate::utils::wait_for_future;
use crate::utils::{get_tokio_runtime, wait_for_future};
use datafusion::arrow::datatypes::{DataType, Schema};
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::arrow::record_batch::RecordBatch;
@@ -52,7 +52,6 @@ use datafusion::prelude::{
};
use datafusion_common::ScalarValue;
use pyo3::types::PyTuple;
use tokio::runtime::Runtime;
use tokio::task::JoinHandle;

/// Configuration options for a SessionContext
@@ -722,7 +721,7 @@ impl PySessionContext {
Arc::new(RuntimeEnv::default()),
);
// create a Tokio runtime to run the async code
let rt = Runtime::new().unwrap();
let rt = &get_tokio_runtime(py).0;
let plan = plan.plan.clone();
let fut: JoinHandle<datafusion_common::Result<SendableRecordBatchStream>> =
rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
9 changes: 9 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -59,12 +59,21 @@ pub mod utils;
#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;

// Used to define Tokio Runtime as a Python module attribute
#[pyclass]
pub(crate) struct TokioRuntime(tokio::runtime::Runtime);

/// Low-level DataFusion internal package.
///
/// The higher-level public API is defined in pure python files under the
/// datafusion directory.
#[pymodule]
fn _internal(py: Python, m: &PyModule) -> PyResult<()> {
// Register the Tokio Runtime as a module attribute so we can reuse it
m.add(
"runtime",
TokioRuntime(tokio::runtime::Runtime::new().unwrap()),
)?;
// Register the python classes
m.add_class::<catalog::PyCatalog>()?;
m.add_class::<catalog::PyDatabase>()?;
11 changes: 9 additions & 2 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -16,19 +16,26 @@
// under the License.

use crate::errors::DataFusionError;
use crate::TokioRuntime;
use datafusion_expr::Volatility;
use pyo3::prelude::*;
use std::future::Future;
use tokio::runtime::Runtime;

/// Utility to get the Tokio Runtime from Python
pub(crate) fn get_tokio_runtime(py: Python) -> PyRef<TokioRuntime> {
let datafusion = py.import("datafusion._internal").unwrap();
datafusion.getattr("runtime").unwrap().extract().unwrap()
}

/// Utility to collect rust futures with GIL released
pub fn wait_for_future<F: Future>(py: Python, f: F) -> F::Output
where
F: Send,
F::Output: Send,
{
let rt = Runtime::new().unwrap();
py.allow_threads(|| rt.block_on(f))
let runtime: &Runtime = &get_tokio_runtime(py).0;
py.allow_threads(|| runtime.block_on(f))
}

pub(crate) fn parse_volatility(value: &str) -> Result<Volatility, DataFusionError> {

0 comments on commit 545e93e

Please sign in to comment.