Skip to content

Commit

Permalink
simplify the macro rules
Browse files Browse the repository at this point in the history
  • Loading branch information
goldmedal committed Sep 20, 2024
1 parent aa80959 commit a8ea6cf
Showing 1 changed file with 28 additions and 53 deletions.
81 changes: 28 additions & 53 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1811,6 +1811,22 @@ mod test {
Ok(())
}

macro_rules! test_case_expression {
($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => {
let case = Case {
expr: $expr.map(|e| Box::new(col(e))),
when_then_expr: $when_then,
else_expr: None,
};

let expected =
cast_helper(case.clone(), &$case_when_type, &$then_else_type, &$schema);

let actual = coerce_case_expression(case, &$schema)?;
assert_eq!(expected, actual);
};
}

#[test]
fn tes_case_when_list() -> Result<()> {
let inner_field = Arc::new(Field::new("item", DataType::Int64, true));
Expand All @@ -1832,70 +1848,48 @@ mod test {
std::collections::HashMap::new(),
)?);

macro_rules! test_case_expression {
($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => {
let case = Case {
expr: Some(Box::new(col($expr))),
when_then_expr: $when_then,
else_expr: None,
};

let case_when_common_type = $case_when_type;
let then_else_common_type = $then_else_type;
let expected = cast_helper(
case.clone(),
&case_when_common_type,
&then_else_common_type,
&$schema,
);

let actual = coerce_case_expression(case, &$schema)?;
assert_eq!(expected, actual);
};
}

test_case_expression!(
"list",
Some("list"),
vec![(Box::new(col("large_list")), Box::new(lit("1")))],
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
DataType::Utf8,
schema
);

test_case_expression!(
"large_list",
Some("large_list"),
vec![(Box::new(col("list")), Box::new(lit("1")))],
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
DataType::Utf8,
schema
);

test_case_expression!(
"list",
Some("list"),
vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
DataType::Utf8,
schema
);

test_case_expression!(
"fixed_list",
Some("fixed_list"),
vec![(Box::new(col("list")), Box::new(lit("1")))],
DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
DataType::Utf8,
schema
);

test_case_expression!(
"fixed_list",
Some("fixed_list"),
vec![(Box::new(col("large_list")), Box::new(lit("1")))],
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
DataType::Utf8,
schema
);

test_case_expression!(
"large_list",
Some("large_list"),
vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
DataType::Utf8,
Expand Down Expand Up @@ -1925,29 +1919,10 @@ mod test {
.into(),
std::collections::HashMap::new(),
)?);
macro_rules! test_case_expression {
($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => {
let case = Case {
expr: None,
when_then_expr: $when_then,
else_expr: None,
};

let expected = cast_helper(
case.clone(),
&$case_when_type,
&$then_else_type,
&$schema,
);

let actual = coerce_case_expression(case, &$schema)?;
assert_eq!(expected, actual);
};
}

// large list and list
test_case_expression!(
"boolean",
None::<String>,
vec![
(Box::new(col("boolean")), Box::new(col("large_list"))),
(Box::new(col("boolean")), Box::new(col("list")))
Expand All @@ -1958,7 +1933,7 @@ mod test {
);

test_case_expression!(
"boolean",
None::<String>,
vec![
(Box::new(col("boolean")), Box::new(col("list"))),
(Box::new(col("boolean")), Box::new(col("large_list")))
Expand All @@ -1970,7 +1945,7 @@ mod test {

// fixed list and list
test_case_expression!(
"boolean",
None::<String>,
vec![
(Box::new(col("boolean")), Box::new(col("fixed_list"))),
(Box::new(col("boolean")), Box::new(col("list")))
Expand All @@ -1981,7 +1956,7 @@ mod test {
);

test_case_expression!(
"boolean",
None::<String>,
vec![
(Box::new(col("boolean")), Box::new(col("list"))),
(Box::new(col("boolean")), Box::new(col("fixed_list")))
Expand All @@ -1993,7 +1968,7 @@ mod test {

// fixed list and large list
test_case_expression!(
"boolean",
None::<String>,
vec![
(Box::new(col("boolean")), Box::new(col("fixed_list"))),
(Box::new(col("boolean")), Box::new(col("large_list")))
Expand All @@ -2004,7 +1979,7 @@ mod test {
);

test_case_expression!(
"boolean",
None::<String>,
vec![
(Box::new(col("boolean")), Box::new(col("large_list"))),
(Box::new(col("boolean")), Box::new(col("fixed_list")))
Expand Down

0 comments on commit a8ea6cf

Please sign in to comment.