Skip to content

Commit

Permalink
feat: use pyo3-asyncio to get a fresh tokio runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
PengLiVectra committed Nov 29, 2023
1 parent 8ca8d65 commit f71515c
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 35 deletions.
4 changes: 4 additions & 0 deletions python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ reqwest = { version = "*", features = ["native-tls-vendored"] }
version = "0.20"
features = ["extension-module", "abi3", "abi3-py38"]

[dependencies.pyo3-asyncio]
version = "0.20"
features = ["tokio-runtime"]

[dependencies.deltalake]
path = "../crates/deltalake"
version = "0"
Expand Down
71 changes: 36 additions & 35 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod error;
mod filesystem;
mod schema;
mod utils;
extern crate pyo3_asyncio;

use std::collections::{HashMap, HashSet};
use std::convert::TryFrom;
Expand Down Expand Up @@ -52,10 +53,15 @@ use crate::filesystem::FsConfig;
use crate::schema::schema_to_pyobject;

#[inline]
fn rt() -> PyResult<tokio::runtime::Runtime> {
fn rt_pyo3() -> PyResult<tokio::runtime::Runtime> {
tokio::runtime::Runtime::new().map_err(|err| PyRuntimeError::new_err(err.to_string()))
}

#[inline]
fn rt() -> &'static tokio::runtime::Runtime {
pyo3_asyncio::tokio::get_runtime()
}

