Skip to content

Commit

Permalink
fix: Pushdown SQL Server WHERE clause (#2316)
Browse files Browse the repository at this point in the history
Fixes #2094

---------

Signed-off-by: Vaibhav <vrongmeal@gmail.com>
  • Loading branch information
vrongmeal authored Jan 11, 2024
1 parent d86f20c commit 5f94acc
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 19 deletions.
1 change: 1 addition & 0 deletions crates/datasources/src/common/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub enum Datasource {
BigQuery,
Snowflake,
Clickhouse,
SqlServer,
}

/// Returns true if the literal expression encoding should be wrapped inside
Expand Down
4 changes: 4 additions & 0 deletions crates/datasources/src/sqlserver/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ pub enum SqlServerError {
Io(#[from] std::io::Error),
#[error(transparent)]
Arrow(#[from] datafusion::arrow::error::ArrowError),
#[error(transparent)]
Fmt(#[from] std::fmt::Error),
#[error(transparent)]
DatasourceCommon(#[from] crate::common::errors::DatasourceCommonError),
}

pub type Result<T, E = SqlServerError> = std::result::Result<T, E>;
166 changes: 147 additions & 19 deletions crates/datasources/src/sqlserver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ pub mod errors;

mod client;

use chrono::{DateTime, Utc};
use client::{Client, QueryStream};

use async_trait::async_trait;
use chrono::naive::NaiveDateTime;
use chrono::{DateTime, Utc};
use client::{Client, QueryStream};
use datafusion::arrow::datatypes::{
DataType, Field, Fields, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, TimeUnit,
};
Expand All @@ -15,7 +14,7 @@ use datafusion::datasource::TableProvider;
use datafusion::error::{DataFusionError, Result as DatafusionResult};
use datafusion::execution::context::SessionState;
use datafusion::execution::context::TaskContext;
use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableType};
use datafusion::logical_expr::{BinaryExpr, Expr, TableProviderFilterPushDown, TableType};
use datafusion::physical_expr::PhysicalSortExpr;
use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
use datafusion::physical_plan::metrics::MetricsSet;
Expand All @@ -28,18 +27,22 @@ use datafusion_ext::functions::VirtualLister;
use datafusion_ext::metrics::DataSourceMetricsStreamAdapter;
use errors::{Result, SqlServerError};
use futures::{future::BoxFuture, ready, stream::BoxStream, FutureExt, Stream, StreamExt};
use std::any::Any;
use std::fmt;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tiberius::FromSql;
use tokio::net::TcpStream;
use tokio::task::JoinHandle;
use tokio_util::compat::TokioAsyncWriteCompatExt;
use tracing::warn;

use std::any::Any;
use std::collections::HashMap;
use std::fmt::{self, Write};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

use crate::common::util;

/// Timeout when attempting to connecting to the remote server.
pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);

Expand Down Expand Up @@ -115,8 +118,12 @@ impl SqlServerAccessState {
})
}

/// Get the arrow schema for a table.
async fn get_table_schema(&self, schema: &str, name: &str) -> Result<ArrowSchema> {
/// Get the arrow schema and sql server schema for a table.
async fn get_table_schema(
&self,
schema: &str,
name: &str,
) -> Result<(ArrowSchema, Vec<tiberius::Column>)> {
let mut query = self
.client
.query(format!("SELECT * FROM {schema}.{name} WHERE 1=0"))
Expand All @@ -133,6 +140,7 @@ impl SqlServerAccessState {
};

let mut fields = Vec::with_capacity(cols.len());

for col in cols {
use tiberius::ColumnType;

Expand Down Expand Up @@ -179,7 +187,7 @@ impl SqlServerAccessState {
fields.push(field);
}

Ok(ArrowSchema::new(fields))
Ok((ArrowSchema::new(fields), cols.to_vec()))
}
}

Expand Down Expand Up @@ -232,7 +240,7 @@ impl VirtualLister for SqlServerAccessState {
async fn list_columns(&self, schema: &str, table: &str) -> Result<Fields, ExtensionError> {
use ExtensionError::ListingErrBoxed;

let schema = self
let (schema, _) = self
.get_table_schema(schema, table)
.await
.map_err(|e| ListingErrBoxed(Box::new(e)))?;
Expand All @@ -252,18 +260,21 @@ pub struct SqlServerTableProvider {
table: String,
state: Arc<SqlServerAccessState>,
arrow_schema: ArrowSchemaRef,
sql_server_schema: Vec<tiberius::Column>,
}

impl SqlServerTableProvider {
pub async fn try_new(conf: SqlServerTableProviderConfig) -> Result<Self> {
let state = SqlServerAccessState::connect(conf.access.config).await?;
let arrow_schema = state.get_table_schema(&conf.schema, &conf.table).await?;
let (arrow_schema, sql_server_schema) =
state.get_table_schema(&conf.schema, &conf.table).await?;

Ok(Self {
schema: conf.schema,
table: conf.table,
state: Arc::new(state),
arrow_schema: Arc::new(arrow_schema),
sql_server_schema,
})
}
}
Expand Down Expand Up @@ -293,7 +304,7 @@ impl TableProvider for SqlServerTableProvider {
&self,
_ctx: &SessionState,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
filters: &[Expr],
limit: Option<usize>,
) -> DatafusionResult<Arc<dyn ExecutionPlan>> {
// Project the schema.
Expand All @@ -312,17 +323,26 @@ impl TableProvider for SqlServerTableProvider {
.join(",");

let limit_string = match limit {
Some(limit) => format!("LIMIT {}", limit),
Some(limit) => format!("TOP {}", limit),
None => String::new(),
};

// TODO: Where/filters
let predicate_string = exprs_to_predicate_string(filters, &self.sql_server_schema)
.map_err(|e| DataFusionError::External(Box::new(e)))?;

let predicate_string = if predicate_string.is_empty() {
predicate_string
} else {
format!("WHERE {predicate_string}")
};

let query = format!(
"SELECT {projection_string} FROM {}.{} {limit_string}",
"SELECT {limit_string} {projection_string} FROM {}.{} {predicate_string}",
self.schema, self.table
);

eprintln!("query = {query:?}");

Ok(Arc::new(SqlServerExec {
query,
state: self.state.clone(),
Expand All @@ -343,6 +363,114 @@ impl TableProvider for SqlServerTableProvider {
}
}

/// Convert filtering expressions to a predicate string usable with the
/// generated SQL Server query.
fn exprs_to_predicate_string(
exprs: &[Expr],
sql_server_schema: &[tiberius::Column],
) -> Result<String> {
let mut ss = Vec::new();
let mut buf = String::new();

let dt_map: HashMap<_, _> = sql_server_schema
.iter()
.map(|col| (col.name(), col.column_type()))
.collect();

for expr in exprs {
if try_write_expr(expr, &dt_map, &mut buf)? {
ss.push(buf);
buf = String::new();
}
}

Ok(ss.join(" AND "))
}

/// Try to write the expression to the string, returning true if it was written.
fn try_write_expr(
expr: &Expr,
dt_map: &HashMap<&str, tiberius::ColumnType>,
buf: &mut String,
) -> Result<bool> {
match expr {
Expr::Column(col) => {
write!(buf, "{}", col)?;
}
Expr::Literal(val) => {
util::encode_literal_to_text(util::Datasource::SqlServer, buf, val)?;
}
Expr::IsNull(expr) => {
if try_write_expr(expr, dt_map, buf)? {
write!(buf, " IS NULL")?;
} else {
return Ok(false);
}
}
Expr::IsNotNull(expr) => {
if try_write_expr(expr, dt_map, buf)? {
write!(buf, " IS NOT NULL")?;
} else {
return Ok(false);
}
}
Expr::IsTrue(expr) => {
if try_write_expr(expr, dt_map, buf)? {
write!(buf, " = 1")?;
} else {
return Ok(false);
}
}
Expr::IsFalse(expr) => {
if try_write_expr(expr, dt_map, buf)? {
write!(buf, " = 0")?;
} else {
return Ok(false);
}
}
Expr::BinaryExpr(binary) => {
if should_skip_binary_expr(binary, dt_map)? {
return Ok(false);
}

if !try_write_expr(binary.left.as_ref(), dt_map, buf)? {
return Ok(false);
}
write!(buf, " {} ", binary.op)?;
if !try_write_expr(binary.right.as_ref(), dt_map, buf)? {
return Ok(false);
}
}
_ => {
// Unsupported.
return Ok(false);
}
}

Ok(true)
}

fn should_skip_binary_expr(
expr: &BinaryExpr,
dt_map: &HashMap<&str, tiberius::ColumnType>,
) -> Result<bool> {
fn is_text_col(expr: &Expr, dt_map: &HashMap<&str, tiberius::ColumnType>) -> Result<bool> {
match expr {
Expr::Column(col) => {
let sql_type = dt_map.get(col.name.as_str()).ok_or_else(|| {
SqlServerError::String(format!("invalid column `{}`", col.name))
})?;
use tiberius::ColumnType;
Ok(matches!(sql_type, ColumnType::Text | ColumnType::NText))
}
_ => Ok(false),
}
}

// Skip if we're trying to do any kind of binary op with text column
Ok(is_text_col(&expr.left, dt_map)? || is_text_col(&expr.right, dt_map)?)
}

/// Execution plan for reading from SQL Server.
struct SqlServerExec {
query: String,
Expand Down

0 comments on commit 5f94acc

Please sign in to comment.