Skip to content

Commit f239016

Browse files
committed
refactor: add dialect enum
1 parent 556eb9b commit f239016

File tree

9 files changed

+127
-35
lines changed

9 files changed

+127
-35
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion-cli/Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ async-trait = { workspace = true }
4040
aws-config = "1.8.6"
4141
aws-credential-types = "1.2.7"
4242
chrono = { workspace = true }
43-
clap = { version = "4.5.47", features = ["derive", "cargo"] }
43+
clap = { version = "4.5.47", features = ["cargo", "derive"] }
4444
datafusion = { workspace = true, features = [
4545
"avro",
4646
"compression",
@@ -55,6 +55,7 @@ datafusion = { workspace = true, features = [
5555
"sql",
5656
"unicode_expressions",
5757
] }
58+
datafusion-common = { workspace = true }
5859
dirs = "6.0.0"
5960
env_logger = { workspace = true }
6061
futures = { workspace = true }
@@ -65,7 +66,7 @@ parking_lot = { workspace = true }
6566
parquet = { workspace = true, default-features = false }
6667
regex = { workspace = true }
6768
rustyline = "17.0"
68-
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot", "signal"] }
69+
tokio = { workspace = true, features = ["macros", "parking_lot", "rt", "rt-multi-thread", "signal", "sync"] }
6970
url = { workspace = true }
7071

7172
[dev-dependencies]

datafusion-cli/src/helper.rs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use crate::highlighter::{NoSyntaxHighlighter, SyntaxHighlighter};
2424

2525
use datafusion::sql::parser::{DFParser, Statement};
2626
use datafusion::sql::sqlparser::dialect::dialect_from_str;
27+
use datafusion_common::config::Dialect;
2728

2829
use rustyline::completion::{Completer, FilenameCompleter, Pair};
2930
use rustyline::error::ReadlineError;
@@ -34,33 +35,33 @@ use rustyline::{Context, Helper, Result};
3435

3536
pub struct CliHelper {
3637
completer: FilenameCompleter,
37-
dialect: String,
38+
dialect: Dialect,
3839
highlighter: Box<dyn Highlighter>,
3940
}
4041

4142
impl CliHelper {
42-
pub fn new(dialect: &str, color: bool) -> Self {
43+
pub fn new(dialect: &Dialect, color: bool) -> Self {
4344
let highlighter: Box<dyn Highlighter> = if !color {
4445
Box::new(NoSyntaxHighlighter {})
4546
} else {
4647
Box::new(SyntaxHighlighter::new(dialect))
4748
};
4849
Self {
4950
completer: FilenameCompleter::new(),
50-
dialect: dialect.into(),
51+
dialect: *dialect,
5152
highlighter,
5253
}
5354
}
5455

55-
pub fn set_dialect(&mut self, dialect: &str) {
56-
if dialect != self.dialect {
57-
self.dialect = dialect.to_string();
56+
pub fn set_dialect(&mut self, dialect: &Dialect) {
57+
if *dialect != self.dialect {
58+
self.dialect = *dialect;
5859
}
5960
}
6061

6162
fn validate_input(&self, input: &str) -> Result<ValidationResult> {
6263
if let Some(sql) = input.strip_suffix(';') {
63-
let dialect = match dialect_from_str(&self.dialect) {
64+
let dialect = match dialect_from_str(self.dialect) {
6465
Some(dialect) => dialect,
6566
None => {
6667
return Ok(ValidationResult::Invalid(Some(format!(
@@ -97,7 +98,7 @@ impl CliHelper {
9798

9899
impl Default for CliHelper {
99100
fn default() -> Self {
100-
Self::new("generic", false)
101+
Self::new(&Dialect::Generic, false)
101102
}
102103
}
103104

@@ -289,7 +290,7 @@ mod tests {
289290
);
290291

291292
// valid in postgresql dialect
292-
validator.set_dialect("postgresql");
293+
validator.set_dialect(&Dialect::PostgreSQL);
293294
let result =
294295
readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?;
295296
assert!(matches!(result, ValidationResult::Valid(None)));

datafusion-cli/src/highlighter.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use datafusion::sql::sqlparser::{
2727
keywords::Keyword,
2828
tokenizer::{Token, Tokenizer},
2929
};
30+
use datafusion_common::config;
3031
use rustyline::highlight::{CmdKind, Highlighter};
3132

3233
/// The syntax highlighter.
@@ -36,7 +37,7 @@ pub struct SyntaxHighlighter {
3637
}
3738

3839
impl SyntaxHighlighter {
39-
pub fn new(dialect: &str) -> Self {
40+
pub fn new(dialect: &config::Dialect) -> Self {
4041
let dialect = dialect_from_str(dialect).unwrap_or(Box::new(GenericDialect {}));
4142
Self { dialect }
4243
}
@@ -93,13 +94,14 @@ impl Color {
9394

9495
#[cfg(test)]
9596
mod tests {
97+
use super::config::Dialect;
9698
use super::SyntaxHighlighter;
9799
use rustyline::highlight::Highlighter;
98100

99101
#[test]
100102
fn highlighter_valid() {
101103
let s = "SElect col_a from tab_1;";
102-
let highlighter = SyntaxHighlighter::new("generic");
104+
let highlighter = SyntaxHighlighter::new(&Dialect::Generic);
103105
let out = highlighter.highlight(s, s.len());
104106
assert_eq!(
105107
"\u{1b}[91mSElect\u{1b}[0m col_a \u{1b}[91mfrom\u{1b}[0m tab_1;",
@@ -110,7 +112,7 @@ mod tests {
110112
#[test]
111113
fn highlighter_valid_with_new_line() {
112114
let s = "SElect col_a from tab_1\n WHERE col_b = 'なにか';";
113-
let highlighter = SyntaxHighlighter::new("generic");
115+
let highlighter = SyntaxHighlighter::new(&Dialect::Generic);
114116
let out = highlighter.highlight(s, s.len());
115117
assert_eq!(
116118
"\u{1b}[91mSElect\u{1b}[0m col_a \u{1b}[91mfrom\u{1b}[0m tab_1\n \u{1b}[91mWHERE\u{1b}[0m col_b = \u{1b}[92m'なにか'\u{1b}[0m;",
@@ -121,7 +123,7 @@ mod tests {
121123
#[test]
122124
fn highlighter_invalid() {
123125
let s = "SElect col_a from tab_1 WHERE col_b = ';";
124-
let highlighter = SyntaxHighlighter::new("generic");
126+
let highlighter = SyntaxHighlighter::new(&Dialect::Generic);
125127
let out = highlighter.highlight(s, s.len());
126128
assert_eq!("SElect col_a from tab_1 WHERE col_b = ';", out);
127129
}

datafusion-examples/examples/remote_catalog.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ async fn main() -> Result<()> {
7575
let state = ctx.state();
7676

7777
// First, parse the SQL (but don't plan it / resolve any table references)
78-
let dialect = state.config().options().sql_parser.dialect.as_str();
79-
let statement = state.sql_to_statement(sql, dialect)?;
78+
let dialect = state.config().options().sql_parser.dialect;
79+
let statement = state.sql_to_statement(sql, &dialect)?;
8080

8181
// Find all `TableReferences` in the parsed queries. These correspond to the
8282
// tables referred to by the query (in this case

datafusion/common/src/config.rs

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ config_namespace! {
258258

259259
/// Configure the SQL dialect used by DataFusion's parser; supported values include: Generic,
260260
/// MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi, DuckDB and Databricks.
261-
pub dialect: String, default = "generic".to_string()
261+
pub dialect: Dialect, default = Dialect::Generic
262262
// no need to lowercase because `sqlparser::dialect_from_str`] is case-insensitive
263263

264264
/// If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but
@@ -292,6 +292,88 @@ config_namespace! {
292292
}
293293
}
294294

295+
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
296+
pub enum Dialect {
297+
#[default]
298+
Generic,
299+
MySQL,
300+
PostgreSQL,
301+
Hive,
302+
SQLite,
303+
Snowflake,
304+
Redshift,
305+
MsSQL,
306+
ClickHouse,
307+
BigQuery,
308+
Ansi,
309+
DuckDB,
310+
Databricks,
311+
}
312+
313+
impl AsRef<str> for Dialect {
314+
fn as_ref(&self) -> &str {
315+
match self {
316+
Self::Generic => "Generic",
317+
Self::MySQL => "MySQL",
318+
Self::PostgreSQL => "PostgreSQL",
319+
Self::Hive => "Hive",
320+
Self::SQLite => "SQLite",
321+
Self::Snowflake => "Snowflake",
322+
Self::Redshift => "Redshift",
323+
Self::MsSQL => "MsSQL",
324+
Self::ClickHouse => "ClickHouse",
325+
Self::BigQuery => "BigQuery",
326+
Self::Ansi => "Ansi",
327+
Self::DuckDB => "DuckDB",
328+
Self::Databricks => "Databricks",
329+
}
330+
}
331+
}
332+
333+
impl FromStr for Dialect {
334+
type Err = DataFusionError;
335+
336+
fn from_str(s: &str) -> Result<Self, Self::Err> {
337+
let value = match s.to_ascii_lowercase().as_str() {
338+
"generic" => Self::Generic,
339+
"mysql" => Self::MySQL,
340+
"postgresql" | "postgres" => Self::PostgreSQL,
341+
"hive" => Self::Hive,
342+
"sqlite" => Self::SQLite,
343+
"snowflake" => Self::Snowflake,
344+
"redshift" => Self::Redshift,
345+
"mssql" => Self::MsSQL,
346+
"clickhouse" => Self::ClickHouse,
347+
"bigquery" => Self::BigQuery,
348+
"ansi" => Self::Ansi,
349+
"duckdb" => Self::DuckDB,
350+
"databricks" => Self::Databricks,
351+
other => {return Err(DataFusionError::Configuration(format!(
352+
"Invalid Dialect: {other}. Expected one of: generic, mysql, postgresql, hive, sqlite, snowflake, redshift, mssql, clickhouse, bigquery, ansi, duckdb, databricks"
353+
)))}
354+
};
355+
Ok(value)
356+
}
357+
}
358+
359+
impl ConfigField for Dialect {
360+
fn visit<V: Visit>(&self, v: &mut V, key: &str, description: &'static str) {
361+
v.some(key, self, description)
362+
}
363+
364+
fn set(&mut self, _: &str, value: &str) -> Result<()> {
365+
*self = Self::from_str(value)?;
366+
Ok(())
367+
}
368+
}
369+
370+
impl Display for Dialect {
371+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
372+
let str = self.as_ref();
373+
write!(f, "{str}")
374+
}
375+
}
376+
295377
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
296378
pub enum SpillCompression {
297379
Zstd,

datafusion/core/benches/sql_planner.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use criterion::Bencher;
3030
use datafusion::datasource::MemTable;
3131
use datafusion::execution::context::SessionContext;
3232
use datafusion::prelude::DataFrame;
33-
use datafusion_common::ScalarValue;
33+
use datafusion_common::{config::Dialect, ScalarValue};
3434
use datafusion_expr::Expr::Literal;
3535
use datafusion_expr::{cast, col, lit, not, try_cast, when};
3636
use datafusion_functions::expr_fn::{
@@ -288,7 +288,10 @@ fn benchmark_with_param_values_many_columns(
288288
}
289289
// SELECT max(attr0), ..., max(attrN) FROM t1.
290290
let query = format!("SELECT {aggregates} FROM t1");
291-
let statement = ctx.state().sql_to_statement(&query, "Generic").unwrap();
291+
let statement = ctx
292+
.state()
293+
.sql_to_statement(&query, &Dialect::Generic)
294+
.unwrap();
292295
let plan =
293296
rt.block_on(async { ctx.state().statement_to_plan(statement).await.unwrap() });
294297
b.iter(|| {

datafusion/core/src/execution/session_state.rs

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,14 @@ use crate::datasource::provider_as_source;
3030
use crate::execution::context::{EmptySerializerRegistry, FunctionFactory, QueryPlanner};
3131
use crate::execution::SessionStateDefaults;
3232
use crate::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
33+
use arrow::datatypes::DataType;
3334
use datafusion_catalog::information_schema::{
3435
InformationSchemaProvider, INFORMATION_SCHEMA,
3536
};
36-
37-
use arrow::datatypes::DataType;
3837
use datafusion_catalog::MemoryCatalogProviderList;
3938
use datafusion_catalog::{TableFunction, TableFunctionImpl};
4039
use datafusion_common::alias::AliasGenerator;
41-
use datafusion_common::config::{ConfigExtension, ConfigOptions, TableOptions};
40+
use datafusion_common::config::{ConfigExtension, ConfigOptions, Dialect, TableOptions};
4241
use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan};
4342
use datafusion_common::tree_node::TreeNode;
4443
use datafusion_common::{
@@ -374,7 +373,7 @@ impl SessionState {
374373
pub fn sql_to_statement(
375374
&self,
376375
sql: &str,
377-
dialect: &str,
376+
dialect: &Dialect,
378377
) -> datafusion_common::Result<Statement> {
379378
let dialect = dialect_from_str(dialect).ok_or_else(|| {
380379
plan_datafusion_err!(
@@ -411,7 +410,7 @@ impl SessionState {
411410
pub fn sql_to_expr(
412411
&self,
413412
sql: &str,
414-
dialect: &str,
413+
dialect: &Dialect,
415414
) -> datafusion_common::Result<SQLExpr> {
416415
self.sql_to_expr_with_alias(sql, dialect).map(|x| x.expr)
417416
}
@@ -423,7 +422,7 @@ impl SessionState {
423422
pub fn sql_to_expr_with_alias(
424423
&self,
425424
sql: &str,
426-
dialect: &str,
425+
dialect: &Dialect,
427426
) -> datafusion_common::Result<SQLExprWithAlias> {
428427
let dialect = dialect_from_str(dialect).ok_or_else(|| {
429428
plan_datafusion_err!(
@@ -527,8 +526,8 @@ impl SessionState {
527526
&self,
528527
sql: &str,
529528
) -> datafusion_common::Result<LogicalPlan> {
530-
let dialect = self.config.options().sql_parser.dialect.as_str();
531-
let statement = self.sql_to_statement(sql, dialect)?;
529+
let dialect = self.config.options().sql_parser.dialect;
530+
let statement = self.sql_to_statement(sql, &dialect)?;
532531
let plan = self.statement_to_plan(statement).await?;
533532
Ok(plan)
534533
}
@@ -542,9 +541,9 @@ impl SessionState {
542541
sql: &str,
543542
df_schema: &DFSchema,
544543
) -> datafusion_common::Result<Expr> {
545-
let dialect = self.config.options().sql_parser.dialect.as_str();
544+
let dialect = self.config.options().sql_parser.dialect;
546545

547-
let sql_expr = self.sql_to_expr_with_alias(sql, dialect)?;
546+
let sql_expr = self.sql_to_expr_with_alias(sql, &dialect)?;
548547

549548
let provider = SessionContextProvider {
550549
state: self,
@@ -2034,6 +2033,7 @@ mod tests {
20342033
use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray};
20352034
use arrow::datatypes::{DataType, Field, Schema};
20362035
use datafusion_catalog::MemoryCatalogProviderList;
2036+
use datafusion_common::config::Dialect;
20372037
use datafusion_common::DFSchema;
20382038
use datafusion_common::Result;
20392039
use datafusion_execution::config::SessionConfig;
@@ -2059,8 +2059,8 @@ mod tests {
20592059
let sql = "[1,2,3]";
20602060
let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
20612061
let df_schema = DFSchema::try_from(schema)?;
2062-
let dialect = state.config.options().sql_parser.dialect.as_str();
2063-
let sql_expr = state.sql_to_expr(sql, dialect)?;
2062+
let dialect = state.config.options().sql_parser.dialect;
2063+
let sql_expr = state.sql_to_expr(sql, &dialect)?;
20642064

20652065
let query = SqlToRel::new_with_options(&provider, state.get_parser_options());
20662066
query.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new())
@@ -2218,7 +2218,8 @@ mod tests {
22182218
}
22192219

22202220
let state = &context_provider.state;
2221-
let statement = state.sql_to_statement("select count(*) from t", "mysql")?;
2221+
let statement =
2222+
state.sql_to_statement("select count(*) from t", &Dialect::MySQL)?;
22222223
let plan = SqlToRel::new(&context_provider).statement_to_plan(statement)?;
22232224
state.create_physical_plan(&plan).await
22242225
}

0 commit comments

Comments
 (0)