From 529182aa9ff7c6c5624faa8f3df6351760034ec5 Mon Sep 17 00:00:00 2001 From: Bei Chu <914745487@qq.com> Date: Sat, 23 Mar 2024 22:26:32 +0800 Subject: [PATCH] fix: Clickhouse conversion of `Null`, `Text` and `Decimal` (#2466) --- dozer-sink-clickhouse/src/client.rs | 8 +- dozer-sink-clickhouse/src/errors.rs | 18 +- dozer-sink-clickhouse/src/schema.rs | 2 +- dozer-sink-clickhouse/src/sink.rs | 11 +- dozer-sink-clickhouse/src/types.rs | 323 +++++++++++++--------------- 5 files changed, 178 insertions(+), 184 deletions(-) diff --git a/dozer-sink-clickhouse/src/client.rs b/dozer-sink-clickhouse/src/client.rs index 3e2ae5ae3e..3df34f4c3a 100644 --- a/dozer-sink-clickhouse/src/client.rs +++ b/dozer-sink-clickhouse/src/client.rs @@ -123,21 +123,21 @@ impl ClickhouseClient { &self, table_name: &str, fields: &[FieldDefinition], - values: &[Field], + values: Vec, query_id: Option, ) -> Result<(), QueryError> { let client = self.pool.get_handle().await?; - insert_multi(client, table_name, fields, &[values.to_vec()], query_id).await + insert_multi(client, table_name, fields, vec![values], query_id).await } pub async fn insert_multi( &self, table_name: &str, fields: &[FieldDefinition], - values: &[Vec], + rows: Vec>, query_id: Option, ) -> Result<(), QueryError> { let client = self.pool.get_handle().await?; - insert_multi(client, table_name, fields, values, query_id).await + insert_multi(client, table_name, fields, rows, query_id).await } } diff --git a/dozer-sink-clickhouse/src/errors.rs b/dozer-sink-clickhouse/src/errors.rs index d3fc752c5c..afa0201f8d 100644 --- a/dozer-sink-clickhouse/src/errors.rs +++ b/dozer-sink-clickhouse/src/errors.rs @@ -1,4 +1,7 @@ -use dozer_types::thiserror::{self, Error}; +use dozer_types::{ + thiserror::{self, Error}, + types::FieldType, +}; #[derive(Error, Debug)] pub enum ClickhouseSinkError { @@ -32,8 +35,17 @@ pub enum QueryError { #[error("Clickhouse error: {0:?}")] DataFetchError(#[from] clickhouse_rs::errors::Error), - #[error("Unexpected field type for {0:?}, expected {0}")] - TypeMismatch(String, String), + #[error("Unexpected field type for {field_name:?}, expected {field_type:?}")] + TypeMismatch { + field_name: String, + field_type: FieldType, + }, + + #[error("Decimal overflow")] + DecimalOverflow, + + #[error("Unsupported field type {0:?}")] + UnsupportedFieldType(FieldType), #[error("{0:?}")] CustomError(String), diff --git a/dozer-sink-clickhouse/src/schema.rs b/dozer-sink-clickhouse/src/schema.rs index ca920e8562..2c45cc11b4 100644 --- a/dozer-sink-clickhouse/src/schema.rs +++ b/dozer-sink-clickhouse/src/schema.rs @@ -1,6 +1,5 @@ use crate::client::ClickhouseClient; use crate::errors::ClickhouseSinkError::{self, SinkTableDoesNotExist}; -use crate::types::DECIMAL_SCALE; use clickhouse_rs::types::Complex; use clickhouse_rs::{Block, ClientHandle}; use dozer_types::log::warn; @@ -145,6 +144,7 @@ impl ClickhouseSchema { } pub fn map_field_to_type(field: &FieldDefinition) -> String { + const DECIMAL_SCALE: u8 = 4; let decimal = format!("Decimal(10, {})", DECIMAL_SCALE); let typ: &str = match field.typ { FieldType::UInt => "UInt64", diff --git a/dozer-sink-clickhouse/src/sink.rs b/dozer-sink-clickhouse/src/sink.rs index f09be7ff29..4b0c70b4bb 100644 --- a/dozer-sink-clickhouse/src/sink.rs +++ b/dozer-sink-clickhouse/src/sink.rs @@ -178,7 +178,7 @@ impl ClickhouseSink { .insert( REPLICA_METADATA_TABLE, &self.metadata.schema.fields, - &[ + vec![ Field::String(self.sink_table_name.clone()), Field::UInt(txid), ], @@ -196,22 +196,17 @@ impl ClickhouseSink { } fn commit_batch(&mut self) -> Result<(), BoxedError> { + let batch = std::mem::take(&mut self.batch); self.runtime.block_on(async { //Insert batch self.client - .insert_multi( - &self.sink_table_name, - &self.schema.fields, - &self.batch, - None, - ) + .insert_multi(&self.sink_table_name, &self.schema.fields, batch, None) .await?; self.insert_metadata().await?; Ok::<(), BoxedError>(()) })?; - self.batch.clear(); Ok(()) } diff --git a/dozer-sink-clickhouse/src/types.rs b/dozer-sink-clickhouse/src/types.rs index 6ed83a1015..6b13a0055a 100644 --- a/dozer-sink-clickhouse/src/types.rs +++ b/dozer-sink-clickhouse/src/types.rs @@ -2,11 +2,10 @@ use crate::errors::QueryError; use chrono_tz::{Tz, UTC}; +use clickhouse_rs::types::column::ColumnFrom; use clickhouse_rs::{Block, ClientHandle}; -use dozer_types::chrono::{DateTime, FixedOffset, NaiveDate, Offset, TimeZone}; -use dozer_types::json_types::JsonValue; +use dozer_types::chrono::{DateTime, FixedOffset, Offset, TimeZone}; use dozer_types::ordered_float::OrderedFloat; -use dozer_types::rust_decimal::prelude::ToPrimitive; use dozer_types::rust_decimal::{self}; use dozer_types::serde_json; use dozer_types::types::{Field, FieldDefinition, FieldType}; @@ -14,7 +13,6 @@ use either::Either; use clickhouse_rs::types::{FromSql, Query, Value, ValueRef}; -pub const DECIMAL_SCALE: u8 = 4; pub struct ValueWrapper(pub Value); impl<'a> FromSql<'a> for ValueWrapper { @@ -139,186 +137,148 @@ fn convert_to_fixed_offset(datetime_tz: DateTime) -> Option QueryError { - QueryError::TypeMismatch(expected_type.to_string(), field_name.to_string()) +fn extract_last_column( + rows: &mut [Vec], + mut mapper: Mapper, +) -> Result, QueryError> +where + Mapper: FnMut(Field) -> Result, +{ + rows.iter_mut() + .map(|row| mapper(row.pop().expect("must still have column"))) + .collect() } -macro_rules! handle_type { - ($nullable: expr, $b: expr, $field_type:ident, $rust_type:ty, $column_values:expr, $n:expr) => {{ - if $nullable { - let column_data: Vec> = $column_values.iter().map(Some).collect(); - let mut col: Vec> = vec![]; - for f in column_data { - let v = match f { - Some(Field::$field_type(v)) => Ok(Some(*v)), - None => Ok(None), - _ => Err(type_mismatch_error(stringify!($field_type), $n)), - }?; - col.push(v); - } - $b = $b.column($n, col); +fn make_nullable_mapper( + mut mapper: Mapper, +) -> impl FnMut(Field) -> Result, QueryError> +where + Mapper: FnMut(Field) -> Result, +{ + move |field| { + if matches!(field, Field::Null) { + Ok(None) } else { - let mut col: Vec<$rust_type> = vec![]; - for f in $column_values { - let v = match f { - Field::$field_type(v) => Ok(*v), - _ => Err(type_mismatch_error(stringify!($field_type), $n)), - }?; - col.push(v); - } - $b = $b.column($n, col); + mapper(field).map(Some) } - }}; + } } -macro_rules! handle_complex_type { - ($nullable: expr, $b: expr, $field_type:ident, $rust_type:ty, $column_values:expr, $n:expr, $complex_expr:expr) => {{ - if $nullable { - let column_data: Vec> = $column_values.iter().map(Some).collect(); - let mut col: Vec> = vec![]; - for f in column_data { - let v = match f { - Some(Field::$field_type(v)) => { - let v = $complex_expr(v); - Ok(v) - } - None => Ok(None), - _ => Err(type_mismatch_error(stringify!($field_type), $n)), - }?; - col.push(v); - } - $b = $b.column($n, col); +/// This is a closure that takes a generic parameter, +/// like C++'s templated labmda, which Rust doesn't support. +/// +/// Saves 4 parameters at every call site. +struct AddLastColumn<'a> { + block: Block, + name: &'a str, + rows: &'a mut [Vec], + nullable: bool, +} + +impl<'a> AddLastColumn<'a> { + fn call( + self, + mapper: Mapper, + ) -> Result, QueryError> + where + Vec: ColumnFrom, + Vec>: ColumnFrom, + Mapper: FnMut(Field) -> Result, + { + Ok(if self.nullable { + self.block.column( + self.name, + extract_last_column(self.rows, make_nullable_mapper(mapper))?, + ) } else { - let mut col: Vec<$rust_type> = vec![]; - for f in $column_values { - let v = match f { - Field::$field_type(v) => { - let v = $complex_expr(v); - match v { - Some(v) => Ok(v), - None => Err(type_mismatch_error(stringify!($field_type), $n)), - } - } - _ => Err(type_mismatch_error(stringify!($field_type), $n)), - }?; - col.push(v); + self.block + .column(self.name, extract_last_column(self.rows, mapper)?) + }) + } +} + +fn add_last_column_to_block( + block: Block, + name: &str, + rows: &mut [Vec], + field_type: FieldType, + nullable: bool, +) -> Result, QueryError> { + let make_error = || QueryError::TypeMismatch { + field_name: name.to_string(), + field_type, + }; + + macro_rules! trivial_mapper { + ($field_type:path) => { + |field| match field { + $field_type(value) => Ok(value), + _ => Err(make_error()), } - $b = $b.column($n, col); - } - }}; + }; + } + + let add_last_column = AddLastColumn { + block, + name, + rows, + nullable, + }; + + match field_type { + FieldType::UInt => add_last_column.call(trivial_mapper!(Field::UInt)), + FieldType::U128 => add_last_column.call(trivial_mapper!(Field::U128)), + FieldType::Int => add_last_column.call(trivial_mapper!(Field::Int)), + FieldType::I128 => add_last_column.call(trivial_mapper!(Field::I128)), + FieldType::Boolean => add_last_column.call(trivial_mapper!(Field::Boolean)), + FieldType::Float => add_last_column.call(|field| match field { + Field::Float(value) => Ok(value.0), + _ => Err(make_error()), + }), + FieldType::String => add_last_column.call(trivial_mapper!(Field::String)), + FieldType::Text => add_last_column.call(trivial_mapper!(Field::Text)), + FieldType::Binary => add_last_column.call(trivial_mapper!(Field::Binary)), + FieldType::Decimal => add_last_column.call(|field| match field { + Field::Decimal(value) => { + // This is hardcoded in `clickhouse-rs`. + if value.scale() > 18 { + return Err(QueryError::DecimalOverflow); + } + let mantissa: i64 = value + .mantissa() + .try_into() + .map_err(|_| QueryError::DecimalOverflow)?; + Ok(clickhouse_rs::types::Decimal::new( + mantissa, + value.scale() as u8, + )) + } + _ => Err(make_error()), + }), + FieldType::Timestamp => add_last_column.call(|field| match field { + Field::Timestamp(value) => Ok(value.with_timezone(&UTC)), + _ => Err(make_error()), + }), + FieldType::Date => add_last_column.call(trivial_mapper!(Field::Date)), + FieldType::Json => add_last_column.call(|field| match field { + Field::Json(value) => Ok(dozer_types::json_types::json_to_bytes(&value)), + _ => Err(make_error()), + }), + other => Err(QueryError::UnsupportedFieldType(other)), + } } pub async fn insert_multi( mut client: ClientHandle, table_name: &str, fields: &[FieldDefinition], - values: &[Vec], + mut rows: Vec>, query_id: Option, ) -> Result<(), QueryError> { - let mut b = Block::::new(); - - for (field_index, fd) in fields.iter().enumerate() { - let column_values: Vec<_> = values.iter().map(|row| &row[field_index]).collect(); + let mut block = Block::::new(); - let n = &fd.name; - let nullable = fd.nullable; - match fd.typ { - FieldType::UInt => handle_type!(nullable, b, UInt, u64, column_values, n), - FieldType::U128 => handle_type!(nullable, b, U128, u128, column_values, n), - FieldType::Int => handle_type!(nullable, b, Int, i64, column_values, n), - FieldType::I128 => handle_type!(nullable, b, I128, i128, column_values, n), - FieldType::Boolean => handle_type!(nullable, b, Boolean, bool, column_values, n), - FieldType::Float => { - handle_complex_type!( - nullable, - b, - Float, - f64, - column_values, - n, - |f: &OrderedFloat| -> Option { f.to_f64() } - ) - } - FieldType::String | FieldType::Text => { - handle_complex_type!( - nullable, - b, - String, - String, - column_values, - n, - |f: &String| -> Option { Some(f.to_string()) } - ) - } - FieldType::Binary => { - handle_complex_type!( - nullable, - b, - Binary, - Vec, - column_values, - n, - |f: &Vec| -> Option> { Some(f.clone()) } - ) - } - FieldType::Decimal => { - handle_complex_type!( - nullable, - b, - Decimal, - clickhouse_rs::types::Decimal, - column_values, - n, - |f: &rust_decimal::Decimal| -> Option { - f.to_f64() - .map(|f| clickhouse_rs::types::Decimal::of(f, DECIMAL_SCALE)) - } - ) - } - FieldType::Timestamp => { - handle_complex_type!( - nullable, - b, - Timestamp, - DateTime, - column_values, - n, - |dt: &DateTime| -> Option> { - Some(dt.with_timezone(&UTC)) - } - ) - } - FieldType::Date => { - handle_complex_type!( - nullable, - b, - Date, - NaiveDate, - column_values, - n, - |f: &NaiveDate| -> Option { Some(*f) } - ) - } - FieldType::Json => { - handle_complex_type!( - nullable, - b, - Json, - Vec, - column_values, - n, - |f: &JsonValue| -> Option> { - Some(dozer_types::json_types::json_to_bytes(f)) - } - ) - } - ft => { - return Err(QueryError::CustomError(format!( - "Unsupported field_type {} for {}", - ft, n - ))); - } - } + for field in fields.iter().rev() { + block = add_last_column_to_block(block, &field.name, &mut rows, field.typ, field.nullable)?; } let query_id = query_id.unwrap_or("".to_string()); @@ -326,7 +286,34 @@ pub async fn insert_multi( let table = Query::new(table_name).id(query_id); // Insert the block into the table - client.insert(table, b).await?; + client.insert(table, block).await?; Ok(()) } + +mod tests { + #[test] + fn test_add_last_column_to_block() { + use super::*; + use dozer_types::rust_decimal::prelude::ToPrimitive; + let dozer_decimal = dozer_types::rust_decimal::Decimal::new(123, 10); + let mut rows = vec![vec![ + Field::Null, + Field::Text("text".to_string()), + Field::Decimal(dozer_decimal), + ]]; + let mut block = Block::::new(); + block = add_last_column_to_block(block, "decimal", &mut rows, FieldType::Decimal, false) + .unwrap(); + block = add_last_column_to_block(block, "text", &mut rows, FieldType::Text, false).unwrap(); + block = add_last_column_to_block(block, "null", &mut rows, FieldType::UInt, true).unwrap(); + let decimal = block + .get_column("decimal") + .unwrap() + .iter::() + .unwrap() + .next() + .unwrap(); + assert_eq!(Into::::into(decimal), dozer_decimal.to_f64().unwrap()); + } +}