Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Pushdown SQL Server WHERE clause #2316

Merged
merged 6 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/datasources/src/common/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub enum Datasource {
BigQuery,
Snowflake,
Clickhouse,
SqlServer,
}

/// Returns true if the literal expression encoding should be wrapped inside
Expand Down
4 changes: 4 additions & 0 deletions crates/datasources/src/sqlserver/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ pub enum SqlServerError {
Io(#[from] std::io::Error),
#[error(transparent)]
Arrow(#[from] datafusion::arrow::error::ArrowError),
#[error(transparent)]
Fmt(#[from] std::fmt::Error),
#[error(transparent)]
DatasourceCommon(#[from] crate::common::errors::DatasourceCommonError),
}

pub type Result<T, E = SqlServerError> = std::result::Result<T, E>;
166 changes: 147 additions & 19 deletions crates/datasources/src/sqlserver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ pub mod errors;

mod client;

use chrono::{DateTime, Utc};
use client::{Client, QueryStream};

use async_trait::async_trait;
use chrono::naive::NaiveDateTime;
use chrono::{DateTime, Utc};
use client::{Client, QueryStream};
use datafusion::arrow::datatypes::{
DataType, Field, Fields, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, TimeUnit,
};
Expand All @@ -15,7 +14,7 @@ use datafusion::datasource::TableProvider;
use datafusion::error::{DataFusionError, Result as DatafusionResult};
use datafusion::execution::context::SessionState;
use datafusion::execution::context::TaskContext;
use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableType};
use datafusion::logical_expr::{BinaryExpr, Expr, TableProviderFilterPushDown, TableType};
use datafusion::physical_expr::PhysicalSortExpr;
use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
use datafusion::physical_plan::metrics::MetricsSet;
Expand All @@ -28,18 +27,22 @@ use datafusion_ext::functions::VirtualLister;
use datafusion_ext::metrics::DataSourceMetricsStreamAdapter;
use errors::{Result, SqlServerError};
use futures::{future::BoxFuture, ready, stream::BoxStream, FutureExt, Stream, StreamExt};
use std::any::Any;
use std::fmt;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tiberius::FromSql;
use tokio::net::TcpStream;
use tokio::task::JoinHandle;
use tokio_util::compat::TokioAsyncWriteCompatExt;
use tracing::warn;

use std::any::Any;
use std::collections::HashMap;
use std::fmt::{self, Write};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

use crate::common::util;

/// Timeout when attempting to connecting to the remote server.
pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);

Expand Down Expand Up @@ -115,8 +118,12 @@ impl SqlServerAccessState {
})
}

