From 7240d448cf9a8ce10f694f35a3e4db433a560231 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 1 Nov 2024 09:45:25 -0400 Subject: [PATCH] Rebasing and pulling in a few changes for DF43.0 --- src/context.rs | 7 ++----- src/udf.rs | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/context.rs b/src/context.rs index b92b7345..547916c4 100644 --- a/src/context.rs +++ b/src/context.rs @@ -21,7 +21,6 @@ use std::str::FromStr; use std::sync::Arc; use arrow::array::RecordBatchReader; -use arrow::ffi::FFI_ArrowSchema; use arrow::ffi_stream::ArrowArrayStreamReader; use arrow::pyarrow::FromPyArrow; use datafusion::execution::session_state::SessionStateBuilder; @@ -37,7 +36,6 @@ use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, DataFusionError}; use crate::expr::sort_expr::PySortExpr; -use crate::expr::PyExpr; use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; use crate::sql::logical::PyLogicalPlan; @@ -56,8 +54,8 @@ use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }; +use datafusion::datasource::MemTable; use datafusion::datasource::TableProvider; -use datafusion::datasource::{provider, MemTable}; use datafusion::execution::context::{ DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext, }; @@ -574,7 +572,6 @@ impl PySessionContext { &mut self, name: &str, provider: Bound<'_, PyAny>, - py: Python, ) -> PyResult<()> { if provider.hasattr("__datafusion_table_provider__")? { let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; @@ -582,7 +579,7 @@ impl PySessionContext { // validate_pycapsule(capsule, "arrow_array_stream")?; let provider = unsafe { capsule.reference::() }; - let provider = ForeignTableProvider::new(provider); + let provider: ForeignTableProvider = provider.into(); let _ = self.ctx.register_table(name, Arc::new(provider))?; } diff --git a/src/udf.rs b/src/udf.rs index 21f6d269..4570e77a 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -97,7 +97,7 @@ impl PyScalarUDF { let function = create_udf( name, input_types.0, - Arc::new(return_type.0), + return_type.0, parse_volatility(volatility)?, to_scalar_function_impl(func), );