Skip to content

Commit

Permalink
Merge branch 'datafusion-contrib:main' into mysql-write
Browse files Browse the repository at this point in the history
  • Loading branch information
gengteng authored Oct 28, 2024
2 parents 9464fd5 + a9e1469 commit 8a46ae4
Show file tree
Hide file tree
Showing 23 changed files with 1,574 additions and 151 deletions.
11 changes: 11 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file

version: 2
updates:
- package-ecosystem: "cargo" # See documentation for possible values
directory: "/" # Location of package manifests
schedule:
interval: "weekly"
9 changes: 6 additions & 3 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ jobs:
- name: Build with only mysql
run: cargo check --no-default-features --features mysql

- name: Build with only flight
run: cargo check --no-default-features --features flight

integration-test:
name: Integration Test
name: Tests
runs-on: ubuntu-latest

env:
Expand All @@ -68,5 +71,5 @@ jobs:
docker pull ${{ env.PG_DOCKER_IMAGE }}
docker pull ${{ env.MYSQL_DOCKER_IMAGE }}
- name: Run integration test
run: make test-integration
- name: Run tests
run: make test
19 changes: 11 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "datafusion-table-providers"
version = "0.2.2"
version = "0.2.3"
readme = "README.md"
edition = "2021"
repository = "https://github.com/datafusion-contrib/datafusion-table-providers"
Expand All @@ -25,7 +25,7 @@ datafusion-expr = { version = "42.0.0", optional = true }
datafusion-physical-expr = { version = "42.0.0", optional = true }
datafusion-physical-plan = { version = "42.0.0", optional = true }
datafusion-proto = { version = "42.0.0", optional = true }
datafusion-federation = { version = "0.3.0", features = ["sql"] }
datafusion-federation = { version = "0.3.0", features = ["sql"], optional = true }
duckdb = { version = "1.1.1", features = [
"bundled",
"r2d2",
Expand Down Expand Up @@ -68,25 +68,26 @@ trust-dns-resolver = "0.23.2"
url = "2.5.1"
pem = { version = "3.0.4", optional = true }
tokio-rusqlite = { version = "0.5.1", optional = true }
tonic = { version = "0.12.2", optional = true }
tonic = { version = "0.12", optional = true, features = ["tls-native-roots", "tls-webpki-roots"] }
itertools = "0.13.0"
dyn-clone = { version = "1.0.17", optional = true }
geo-types = "0.7.13"
fundu = "2.0.1"

[dev-dependencies]
anyhow = "1.0.86"
bollard = "0.16.1"
bollard = "0.17.1"
rand = "0.8.5"
reqwest = "0.12.5"
secrecy = "0.8.0"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
test-log = { version = "0.2.16", features = ["trace"] }
rstest = "0.22.0"
geozero = { version = "0.13.0", features = ["with-wkb"] }
geozero = { version = "0.14.0", features = ["with-wkb"] }
tokio-stream = { version = "0.1.15", features = ["net"] }
arrow-schema = "53.1.0"
prost = { version = "0.13"}
insta = { version = "1.40.0", features = ["filters"] }

[features]
mysql = ["dep:mysql_async", "dep:async-stream"]
Expand All @@ -105,8 +106,10 @@ flight = [
"dep:serde",
"dep:tonic",
]
duckdb-federation = ["duckdb"]
sqlite-federation = ["sqlite"]
postgres-federation = ["postgres"]
federation = ["dep:datafusion-federation"]
duckdb-federation = ["duckdb", "federation"]
sqlite-federation = ["sqlite", "federation"]
postgres-federation = ["postgres", "federation"]
mysql-federation = ["mysql", "federation"]


2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ lint:

.PHONY: test-integration
test-integration:
RUST_LOG=debug cargo test --test integration --no-default-features --features postgres,sqlite,mysql -- --nocapture
RUST_LOG=debug cargo test --test integration --no-default-features --features postgres,sqlite,mysql,flight -- --nocapture
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ let ctx = SessionContext::with_state(state);
Run the included examples to see how to use the table providers:

### DuckDB

```bash
# Read from a table in a DuckDB file
cargo run --example duckdb --features duckdb
Expand All @@ -41,6 +42,7 @@ cargo run --example duckdb_function --features duckdb
```

### SQLite

```bash
cargo run --example sqlite --features sqlite
```
Expand Down Expand Up @@ -94,6 +96,7 @@ cargo run --example mysql --features mysql
```

### Flight SQL

```bash
brew install roapi
# or
Expand Down
16 changes: 10 additions & 6 deletions examples/sqlite.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
use std::sync::Arc;

use datafusion::{prelude::SessionContext, sql::TableReference};
use datafusion_table_providers::{
sql::db_connection_pool::{sqlitepool::SqliteConnectionPoolFactory, Mode},
sqlite::SqliteTableFactory,
};
use std::sync::Arc;
use std::time::Duration;

/// This example demonstrates how to create a SqliteTableFactory and use it to create TableProviders
/// that can be registered with DataFusion.
#[tokio::main]
async fn main() {
let sqlite_pool = Arc::new(
SqliteConnectionPoolFactory::new("examples/sqlite_example.db", Mode::File)
.build()
.await
.expect("unable to create Sqlite connection pool"),
SqliteConnectionPoolFactory::new(
"examples/sqlite_example.db",
Mode::File,
Duration::default(),
)
.build()
.await
.expect("unable to create Sqlite connection pool"),
);

let sqlite_table_factory = SqliteTableFactory::new(sqlite_pool);
Expand Down
113 changes: 75 additions & 38 deletions src/flight.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

use std::any::Any;
use std::collections::HashMap;
use std::error::Error;
use std::fmt::Debug;
use std::sync::Arc;

Expand All @@ -35,12 +36,14 @@ use datafusion::datasource::TableProvider;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_expr::{CreateExternalTable, Expr, TableType};
use serde::{Deserialize, Serialize};
use tonic::transport::Channel;
use tonic::transport::{Channel, ClientTlsConfig};

pub mod codec;
mod exec;
pub mod sql;

pub use exec::enforce_schema;

/// Generic Arrow Flight data source. Requires a [FlightDriver] that allows implementors
/// to integrate any custom Flight RPC service by producing a [FlightMetadata] for some DDL.
///
Expand Down Expand Up @@ -80,7 +83,7 @@ pub mod sql;
/// CustomFlightDriver::default(),
/// ))),
/// );
/// _ = ctx.sql(
/// let _ = ctx.sql(
/// r#"
/// CREATE EXTERNAL TABLE custom_flight_table STORED AS CUSTOM_FLIGHT
/// LOCATION 'https://custom.flight.rpc'
Expand All @@ -107,16 +110,12 @@ impl FlightTableFactory {
options: HashMap<String, String>,
) -> datafusion::common::Result<FlightTable> {
let origin = entry_point.into();
let channel = Channel::from_shared(origin.clone())
.unwrap()
.connect()
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;
let channel = flight_channel(&origin).await?;
let metadata = self
.driver
.metadata(channel.clone(), &options)
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;
.map_err(to_df_err)?;
let num_rows = precision(metadata.info.total_records);
let total_byte_size = precision(metadata.info.total_bytes);
let logical_schema = metadata.schema;
Expand All @@ -136,14 +135,6 @@ impl FlightTableFactory {
}
}

fn precision(total: i64) -> Precision<usize> {
if total < 0 {
Precision::Absent
} else {
Precision::Exact(total as usize)
}
}

#[async_trait]
impl TableProviderFactory for FlightTableFactory {
async fn create(
Expand Down Expand Up @@ -177,65 +168,88 @@ pub trait FlightDriver: Sync + Send + Debug {
pub struct FlightMetadata {
/// FlightInfo object produced by the driver
info: FlightInfo,
/// Arrow schema. Can be enforced by the driver or inferred from the FlightInfo
schema: SchemaRef,
/// Various knobs that control execution
props: FlightProperties,
/// Arrow schema. Can be enforced by the driver or inferred from the FlightInfo
schema: SchemaRef,
}

impl FlightMetadata {
/// Customize everything that is in the driver's control
pub fn new(info: FlightInfo, schema: SchemaRef, props: FlightProperties) -> Self {
pub fn new(info: FlightInfo, props: FlightProperties, schema: SchemaRef) -> Self {
Self {
info,
schema,
props,
schema,
}
}

/// Customize gRPC headers
pub fn try_new(
info: FlightInfo,
grpc_headers: HashMap<String, String>,
) -> arrow_flight::error::Result<Self> {
/// Customize flight properties and try to use the FlightInfo schema
pub fn try_new(info: FlightInfo, props: FlightProperties) -> arrow_flight::error::Result<Self> {
let schema = Arc::new(info.clone().try_decode_schema()?);
let props = grpc_headers.into();
Ok(Self::new(info, schema, props))
Ok(Self::new(info, props, schema))
}
}

impl TryFrom<FlightInfo> for FlightMetadata {
type Error = FlightError;

fn try_from(info: FlightInfo) -> Result<Self, Self::Error> {
Self::try_new(info, HashMap::default())
Self::try_new(info, FlightProperties::default())
}
}

/// Meant to gradually encapsulate all sorts of knobs required
/// for controlling the protocol and query execution details.
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
pub struct FlightProperties {
unbounded_stream: bool,
grpc_headers: HashMap<String, String>,
size_limits: SizeLimits,
}

impl FlightProperties {
pub fn new(unbounded_stream: bool, grpc_headers: HashMap<String, String>) -> Self {
Self {
unbounded_stream,
grpc_headers,
}
pub fn unbounded_stream(mut self, unbounded_stream: bool) -> Self {
self.unbounded_stream = unbounded_stream;
self
}

pub fn grpc_headers(mut self, grpc_headers: HashMap<String, String>) -> Self {
self.grpc_headers = grpc_headers;
self
}

pub fn size_limits(mut self, size_limits: SizeLimits) -> Self {
self.size_limits = size_limits;
self
}
}

impl From<HashMap<String, String>> for FlightProperties {
fn from(grpc_headers: HashMap<String, String>) -> Self {
Self::new(false, grpc_headers)
/// Message size limits to be passed to the underlying gRPC library.
#[derive(Copy, Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct SizeLimits {
encoding: usize,
decoding: usize,
}

impl SizeLimits {
pub fn new(encoding: usize, decoding: usize) -> Self {
Self { encoding, decoding }
}
}

impl Default for SizeLimits {
fn default() -> Self {
Self {
// no limits
encoding: usize::MAX,
decoding: usize::MAX,
}
}
}

/// Table provider that wraps a specific flight from an Arrow Flight service
#[derive(Debug)]
pub struct FlightTable {
driver: Arc<dyn FlightDriver>,
channel: Channel,
Expand Down Expand Up @@ -270,7 +284,7 @@ impl TableProvider for FlightTable {
.driver
.metadata(self.channel.clone(), &self.options)
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;
.map_err(to_df_err)?;
Ok(Arc::new(FlightExec::try_new(
metadata,
projection,
Expand All @@ -282,3 +296,26 @@ impl TableProvider for FlightTable {
Some(self.stats.clone())
}
}

fn to_df_err<E: Error + Send + Sync + 'static>(err: E) -> DataFusionError {
DataFusionError::External(Box::new(err))
}

async fn flight_channel(source: impl Into<String>) -> datafusion::common::Result<Channel> {
let tls_config = ClientTlsConfig::new().with_enabled_roots();
Channel::from_shared(source.into())
.map_err(to_df_err)?
.tls_config(tls_config)
.map_err(to_df_err)?
.connect()
.await
.map_err(to_df_err)
}

fn precision(total: i64) -> Precision<usize> {
if total < 0 {
Precision::Absent
} else {
Precision::Exact(total as usize)
}
}
7 changes: 3 additions & 4 deletions src/flight/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
use std::sync::Arc;

use crate::flight::exec::{FlightConfig, FlightExec};
use crate::flight::to_df_err;
use datafusion::common::DataFusionError;
use datafusion_expr::registry::FunctionRegistry;
use datafusion_physical_plan::ExecutionPlan;
Expand All @@ -37,8 +38,7 @@ impl PhysicalExtensionCodec for FlightPhysicalCodec {
_registry: &dyn FunctionRegistry,
) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
if inputs.is_empty() {
let config: FlightConfig =
serde_json::from_slice(buf).map_err(|e| DataFusionError::External(Box::new(e)))?;
let config: FlightConfig = serde_json::from_slice(buf).map_err(to_df_err)?;
Ok(Arc::from(FlightExec::from(config)))
} else {
Err(DataFusionError::Internal(
Expand All @@ -53,8 +53,7 @@ impl PhysicalExtensionCodec for FlightPhysicalCodec {
buf: &mut Vec<u8>,
) -> datafusion::common::Result<()> {
if let Some(flight) = node.as_any().downcast_ref::<FlightExec>() {
let mut bytes = serde_json::to_vec(flight.config())
.map_err(|e| DataFusionError::External(Box::new(e)))?;
let mut bytes = serde_json::to_vec(flight.config()).map_err(to_df_err)?;
buf.append(&mut bytes);
Ok(())
} else {
Expand Down
Loading

0 comments on commit 8a46ae4

Please sign in to comment.