/// Get the arrow schema for a table.
async fn get_table_schema(&self, schema: &str, name: &str) -> Result<ArrowSchema> {
/// Get the arrow schema and sql server schema for a table.
async fn get_table_schema(
&self,
schema: &str,
name: &str,
) -> Result<(ArrowSchema, Vec<tiberius::Column>)> {
let mut query = self
.client
.query(format!("SELECT * FROM {schema}.{name} WHERE 1=0"))
Expand All @@ -133,6 +140,7 @@ impl SqlServerAccessState {
};

let mut fields = Vec::with_capacity(cols.len());

for col in cols {
use tiberius::ColumnType;

Expand Down Expand Up @@ -179,7 +187,7 @@ impl SqlServerAccessState {
fields.push(field);
}

Ok(ArrowSchema::new(fields))
Ok((ArrowSchema::new(fields), cols.to_vec()))
}
}

Expand Down Expand Up @@ -232,7 +240,7 @@ impl VirtualLister for SqlServerAccessState {
async fn list_columns(&self, schema: &str, table: &str) -> Result<Fields, ExtensionError> {
use ExtensionError::ListingErrBoxed;

let schema = self
let (schema, _) = self
.get_table_schema(schema, table)
.await
.map_err(|e| ListingErrBoxed(Box::new(e)))?;
Expand All @@ -252,18 +260,21 @@ pub struct SqlServerTableProvider {
table: String,
state: Arc<SqlServerAccessState>,
arrow_schema: ArrowSchemaRef,
sql_server_schema: Vec<tiberius::Column>,
}

impl SqlServerTableProvider {
pub async fn try_new(conf: SqlServerTableProviderConfig) -> Result<Self> {
let state = SqlServerAccessState::connect(conf.access.config).await?;
let arrow_schema = state.get_table_schema(&conf.schema, &conf.table).await?;
let (arrow_schema, sql_server_schema) =
state.get_table_schema(&conf.schema, &conf.table).await?;

Ok(Self {
schema: conf.schema,
table: conf.table,
state: Arc::new(state),
arrow_schema: Arc::new(arrow_schema),
sql_server_schema,
})
}
}
Expand Down Expand Up @@ -293,7 +304,7 @@ impl TableProvider for SqlServerTableProvider {
&self,
_ctx: &SessionState,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
filters: &[Expr],
limit: Option<usize>,
) -> DatafusionResult<Arc<dyn ExecutionPlan>> {
// Project the schema.
Expand All @@ -312,17 +323,26 @@ impl TableProvider for SqlServerTableProvider {
.join(",");

let limit_string = match limit {
Some(limit) => format!("LIMIT {}", limit),
Some(limit) => format!("TOP {}", limit),
None => String::new(),
};

// TODO: Where/filters
let predicate_string = exprs_to_predicate_string(filters, &self.sql_server_schema)
.map_err(|e| DataFusionError::External(Box::new(e)))?;

let predicate_string = if predicate_string.is_empty() {
predicate_string
} else {
format!("WHERE {predicate_string}")
};

let query = format!(
"SELECT {projection_string} FROM {}.{} {limit_string}",
"SELECT {limit_string} {projection_string} FROM {}.{} {predicate_string}",
self.schema, self.table
);

eprintln!("query = {query:?}");

Ok(Arc::new(SqlServerExec {
query,
state: self.state.clone(),
Expand All @@ -343,6 +363,114 @@ impl TableProvider for SqlServerTableProvider {
}
}

/// Convert filtering expressions to a predicate string usable with the
/// generated SQL Server query.
fn exprs_to_predicate_string(
exprs: &[Expr],
sql_server_schema: &[tiberius::Column],
) -> Result<String> {
let mut ss = Vec::new();
let mut buf = String::new();

let dt_map: HashMap<_, _> = sql_server_schema
.iter()
.map(|col| (col.name(), col.column_type()))
.collect();

for expr in exprs {
if try_write_expr(expr, &dt_map, &mut buf)? {
ss.push(buf);
buf = String::new();
}
}

Ok(ss.join(" AND "))
}

/// Try to write the expression to the string, returning true if it was written.
fn try_write_expr(
expr: &Expr,
dt_map: &HashMap<&str, tiberius::ColumnType>,
buf: &mut String,
) -> Result<bool> {
match expr {
Expr::Column(col) => {
write!(buf, "{}", col)?;
}
Expr::Literal(val) => {
util::encode_literal_to_text(util::Datasource::SqlServer, buf, val)?;
}
Expr::IsNull(expr) => {
if try_write_expr(expr, dt_map, buf)? {
write!(buf, " IS NULL")?;
} else {
return Ok(false);
}
}
Expr::IsNotNull(expr) => {
if try_write_expr(expr, dt_map, buf)? {
write!(buf, " IS NOT NULL")?;
} else {
return Ok(false);
}
}
Expr::IsTrue(expr) => {
if try_write_expr(expr, dt_map, buf)? {
write!(buf, " = 1")?;
} else {
return Ok(false);
}
}
Expr::IsFalse(expr) => {
if try_write_expr(expr, dt_map, buf)? {
write!(buf, " = 0")?;
} else {
return Ok(false);
}
}
Expr::BinaryExpr(binary) => {
if should_skip_binary_expr(binary, dt_map)? {
return Ok(false);
}

if !try_write_expr(binary.left.as_ref(), dt_map, buf)? {
return Ok(false);
}
write!(buf, " {} ", binary.op)?;
if !try_write_expr(binary.right.as_ref(), dt_map, buf)? {
return Ok(false);
}
}
_ => {
// Unsupported.
return Ok(false);
}
}

Ok(true)
}

fn should_skip_binary_expr(
expr: &BinaryExpr,
dt_map: &HashMap<&str, tiberius::ColumnType>,
) -> Result<bool> {
fn is_text_col(expr: &Expr, dt_map: &HashMap<&str, tiberius::ColumnType>) -> Result<bool> {
match expr {
Expr::Column(col) => {
let sql_type = dt_map.get(col.name.as_str()).ok_or_else(|| {
SqlServerError::String(format!("invalid column `{}`", col.name))
})?;
use tiberius::ColumnType;
Ok(matches!(sql_type, ColumnType::Text | ColumnType::NText))
}
_ => Ok(false),
}
}

// Skip if we're trying to do any kind of binary op with text column
Ok(is_text_col(&expr.left, dt_map)? || is_text_col(&expr.right, dt_map)?)
}

/// Execution plan for reading from SQL Server.
struct SqlServerExec {
query: String,
Expand Down