Skip to content

Commit

Permalink
feat(expr): support user-defined CAST and TRY_CAST
Browse files Browse the repository at this point in the history
  • Loading branch information
andylokandy committed Jul 17, 2022
1 parent 4dff5ec commit c5767aa
Show file tree
Hide file tree
Showing 9 changed files with 408 additions and 52 deletions.
43 changes: 38 additions & 5 deletions common/expression/src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,16 @@ impl Display for RawExpr {
match self {
RawExpr::Literal { lit, .. } => write!(f, "{lit}"),
RawExpr::ColumnRef { id, data_type, .. } => write!(f, "ColumnRef({id})::{data_type}"),
RawExpr::Cast {
expr, dest_type, ..
} => {
write!(f, "CAST({expr} AS {dest_type})")
}
RawExpr::TryCast {
expr, dest_type, ..
} => {
write!(f, "TRY_CAST({expr} AS {dest_type})")
}
RawExpr::FunctionCall {
name, args, params, ..
} => {
Expand Down Expand Up @@ -182,6 +192,16 @@ impl Display for Expr {
match self {
Expr::Literal { lit, .. } => write!(f, "{lit}"),
Expr::ColumnRef { id, .. } => write!(f, "ColumnRef({id})"),
Expr::Cast {
expr, dest_type, ..
} => {
write!(f, "CAST<dest_type={dest_type}>({expr})")
}
Expr::TryCast {
expr, dest_type, ..
} => {
write!(f, "TRY_CAST<dest_type={dest_type}>({expr})")
}
Expr::FunctionCall {
function,
args,
Expand Down Expand Up @@ -216,11 +236,24 @@ impl Display for Expr {
}
write!(f, ")")
}
Expr::Cast {
expr, dest_type, ..
} => {
write!(f, "cast<dest_type={dest_type}>({expr})")
}
}
}
}

impl<T: ValueType> Debug for Value<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Value::Scalar(s) => write!(f, "Scalar({:?})", s),
Value::Column(c) => write!(f, "Column({:?})", c),
}
}
}

impl<'a, T: ValueType> Debug for ValueRef<'a, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ValueRef::Scalar(s) => write!(f, "Scalar({:?})", s),
ValueRef::Column(c) => write!(f, "Column({:?})", c),
}
}
}
Expand Down
217 changes: 193 additions & 24 deletions common/expression/src/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use common_arrow::arrow::bitmap;
use itertools::Itertools;

use crate::chunk::Chunk;
Expand Down Expand Up @@ -74,16 +75,24 @@ impl Evaluator {
dest_type,
} => {
let value = self.run(expr)?;
self.run_cast(value, dest_type, span.clone())
self.run_cast(span.clone(), value, dest_type)
}
Expr::TryCast {
span,
expr,
dest_type,
} => {
let value = self.run(expr)?;
Ok(self.run_try_cast(span.clone(), value, dest_type))
}
}
}