#[derive(FromPyObject)]
enum PartitionFilterValue<'a> {
Single(&'a str),
Expand Down Expand Up @@ -113,7 +119,7 @@ impl RawDeltaTable {
.map_err(PythonError::from)?;
}

let table = rt()?.block_on(builder.load()).map_err(PythonError::from)?;
let table = rt().block_on(builder.load()).map_err(PythonError::from)?;
Ok(RawDeltaTable {
_table: table,
_config: FsConfig {
Expand All @@ -135,7 +141,7 @@ impl RawDeltaTable {
) -> PyResult<String> {
let data_catalog = deltalake::data_catalog::get_data_catalog(data_catalog, catalog_options)
.map_err(|e| PyValueError::new_err(format!("{}", e)))?;
let table_uri = rt()?
let table_uri = rt()
.block_on(data_catalog.get_table_storage_location(
data_catalog_id,
database_name,
Expand Down Expand Up @@ -174,13 +180,13 @@ impl RawDeltaTable {
}

pub fn load_version(&mut self, version: i64) -> PyResult<()> {
Ok(rt()?
Ok(rt()
.block_on(self._table.load_version(version))
.map_err(PythonError::from)?)
}

pub fn get_latest_version(&mut self) -> PyResult<i64> {
Ok(rt()?
Ok(rt()
.block_on(self._table.get_latest_version())
.map_err(PythonError::from)?)
}
Expand All @@ -190,7 +196,7 @@ impl RawDeltaTable {
DateTime::<Utc>::from(DateTime::<FixedOffset>::parse_from_rfc3339(ds).map_err(
|err| PyValueError::new_err(format!("Failed to parse datetime string: {err}")),
)?);
Ok(rt()?
Ok(rt()
.block_on(self._table.load_with_datetime(datetime))
.map_err(PythonError::from)?)
}
Expand Down Expand Up @@ -280,7 +286,7 @@ impl RawDeltaTable {
if let Some(retention_period) = retention_hours {
cmd = cmd.with_retention_period(Duration::hours(retention_period as i64));
}
let (table, metrics) = rt()?
let (table, metrics) = rt()
.block_on(cmd.into_future())
.map_err(PythonError::from)?;
self._table.state = table.state;
Expand Down Expand Up @@ -333,7 +339,7 @@ impl RawDeltaTable {
cmd = cmd.with_predicate(update_predicate);
}

let (table, metrics) = rt()?
let (table, metrics) = rt()
.block_on(cmd.into_future())
.map_err(PythonError::from)?;
self._table.state = table.state;
Expand Down Expand Up @@ -361,7 +367,7 @@ impl RawDeltaTable {
.map_err(PythonError::from)?;
cmd = cmd.with_filters(&converted_filters);

let (table, metrics) = rt()?
let (table, metrics) = rt()
.block_on(cmd.into_future())
.map_err(PythonError::from)?;
self._table.state = table.state;
Expand Down Expand Up @@ -394,7 +400,7 @@ impl RawDeltaTable {
.map_err(PythonError::from)?;
cmd = cmd.with_filters(&converted_filters);

let (table, metrics) = rt()?
let (table, metrics) = rt()
.block_on(cmd.into_future())
.map_err(PythonError::from)?;
self._table.state = table.state;
Expand Down Expand Up @@ -593,7 +599,7 @@ impl RawDeltaTable {
}
}

let (table, metrics) = rt()?
let (table, metrics) = rt()
.block_on(cmd.into_future())
.map_err(PythonError::from)?;
self._table.state = table.state;
Expand Down Expand Up @@ -624,7 +630,7 @@ impl RawDeltaTable {
}
cmd = cmd.with_ignore_missing_files(ignore_missing_files);
cmd = cmd.with_protocol_downgrade_allowed(protocol_downgrade_allowed);
let (table, metrics) = rt()?
let (table, metrics) = rt()
.block_on(cmd.into_future())
.map_err(PythonError::from)?;
self._table.state = table.state;
Expand All @@ -633,7 +639,7 @@ impl RawDeltaTable {

/// Run the History command on the Delta Table: Returns provenance information, including the operation, user, and so on, for each write to a table.
pub fn history(&mut self, limit: Option<usize>) -> PyResult<Vec<String>> {
let history = rt()?
let history = rt()
.block_on(self._table.history(limit))
.map_err(PythonError::from)?;
Ok(history
Expand All @@ -643,7 +649,7 @@ impl RawDeltaTable {
}

pub fn update_incremental(&mut self) -> PyResult<()> {
Ok(rt()?
Ok(rt()
.block_on(self._table.update_incremental(None))
.map_err(PythonError::from)?)
}
Expand Down Expand Up @@ -825,39 +831,36 @@ impl RawDeltaTable {
};
let store = self._table.log_store();

rt()?
.block_on(commit(
&*store,
&actions,
operation,
self._table.get_state(),
None,
))
.map_err(PythonError::from)?;
rt().block_on(commit(
&*store,
&actions,
operation,
self._table.get_state(),
None,
))
.map_err(PythonError::from)?;

Ok(())
}

pub fn get_py_storage_backend(&self) -> PyResult<filesystem::DeltaFileSystemHandler> {
Ok(filesystem::DeltaFileSystemHandler {
inner: self._table.object_store(),
rt: Arc::new(rt()?),
rt: Arc::new(rt_pyo3()?),
config: self._config.clone(),
known_sizes: None,
})
}

pub fn create_checkpoint(&self) -> PyResult<()> {
rt()?
.block_on(create_checkpoint(&self._table))
rt().block_on(create_checkpoint(&self._table))
.map_err(PythonError::from)?;

Ok(())
}

pub fn cleanup_metadata(&self) -> PyResult<()> {
rt()?
.block_on(cleanup_metadata(&self._table))
rt().block_on(cleanup_metadata(&self._table))
.map_err(PythonError::from)?;

Ok(())
Expand All @@ -879,7 +882,7 @@ impl RawDeltaTable {
if let Some(predicate) = predicate {
cmd = cmd.with_predicate(predicate);
}
let (table, metrics) = rt()?
let (table, metrics) = rt()
.block_on(cmd.into_future())
.map_err(PythonError::from)?;
self._table.state = table.state;
Expand All @@ -893,7 +896,7 @@ impl RawDeltaTable {
let cmd = FileSystemCheckBuilder::new(self._table.log_store(), self._table.state.clone())
.with_dry_run(dry_run);

let (table, metrics) = rt()?
let (table, metrics) = rt()
.block_on(cmd.into_future())
.map_err(PythonError::from)?;
self._table.state = table.state;
Expand Down Expand Up @@ -1080,7 +1083,7 @@ fn batch_distinct(batch: PyArrowType<RecordBatch>) -> PyResult<PyArrowType<Recor
let schema = batch.0.schema();
ctx.register_batch("batch", batch.0)
.map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
let batches = rt()?
let batches = rt()
.block_on(async { ctx.table("batch").await?.distinct()?.collect().await })
.map_err(|err| PyRuntimeError::new_err(err.to_string()))?;

Expand Down Expand Up @@ -1164,8 +1167,7 @@ fn write_new_deltalake(
builder = builder.with_configuration(config);
};

rt()?
.block_on(builder.into_future())
rt().block_on(builder.into_future())
.map_err(PythonError::from)?;

Ok(())
Expand Down Expand Up @@ -1217,8 +1219,7 @@ fn convert_to_deltalake(
builder = builder.with_metadata(json_metadata);
};

rt()?
.block_on(builder.into_future())
rt().block_on(builder.into_future())
.map_err(PythonError::from)?;
Ok(())
}
Expand Down

0 comments on commit f71515c

Please sign in to comment.