Skip to content

Commit be272ec

Browse files
committed
oof
1 parent 740bb5a commit be272ec

File tree

11 files changed

+97
-65
lines changed

11 files changed

+97
-65
lines changed

datafusion/common/src/param_value.rs

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,23 @@
1616
// under the License.
1717

1818
use crate::error::{_plan_datafusion_err, _plan_err};
19+
use crate::metadata::FieldMetadata;
1920
use crate::{Result, ScalarValue};
20-
use arrow::datatypes::DataType;
21+
use arrow::datatypes::FieldRef;
2122
use std::collections::HashMap;
2223

2324
/// The parameter value corresponding to the placeholder
2425
#[derive(Debug, Clone)]
2526
pub enum ParamValues {
2627
/// For positional query parameters, like `SELECT * FROM test WHERE a > $1 AND b = $2`
27-
List(Vec<ScalarValue>),
28+
List(Vec<(ScalarValue, Option<FieldMetadata>)>),
2829
/// For named query parameters, like `SELECT * FROM test WHERE a > $foo AND b = $goo`
29-
Map(HashMap<String, ScalarValue>),
30+
Map(HashMap<String, (ScalarValue, Option<FieldMetadata>)>),
3031
}
3132

3233
impl ParamValues {
3334
/// Verify parameter list length and type
34-
pub fn verify(&self, expect: &[DataType]) -> Result<()> {
35+
pub fn verify(&self, expect: &[FieldRef]) -> Result<()> {
3536
match self {
3637
ParamValues::List(list) => {
3738
// Verify if the number of params matches the number of values
@@ -45,15 +46,28 @@ impl ParamValues {
4546

4647
// Verify if the types of the params matches the types of the values
4748
let iter = expect.iter().zip(list.iter());
48-
for (i, (param_type, value)) in iter.enumerate() {
49-
if *param_type != value.data_type() {
49+
for (i, (param_type, (value, maybe_metadata))) in iter.enumerate() {
50+
if *param_type.data_type() != value.data_type() {
5051
return _plan_err!(
5152
"Expected parameter of type {}, got {:?} at index {}",
5253
param_type,
5354
value.data_type(),
5455
i
5556
);
5657
}
58+
59+
if let Some(expected_metadata) = maybe_metadata {
60+
// Probably too strict of a comparison (this is an example of where
61+
// the concept of type equality would be useful)
62+
if &expected_metadata.to_hashmap() != param_type.metadata() {
63+
return _plan_err!(
64+
"Expected parameter with metadata {:?}, got {:?} at index {}",
65+
expected_metadata,
66+
param_type.metadata(),
67+
i
68+
);
69+
}
70+
}
5771
}
5872
Ok(())
5973
}
@@ -65,7 +79,10 @@ impl ParamValues {
6579
}
6680
}
6781

68-
pub fn get_placeholders_with_values(&self, id: &str) -> Result<ScalarValue> {
82+
pub fn get_placeholders_with_values(
83+
&self,
84+
id: &str,
85+
) -> Result<(ScalarValue, Option<FieldMetadata>)> {
6986
match self {
7087
ParamValues::List(list) => {
7188
if id.is_empty() {
@@ -99,7 +116,7 @@ impl ParamValues {
99116

100117
impl From<Vec<ScalarValue>> for ParamValues {
101118
fn from(value: Vec<ScalarValue>) -> Self {
102-
Self::List(value)
119+
Self::List(value.into_iter().map(|v| (v, None)).collect())
103120
}
104121
}
105122

@@ -108,8 +125,10 @@ where
108125
K: Into<String>,
109126
{
110127
fn from(value: Vec<(K, ScalarValue)>) -> Self {
111-
let value: HashMap<String, ScalarValue> =
112-
value.into_iter().map(|(k, v)| (k.into(), v)).collect();
128+
let value: HashMap<String, (ScalarValue, Option<FieldMetadata>)> = value
129+
.into_iter()
130+
.map(|(k, v)| (k.into(), (v, None)))
131+
.collect();
113132
Self::Map(value)
114133
}
115134
}
@@ -119,8 +138,10 @@ where
119138
K: Into<String>,
120139
{
121140
fn from(value: HashMap<K, ScalarValue>) -> Self {
122-
let value: HashMap<String, ScalarValue> =
123-
value.into_iter().map(|(k, v)| (k.into(), v)).collect();
141+
let value: HashMap<String, (ScalarValue, Option<FieldMetadata>)> = value
142+
.into_iter()
143+
.map(|(k, v)| (k.into(), (v, None)))
144+
.collect();
124145
Self::Map(value)
125146
}
126147
}

datafusion/core/src/execution/context/mod.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ use datafusion_catalog::{
6464
DynamicFileCatalog, TableFunction, TableFunctionImpl, UrlTableFactory,
6565
};
6666
use datafusion_common::config::ConfigOptions;
67+
use datafusion_common::metadata::FieldMetadata;
6768
use datafusion_common::{
6869
config::{ConfigExtension, TableOptions},
6970
exec_datafusion_err, exec_err, internal_datafusion_err, not_impl_err,
@@ -1238,10 +1239,10 @@ impl SessionContext {
12381239
})?;
12391240

12401241
// Only allow literals as parameters for now.
1241-
let mut params: Vec<ScalarValue> = parameters
1242+
let mut params: Vec<(ScalarValue, Option<FieldMetadata>)> = parameters
12421243
.into_iter()
12431244
.map(|e| match e {
1244-
Expr::Literal(scalar, _) => Ok(scalar),
1245+
Expr::Literal(scalar, metadata) => Ok((scalar, metadata)),
12451246
_ => not_impl_err!("Unsupported parameter type: {}", e),
12461247
})
12471248
.collect::<Result<_>>()?;
@@ -1259,7 +1260,11 @@ impl SessionContext {
12591260
params = params
12601261
.into_iter()
12611262
.zip(prepared.data_types.iter())
1262-
.map(|(e, dt)| e.cast_to(dt))
1263+
.map(|(e, dt)| -> Result<_> {
1264+
// This is fishy...we're casting storage without checking if an
1265+
// extension type supports the destination
1266+
Ok((e.0.cast_to(dt.data_type())?, e.1))
1267+
})
12631268
.collect::<Result<_>>()?;
12641269
}
12651270

datafusion/core/src/execution/session_state.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use crate::datasource::provider_as_source;
3030
use crate::execution::context::{EmptySerializerRegistry, FunctionFactory, QueryPlanner};
3131
use crate::execution::SessionStateDefaults;
3232
use crate::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
33+
use arrow_schema::FieldRef;
3334
use datafusion_catalog::information_schema::{
3435
InformationSchemaProvider, INFORMATION_SCHEMA,
3536
};
@@ -116,11 +117,11 @@ use uuid::Uuid;
116117
/// # #[tokio::main]
117118
/// # async fn main() -> Result<()> {
118119
/// let state = SessionStateBuilder::new()
119-
/// .with_config(SessionConfig::new())
120+
/// .with_config(SessionConfig::new())
120121
/// .with_runtime_env(Arc::new(RuntimeEnv::default()))
121122
/// .with_default_features()
122123
/// .build();
123-
/// Ok(())
124+
/// Ok(())
124125
/// # }
125126
/// ```
126127
///
@@ -873,7 +874,7 @@ impl SessionState {
873874
pub(crate) fn store_prepared(
874875
&mut self,
875876
name: String,
876-
data_types: Vec<DataType>,
877+
data_types: Vec<FieldRef>,
877878
plan: Arc<LogicalPlan>,
878879
) -> datafusion_common::Result<()> {
879880
match self.prepared_plans.entry(name) {
@@ -1323,7 +1324,7 @@ impl SessionStateBuilder {
13231324
/// let url = Url::try_from("file://").unwrap();
13241325
/// let object_store = object_store::local::LocalFileSystem::new();
13251326
/// let state = SessionStateBuilder::new()
1326-
/// .with_config(SessionConfig::new())
1327+
/// .with_config(SessionConfig::new())
13271328
/// .with_object_store(&url, Arc::new(object_store))
13281329
/// .with_default_features()
13291330
/// .build();
@@ -2012,7 +2013,7 @@ impl SimplifyInfo for SessionSimplifyProvider<'_> {
20122013
#[derive(Debug)]
20132014
pub(crate) struct PreparedPlan {
20142015
/// Data types of the parameters
2015-
pub(crate) data_types: Vec<DataType>,
2016+
pub(crate) data_types: Vec<FieldRef>,
20162017
/// The prepared logical plan
20172018
pub(crate) plan: Arc<LogicalPlan>,
20182019
}

datafusion/core/tests/dataframe/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ use datafusion_catalog::TableProvider;
6666
use datafusion_common::test_util::{batches_to_sort_string, batches_to_string};
6767
use datafusion_common::{
6868
assert_contains, internal_datafusion_err, Constraint, Constraints, DFSchema,
69-
DataFusionError, ParamValues, ScalarValue, TableReference, UnnestOptions,
69+
DataFusionError, ScalarValue, TableReference, UnnestOptions,
7070
};
7171
use datafusion_common_runtime::SpawnedTask;
7272
use datafusion_datasource::file_format::format_as_file_type;
@@ -2464,7 +2464,7 @@ async fn filtered_aggr_with_param_values() -> Result<()> {
24642464
let df = ctx
24652465
.sql("select count (c2) filter (where c3 > $1) from table1")
24662466
.await?
2467-
.with_param_values(ParamValues::List(vec![ScalarValue::from(10u64)]));
2467+
.with_param_values(vec![ScalarValue::from(10u64)]);
24682468

24692469
let df_results = df?.collect().await?;
24702470
assert_snapshot!(

datafusion/expr/src/logical_plan/builder.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ use crate::{
5050

5151
use super::dml::InsertOp;
5252
use arrow::compute::can_cast_types;
53-
use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
53+
use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema, SchemaRef};
5454
use datafusion_common::display::ToStringifiedPlan;
5555
use datafusion_common::file_options::file_type::FileType;
5656
use datafusion_common::metadata::FieldMetadata;
@@ -623,7 +623,7 @@ impl LogicalPlanBuilder {
623623
}
624624

625625
/// Make a builder for a prepare logical plan from the builder's plan
626-
pub fn prepare(self, name: String, data_types: Vec<DataType>) -> Result<Self> {
626+
pub fn prepare(self, name: String, data_types: Vec<FieldRef>) -> Result<Self> {
627627
Ok(Self::new(LogicalPlan::Statement(Statement::Prepare(
628628
Prepare {
629629
name,

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1464,7 +1464,7 @@ impl LogicalPlan {
14641464
let transformed_expr = e.transform_up(|e| {
14651465
if let Expr::Placeholder(Placeholder { id, .. }) = e {
14661466
let value = param_values.get_placeholders_with_values(&id)?;
1467-
Ok(Transformed::yes(Expr::Literal(value, None)))
1467+
Ok(Transformed::yes(Expr::Literal(value.0, value.1)))
14681468
} else {
14691469
Ok(Transformed::no(e))
14701470
}

datafusion/expr/src/logical_plan/statement.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::datatypes::DataType;
18+
use arrow::datatypes::FieldRef;
1919
use datafusion_common::{DFSchema, DFSchemaRef};
2020
use itertools::Itertools as _;
2121
use std::fmt::{self, Display};
@@ -192,7 +192,7 @@ pub struct Prepare {
192192
/// The name of the statement
193193
pub name: String,
194194
/// Data types of the parameters ([`Expr::Placeholder`])
195-
pub data_types: Vec<DataType>,
195+
pub data_types: Vec<FieldRef>,
196196
/// The logical plan of the statements
197197
pub input: Arc<LogicalPlan>,
198198
}

datafusion/sql/src/expr/mod.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
287287
schema,
288288
planner_context,
289289
)?),
290-
self.convert_data_type(&data_type)?,
290+
self.convert_data_type(&data_type)?.data_type().clone(),
291291
)))
292292
}
293293

@@ -297,7 +297,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
297297
uses_odbc_syntax: _,
298298
}) => Ok(Expr::Cast(Cast::new(
299299
Box::new(lit(value.into_string().unwrap())),
300-
self.convert_data_type(&data_type)?,
300+
self.convert_data_type(&data_type)?.data_type().clone(),
301301
))),
302302

303303
SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(
@@ -974,7 +974,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
974974

975975
// numeric constants are treated as seconds (rather as nanoseconds)
976976
// to align with postgres / duckdb semantics
977-
let expr = match &dt {
977+
let expr = match dt.data_type() {
978978
DataType::Timestamp(TimeUnit::Nanosecond, tz)
979979
if expr.get_type(schema)? == DataType::Int64 =>
980980
{
@@ -986,7 +986,10 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
986986
_ => expr,
987987
};
988988

989-
Ok(Expr::Cast(Cast::new(Box::new(expr), dt)))
989+
Ok(Expr::Cast(Cast::new(
990+
Box::new(expr),
991+
dt.data_type().clone(),
992+
)))
990993
}
991994

992995
/// Extracts the root expression and access chain from a compound expression.

datafusion/sql/src/expr/value.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use arrow::compute::kernels::cast_utils::{
2020
parse_interval_month_day_nano_config, IntervalParseConfig, IntervalUnit,
2121
};
2222
use arrow::datatypes::{
23-
i256, DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
23+
i256, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
2424
};
2525
use bigdecimal::num_bigint::BigInt;
2626
use bigdecimal::{BigDecimal, Signed, ToPrimitive};
@@ -45,7 +45,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
4545
pub(crate) fn parse_value(
4646
&self,
4747
value: Value,
48-
param_data_types: &[DataType],
48+
param_data_types: &[FieldRef],
4949
) -> Result<Expr> {
5050
match value {
5151
Value::Number(n, _) => self.parse_sql_number(&n, false),
@@ -108,7 +108,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
108108
/// number 1, 2, ... etc. For example, `$1` is the first placeholder; $2 is the second one and so on.
109109
fn create_placeholder_expr(
110110
param: String,
111-
param_data_types: &[DataType],
111+
param_data_types: &[FieldRef],
112112
) -> Result<Expr> {
113113
// Parse the placeholder as a number because it is the only support from sqlparser and postgres
114114
let index = param[1..].parse::<usize>();
@@ -133,7 +133,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
133133
// Data type of the parameter
134134
debug!("type of param {param} param_data_types[idx]: {param_type:?}");
135135

136-
Ok(Expr::Placeholder(Placeholder::new(
136+
Ok(Expr::Placeholder(Placeholder::new_with_metadata(
137137
param,
138138
param_type.cloned(),
139139
)))

0 commit comments

Comments
 (0)