pub fn run_cast(
&self,
span: Span,
input: Value<AnyType>,
dest_type: &DataType,
span: Span,
) -> Result<Value<AnyType>> {
match input {
Value::Scalar(scalar) => match (scalar, dest_type) {
Expand All @@ -93,11 +102,11 @@ impl Evaluator {
Ok(Value::Scalar(Scalar::Array(column)))
}
(scalar, DataType::Nullable(dest_ty)) => {
self.run_cast(Value::Scalar(scalar), dest_ty, span)
self.run_cast(span, Value::Scalar(scalar), dest_ty)
}
(Scalar::Array(array), DataType::Array(dest_ty)) => {
let array = self
.run_cast(Value::Column(array), dest_ty, span)?
.run_cast(span, Value::Column(array), dest_ty)?
.into_column()
.ok()
.unwrap();
Expand Down Expand Up @@ -132,15 +141,16 @@ impl Evaluator {
}
Ok(Value::Column(builder.build()))
}
(Column::EmptyArray { len }, DataType::Array(dest_ty)) => {
Ok(Value::Column(Column::Array {
array: Box::new(ColumnBuilder::with_capacity(dest_ty, 0).build()),
offsets: vec![0; len + 1].into(),
}))
(Column::EmptyArray { len }, DataType::Array(_)) => {
let mut builder = ColumnBuilder::with_capacity(dest_type, len);
for _ in 0..len {
builder.push_default();
}
Ok(Value::Column(builder.build()))
}
(Column::Nullable { column, validity }, DataType::Nullable(dest_ty)) => {
let column = self
.run_cast(Value::Column(*column), dest_ty, span)?
.run_cast(span, Value::Column(*column), dest_ty)?
.into_column()
.ok()
.unwrap();
Expand All @@ -151,7 +161,7 @@ impl Evaluator {
}
(col, DataType::Nullable(dest_ty)) => {
let column = self
.run_cast(Value::Column(col), dest_ty, span)?
.run_cast(span, Value::Column(col), dest_ty)?
.into_column()
.ok()
.unwrap();
Expand All @@ -162,7 +172,7 @@ impl Evaluator {
}
(Column::Array { array, offsets }, DataType::Array(dest_ty)) => {
let array = self
.run_cast(Value::Column(*array), dest_ty, span)?
.run_cast(span, Value::Column(*array), dest_ty)?
.into_column()
.ok()
.unwrap();
Expand All @@ -171,14 +181,14 @@ impl Evaluator {
offsets,
}))
}
(Column::UInt8(column), DataType::UInt16) => Ok(Value::Column(Column::UInt16(
column.iter().map(|v| *v as u16).collect(),
(Column::UInt8(col), DataType::UInt16) => Ok(Value::Column(Column::UInt16(
col.iter().map(|v| *v as u16).collect(),
))),
(Column::Int8(column), DataType::Int16) => Ok(Value::Column(Column::Int16(
column.iter().map(|v| *v as i16).collect(),
(Column::Int8(col), DataType::Int16) => Ok(Value::Column(Column::Int16(
col.iter().map(|v| *v as i16).collect(),
))),
(Column::UInt8(column), DataType::Int16) => Ok(Value::Column(Column::Int16(
column.iter().map(|v| *v as i16).collect(),
(Column::UInt8(col), DataType::Int16) => Ok(Value::Column(Column::Int16(
col.iter().map(|v| *v as i16).collect(),
))),
(col @ Column::Boolean(_), DataType::Boolean)
| (col @ Column::String { .. }, DataType::String)
Expand All @@ -192,6 +202,85 @@ impl Evaluator {
}
}

pub fn run_try_cast(
&self,
span: Span,
input: Value<AnyType>,
dest_type: &DataType,
) -> Value<AnyType> {
let inner_type: &DataType = dest_type.as_nullable().unwrap();
match input {
Value::Scalar(scalar) => self
.run_cast(span, Value::Scalar(scalar), inner_type)
.unwrap_or(Value::Scalar(Scalar::Null)),
Value::Column(col) => match (col, inner_type) {
(_, DataType::Null | DataType::Nullable(_)) => {
unreachable!("inner type can not be nullable")
}
(Column::Null { len }, _) => {
let mut builder = ColumnBuilder::with_capacity(dest_type, len);
for _ in 0..len {
builder.push_default();
}
Value::Column(builder.build())
}
(Column::EmptyArray { len }, DataType::Array(_)) => {
let mut builder = ColumnBuilder::with_capacity(dest_type, len);
for _ in 0..len {
builder.push_default();
}
Value::Column(builder.build())
}
(Column::Nullable { column, validity }, _) => {
let (new_col, new_validity) = self
.run_try_cast(span, Value::Column(*column), dest_type)
.into_column()
.unwrap()
.into_nullable()
.unwrap();
Value::Column(Column::Nullable {
column: new_col,
validity: bitmap::or(&validity, &new_validity),
})
}
(Column::Array { array, offsets }, DataType::Array(dest_ty)) => {
todo!()
}
(Column::UInt8(col), DataType::UInt16) => Value::Column(Column::Nullable {
column: Box::new(Column::UInt16(col.iter().map(|v| *v as u16).collect())),
validity: constant_bitmap(true, col.len()).into(),
}),
(Column::Int8(col), DataType::Int16) => Value::Column(Column::Nullable {
column: Box::new(Column::Int16(col.iter().map(|v| *v as i16).collect())),
validity: constant_bitmap(true, col.len()).into(),
}),
(Column::UInt8(col), DataType::Int16) => Value::Column(Column::Nullable {
column: Box::new(Column::Int16(col.iter().map(|v| *v as i16).collect())),
validity: constant_bitmap(true, col.len()).into(),
}),
(col @ Column::Boolean(_), DataType::Boolean)
| (col @ Column::String { .. }, DataType::String)
| (col @ Column::UInt8(_), DataType::UInt8)
| (col @ Column::Int8(_), DataType::Int8)
| (col @ Column::Int16(_), DataType::Int16)
| (col @ Column::EmptyArray { .. }, DataType::EmptyArray) => {
Value::Column(Column::Nullable {
validity: constant_bitmap(true, col.len()).into(),
column: Box::new(col),
})
}
(col, _) => {
let len = col.len();
let mut builder = ColumnBuilder::with_capacity(dest_type, len);
for _ in 0..len {
builder.push_default();
}
Value::Column(builder.build())
}
},
}
}

pub fn run_lit(&self, lit: &Literal) -> Scalar {
match lit {
Literal::Null => Scalar::Null,
Expand Down Expand Up @@ -220,7 +309,15 @@ impl DomainCalculator {
dest_type,
} => {
let domain = self.calculate(expr)?;
self.calculate_cast(&domain, dest_type, span.clone())
self.calculate_cast(span.clone(), &domain, dest_type)
}
Expr::TryCast {
span,
expr,
dest_type,
} => {
let domain = self.calculate(expr)?;
Ok(self.calculate_try_cast(span.clone(), &domain, dest_type))
}
Expr::FunctionCall {
function,
Expand Down Expand Up @@ -276,16 +373,18 @@ impl DomainCalculator {

pub fn calculate_cast(
&self,
span: Span,
input: &Domain,
dest_type: &DataType,
span: Span,
) -> Result<Domain> {
match (input, dest_type) {
(
Domain::Nullable(NullableDomain { value: None, .. }),
DataType::Null | DataType::Nullable(_),
) => Ok(input.clone()),
(Domain::Array(None), DataType::EmptyArray | DataType::Array(_)) => Ok(input.clone()),
(Domain::Array(None), DataType::EmptyArray | DataType::Array(_)) => {
Ok(Domain::Array(None))
}
(
Domain::Nullable(NullableDomain {
has_null,
Expand All @@ -294,14 +393,14 @@ impl DomainCalculator {
DataType::Nullable(ty),
) => Ok(Domain::Nullable(NullableDomain {
has_null: *has_null,
value: Some(Box::new(self.calculate_cast(value, ty, span)?)),
value: Some(Box::new(self.calculate_cast(span, value, ty)?)),
})),
(domain, DataType::Nullable(ty)) => Ok(Domain::Nullable(NullableDomain {
has_null: false,
value: Some(Box::new(self.calculate_cast(domain, ty, span)?)),
value: Some(Box::new(self.calculate_cast(span, domain, ty)?)),
})),
(Domain::Array(Some(domain)), DataType::Array(ty)) => Ok(Domain::Array(Some(
Box::new(self.calculate_cast(domain, ty, span)?),
Box::new(self.calculate_cast(span, domain, ty)?),
))),
(Domain::UInt(UIntDomain { min, max }), DataType::UInt16) => {
Ok(Domain::UInt(UIntDomain {
Expand All @@ -326,4 +425,74 @@ impl DomainCalculator {
(domain, dest_ty) => Err((span, (format!("unable to cast {domain} to {dest_ty}",)))),
}
}

pub fn calculate_try_cast(&self, span: Span, input: &Domain, dest_type: &DataType) -> Domain {
let inner_type: &DataType = dest_type.as_nullable().unwrap();
match (input, inner_type) {
(_, DataType::Null | DataType::Nullable(_)) => {
unreachable!("inner type cannot be nullable")
}
(Domain::Array(None), DataType::EmptyArray | DataType::Array(_)) => {
Domain::Nullable(NullableDomain {
has_null: false,
value: Some(Box::new(Domain::Array(None))),
})
}
(
Domain::Nullable(NullableDomain {
has_null,
value: Some(value),
}),
_,
) => {
let inner_domain: NullableDomain<AnyType> = self
.calculate_try_cast(span, value, dest_type)
.into_nullable()
.unwrap();
Domain::Nullable(NullableDomain {
has_null: *has_null || inner_domain.has_null,
value: inner_domain.value,
})
}
(Domain::Array(Some(domain)), DataType::Array(ty)) => todo!(),
(Domain::UInt(UIntDomain { min, max }), DataType::UInt16) => {
Domain::Nullable(NullableDomain {
has_null: false,
value: Some(Box::new(Domain::UInt(UIntDomain {
min: (*min).min(u16::MAX as u64),
max: (*max).min(u16::MAX as u64),
}))),
})
}
(Domain::Int(IntDomain { min, max }), DataType::Int16) => {
Domain::Nullable(NullableDomain {
has_null: false,
value: Some(Box::new(Domain::Int(IntDomain {
min: (*min).max(i16::MIN as i64).min(i16::MAX as i64),
max: (*max).max(i16::MIN as i64).min(i16::MAX as i64),
}))),
})
}
(Domain::UInt(UIntDomain { min, max }), DataType::Int16) => {
Domain::Nullable(NullableDomain {
has_null: false,
value: Some(Box::new(Domain::Int(IntDomain {
min: (*min).min(i16::MAX as u64) as i64,
max: (*max).min(i16::MAX as u64) as i64,
}))),
})
}
(Domain::Boolean(_), DataType::Boolean)
| (Domain::String(_), DataType::String)
| (Domain::UInt(_), DataType::UInt8)
| (Domain::Int(_), DataType::Int8) => Domain::Nullable(NullableDomain {
has_null: false,
value: Some(Box::new(input.clone())),
}),
_ => Domain::Nullable(NullableDomain {
has_null: true,
value: None,
}),
}
}
}
Loading

0 comments on commit c5767aa

Please sign in to comment.