Skip to content

Commit

Permalink
[flight] Coerce flight data into target schema if needed
Browse files Browse the repository at this point in the history
  • Loading branch information
ccciudatu committed Sep 16, 2024
1 parent 58531df commit 2c2ad24
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 34 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ description = "Extend the capabilities of DataFusion to support additional data
[dependencies]
arrow = "52.2.0"
arrow-array = { version = "52.2.0", optional = true }
arrow-cast = { version = "52.2.0", optional = true }
arrow-flight = { version = "52.2.0", optional = true, features = ["flight-sql-experimental", "tls"] }
arrow-schema = { version = "52.2.0", optional = true, features = ["serde"] }
arrow-json = "52.2.0"
Expand Down Expand Up @@ -83,6 +84,7 @@ sqlite = ["dep:rusqlite", "dep:tokio-rusqlite"]
duckdb = ["dep:duckdb", "dep:r2d2", "dep:uuid"]
flight = [
"dep:arrow-array",
"dep:arrow-cast",
"dep:arrow-flight",
"dep:arrow-schema",
"dep:base64",
Expand Down
142 changes: 115 additions & 27 deletions src/flight/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::flight::{FlightMetadata, FlightProperties};
use arrow_array::RecordBatch;
use arrow_flight::error::FlightError;
use arrow_flight::{FlightClient, FlightEndpoint, Ticket};
use arrow_schema::SchemaRef;
use arrow_schema::{ArrowError, SchemaRef};
use datafusion::arrow::datatypes::ToByteSlice;
use datafusion::common::Result;
use datafusion::common::{project_schema, DataFusionError};
Expand Down Expand Up @@ -206,35 +206,40 @@ async fn try_fetch_stream(
.map_err(|e| FlightError::ExternalError(Box::new(e)))?;
let mut client = FlightClient::new(channel);
client.metadata_mut().clone_from(grpc_headers.as_ref());
let stream = client.do_get(ticket).await?;
let stream = client
.do_get(ticket)
.await?
.map_err(|e| DataFusionError::External(Box::new(e)));
Ok(Box::pin(RecordBatchStreamAdapter::new(
schema.clone(),
stream.map(move |rb| {
let schema = schema.clone();
rb.map(move |rb| {
if schema.fields.is_empty() || rb.schema() == schema {
rb
} else if schema.contains(rb.schema_ref()) {
rb.with_schema(schema.clone()).unwrap()
} else {
let columns = schema
.fields
.iter()
.map(|field| {
rb.column_by_name(field.name())
.expect("missing fields in record batch")
.clone()
})
.collect();
RecordBatch::try_new(schema.clone(), columns)
.expect("cannot impose desired schema on record batch")
}
})
.map_err(|e| DataFusionError::External(Box::new(e)))
}),
stream.map(move |item| item.and_then(|rb| enforce_schema(rb, &schema).map_err(Into::into))),
)))
}

fn enforce_schema(rb: RecordBatch, target_schema: &SchemaRef) -> arrow::error::Result<RecordBatch> {
if target_schema.fields.is_empty() || rb.schema() == *target_schema {
Ok(rb)
} else if target_schema.contains(rb.schema_ref()) {
rb.with_schema(target_schema.clone())
} else {
let columns = target_schema
.fields
.iter()
.map(|field| {
rb.column_by_name(field.name())
.ok_or(ArrowError::SchemaError(format!(
"Required field `{}` is missing from the flight response",
field.name()
)))
.and_then(|original_array| {
arrow_cast::cast(original_array.as_ref(), field.data_type())
})
})
.collect::<Result<_, _>>()?;
RecordBatch::try_new(target_schema.clone(), columns)
}
}

impl DisplayAs for FlightExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
match t {
Expand Down Expand Up @@ -297,9 +302,12 @@ impl ExecutionPlan for FlightExec {

#[cfg(test)]
mod tests {
use crate::flight::exec::{FlightConfig, FlightPartition, FlightTicket};
use crate::flight::exec::{enforce_schema, FlightConfig, FlightPartition, FlightTicket};
use crate::flight::FlightProperties;
use arrow_schema::{DataType, Field, Schema};
use arrow_array::{
BooleanArray, Float32Array, Int32Array, RecordBatch, StringArray, StructArray,
};
use arrow_schema::{DataType, Field, Fields, Schema};
use std::collections::HashMap;
use std::sync::Arc;

Expand Down Expand Up @@ -334,4 +342,84 @@ mod tests {
let restored = serde_json::from_slice(json.as_slice()).expect("cannot decode json config");
assert_eq!(config, restored);
}

#[test]
fn test_schema_enforcement() -> arrow::error::Result<()> {
let data = StructArray::new(
Fields::from(vec![
Arc::new(Field::new("f_int", DataType::Int32, true)),
Arc::new(Field::new("f_bool", DataType::Boolean, false)),
]),
vec![
Arc::new(Int32Array::from(vec![10, 20])),
Arc::new(BooleanArray::from(vec![true, false])),
],
None,
);
let input_rb = RecordBatch::from(data);

let empty_schema = Arc::new(Schema::empty());
let same_rb = enforce_schema(input_rb.clone(), &empty_schema)?;
assert_eq!(input_rb, same_rb);

let coerced_rb = enforce_schema(
input_rb.clone(),
&Arc::new(Schema::new(vec![
// compatible yet different types with flipped nullability
Arc::new(Field::new("f_int", DataType::Float32, false)),
Arc::new(Field::new("f_bool", DataType::Utf8, true)),
])),
)?;
assert_ne!(input_rb, coerced_rb);
assert_eq!(coerced_rb.num_columns(), 2);
assert_eq!(coerced_rb.num_rows(), 2);
assert_eq!(
coerced_rb.column(0).as_ref(),
&Float32Array::from(vec![10.0, 20.0])
);
assert_eq!(
coerced_rb.column(1).as_ref(),
&StringArray::from(vec!["true", "false"])
);

let projection_rb = enforce_schema(
input_rb.clone(),
&Arc::new(Schema::new(vec![
// keep only the first column and make it non-nullable int16
Arc::new(Field::new("f_int", DataType::Int16, false)),
])),
)?;
assert_eq!(projection_rb.num_columns(), 1);
assert_eq!(projection_rb.num_rows(), 2);
assert_eq!(projection_rb.schema().fields().len(), 1);
assert_eq!(projection_rb.schema().fields()[0].name(), "f_int");

let incompatible_schema_attempt = enforce_schema(
input_rb.clone(),
&Arc::new(Schema::new(vec![
Arc::new(Field::new("f_int", DataType::Float32, true)),
Arc::new(Field::new("f_bool", DataType::Date32, false)),
])),
);
assert!(incompatible_schema_attempt.is_err());
assert_eq!(
incompatible_schema_attempt.unwrap_err().to_string(),
"Cast error: Casting from Boolean to Date32 not supported"
);

let broader_schema_attempt = enforce_schema(
input_rb.clone(),
&Arc::new(Schema::new(vec![
Arc::new(Field::new("f_int", DataType::Int32, true)),
Arc::new(Field::new("f_bool", DataType::Boolean, false)),
Arc::new(Field::new("f_extra", DataType::Utf8, true)),
])),
);
assert!(broader_schema_attempt.is_err());
assert_eq!(
broader_schema_attempt.unwrap_err().to_string(),
"Schema error: Required field `f_extra` is missing from the flight response"
);
Ok(())
}
}
4 changes: 2 additions & 2 deletions src/flight/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ impl FlightDriver for FlightSqlDriver {
key.strip_prefix(HEADER_PREFIX)
.map(|header_name| (header_name, value))
});
for header in headers {
client.set_header(header.0, header.1)
for (name, value) in headers {
client.set_header(name, value)
}
if let Some(username) = options.get(USERNAME) {
let default_password = "".to_string();
Expand Down
7 changes: 2 additions & 5 deletions tests/flight/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ async fn test_flight_sql_data_source() -> datafusion::common::Result<()> {
Arc::new(Float32Array::from(vec![0.0, 0.1, 0.2, 0.3])),
Arc::new(Int8Array::from(vec![10, 20, 30, 40])),
],
)
.unwrap();
)?;
let rows_per_partition = partition_data.num_rows();

let query = "SELECT * FROM some_table";
Expand All @@ -174,9 +173,7 @@ async fn test_flight_sql_data_source() -> datafusion::common::Result<()> {
endpoint_archetype,
];
let num_partitions = endpoints.len();
let flight_info = FlightInfo::default()
.try_with_schema(partition_data.schema().as_ref())
.unwrap();
let flight_info = FlightInfo::default().try_with_schema(partition_data.schema().as_ref())?;
let flight_info = endpoints
.into_iter()
.fold(flight_info, |fi, e| fi.with_endpoint(e));
Expand Down

0 comments on commit 2c2ad24

Please sign in to comment.