Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse the Tokio Runtime #341

Merged
merged 1 commit into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use crate::sql::logical::PyLogicalPlan;
use crate::store::StorageContexts;
use crate::udaf::PyAggregateUDF;
use crate::udf::PyScalarUDF;
use crate::utils::wait_for_future;
use crate::utils::{get_tokio_runtime, wait_for_future};
use datafusion::arrow::datatypes::{DataType, Schema};
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::arrow::record_batch::RecordBatch;
Expand All @@ -52,7 +52,6 @@ use datafusion::prelude::{
};
use datafusion_common::ScalarValue;
use pyo3::types::PyTuple;
use tokio::runtime::Runtime;
use tokio::task::JoinHandle;

/// Configuration options for a SessionContext
Expand Down Expand Up @@ -722,7 +721,7 @@ impl PySessionContext {
Arc::new(RuntimeEnv::default()),
);
// create a Tokio runtime to run the async code
let rt = Runtime::new().unwrap();
let rt = &get_tokio_runtime(py).0;
let plan = plan.plan.clone();
let fut: JoinHandle<datafusion_common::Result<SendableRecordBatchStream>> =
rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
Expand Down
9 changes: 9 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,21 @@ pub mod utils;
#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;

// Used to define Tokio Runtime as a Python module attribute
#[pyclass]
pub(crate) struct TokioRuntime(tokio::runtime::Runtime);

/// Low-level DataFusion internal package.
///
/// The higher-level public API is defined in pure python files under the
/// datafusion directory.
#[pymodule]
fn _internal(py: Python, m: &PyModule) -> PyResult<()> {
// Register the Tokio Runtime as a module attribute so we can reuse it
m.add(
"runtime",
TokioRuntime(tokio::runtime::Runtime::new().unwrap()),
)?;
// Register the python classes
m.add_class::<catalog::PyCatalog>()?;
m.add_class::<catalog::PyDatabase>()?;
Expand Down
11 changes: 9 additions & 2 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,26 @@
// under the License.

use crate::errors::DataFusionError;
use crate::TokioRuntime;
use datafusion_expr::Volatility;
use pyo3::prelude::*;
use std::future::Future;
use tokio::runtime::Runtime;

/// Utility to get the Tokio Runtime from Python
pub(crate) fn get_tokio_runtime(py: Python) -> PyRef<TokioRuntime> {
let datafusion = py.import("datafusion._internal").unwrap();
datafusion.getattr("runtime").unwrap().extract().unwrap()
}

/// Utility to collect rust futures with GIL released
pub fn wait_for_future<F: Future>(py: Python, f: F) -> F::Output
where
F: Send,
F::Output: Send,
{
let rt = Runtime::new().unwrap();
py.allow_threads(|| rt.block_on(f))
let runtime: &Runtime = &get_tokio_runtime(py).0;
py.allow_threads(|| runtime.block_on(f))
}

pub(crate) fn parse_volatility(value: &str) -> Result<Volatility, DataFusionError> {
Expand Down