Skip to content

Commit

Permalink
Flattening in Row derive macro + documentation of attributes
Browse files Browse the repository at this point in the history
Following issue #30, this implements the serde-like `flatten` attribute on the `Row` derive macro, allowing to compose rows as follows:
```
struct Row {
  #[klickhouse(flatten)]
  user: User,
  credits: u32
}
let users: Vec<Row> = ch.query_collect("SELECT age, name, credits FROM ...").await?;
```

The commit also documents the serde- and clickhouse-specific attributes supported by the derive macro.
  • Loading branch information
cpg314 committed Dec 25, 2023
1 parent d21f52e commit d58c633
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 30 deletions.
3 changes: 3 additions & 0 deletions klickhouse/src/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ impl Iterator for BlockRowIntoIter {

fn next(&mut self) -> Option<Self::Item> {
let mut out = IndexMap::new();
if self.column_data.is_empty() {
return None;
}
for (name, value) in self.column_data.iter_mut() {
out.insert(name.clone(), value.pop_front()?);
}
Expand Down
25 changes: 25 additions & 0 deletions klickhouse/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,31 @@ pub use manager::ConnectionManager;
pub use uuid::Uuid;

#[cfg(feature = "derive")]
/// Derive macro for the [Row] trait.
///
/// This is similar in usage and implementation to the [serde::Serialize] and [serde::Deserialize] derive macros.
///
/// ## serde attributes
/// The following [serde attributes](https://serde.rs/attributes.html) are supported, using `#[klickhouse(...)]` instead of `#[serde(...)]`:
/// - `with`
/// - `from` and `into`
/// - `try_from`
/// - `skip`
/// - `default`
/// - `deny_unknown_fields`
/// - `rename`
/// - `rename_all`
/// - `serialize_with`, `deserialize_with`
/// - `skip_deserializing`, `skip_serializing`
/// - `flatten`
/// - Index-based matching is disabled (the column names must match exactly).
/// - Due to the current interface of the [Row] trait, performance might not be optimal, as a value map must be reconstitued for each flattened subfield.
///
/// ## Clickhouse-specific attributes
/// - The `nested` attribute allows handling [Clickhouse nested data structures](https://clickhouse.com/docs/en/sql-reference/data-types/nested-data-structures/nested). See an example in the `tests` folder.
///
/// ## Known issues
/// - For serialization, the ordering of fields in the struct declaration must match the order in the `INSERT` statement, respectively in the table declaration. See issue [#34](https://github.com/Protryon/klickhouse/issues/34).
pub use klickhouse_derive::Row;

pub use client::*;
Expand Down
1 change: 1 addition & 0 deletions klickhouse/tests/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod test;
pub mod test_bytes;
pub mod test_decimal;
pub mod test_flatten;
pub mod test_lock;
pub mod test_nested;
pub mod test_raw_string;
Expand Down
59 changes: 59 additions & 0 deletions klickhouse/tests/test_flatten.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use klickhouse::Row;

#[derive(klickhouse::Row, Debug, Default, PartialEq, Clone)]
pub struct TestRow {
field: u32,
#[klickhouse(flatten)]
subrow: SubRow,
field2: u32,
}

#[derive(klickhouse::Row, Debug, Default, PartialEq, Clone)]
pub struct SubRow {
a: u32,
b: f32,
}

#[tokio::test]
async fn test_client() {
env_logger::builder()
.filter_level(log::LevelFilter::Info)
.init();

assert!(TestRow::column_names()
.unwrap()
.into_iter()
.zip(["field", "a", "b"])
.all(|(x, y)| x == y));

let client = super::get_client().await;

super::prepare_table(
"test_flatten",
"field UInt32,
field2 UInt32,
a UInt32,
b Float32",
&client,
)
.await;

let row = TestRow {
field: 1,
field2: 4,
subrow: SubRow { a: 2, b: 3.0 },
};

client
.insert_native_block("INSERT INTO test_flatten FORMAT Native", vec![row.clone()])
.await
.unwrap();

tokio::time::sleep(std::time::Duration::from_secs(1)).await;

let row2 = client
.query_one::<TestRow>("SELECT * FROM test_flatten")
.await
.unwrap();
assert_eq!(row, row2);
}
12 changes: 12 additions & 0 deletions klickhouse_derive/src/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ pub struct Field {
deserialize_with: Option<syn::ExprPath>,
bound: Option<Vec<syn::WherePredicate>>,
nested: bool,
flatten: bool,
}

#[allow(clippy::enum_variant_names)]
Expand All @@ -360,6 +361,7 @@ impl Field {
let mut nested = BoolAttr::none(cx, NESTED);
let mut skip_serializing = BoolAttr::none(cx, SKIP_SERIALIZING);
let mut skip_deserializing = BoolAttr::none(cx, SKIP_DESERIALIZING);
let mut flatten = BoolAttr::none(cx, FLATTEN);
let mut default = Attr::none(cx, DEFAULT);
let mut serialize_with = Attr::none(cx, SERIALIZE_WITH);
let mut deserialize_with = Attr::none(cx, DESERIALIZE_WITH);
Expand Down Expand Up @@ -406,6 +408,11 @@ impl Field {
nested.set_true(word);
}

// Parse `#[klickhouse(flatten)]`
Meta(Path(word)) if word == FLATTEN => {
flatten.set_true(word);
}

// Parse `#[klickhouse(skip_deserializing)]`
Meta(Path(word)) if word == SKIP_DESERIALIZING => {
skip_deserializing.set_true(word);
Expand Down Expand Up @@ -492,6 +499,7 @@ impl Field {
deserialize_with: deserialize_with.get(),
bound: bound.get(),
nested: nested.get(),
flatten: flatten.get(),
}
}

Expand All @@ -505,6 +513,10 @@ impl Field {
}
}

pub fn flatten(&self) -> bool {
self.flatten
}

pub fn nested(&self) -> bool {
self.nested
}
Expand Down
131 changes: 101 additions & 30 deletions klickhouse_derive/src/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,19 @@ pub fn expand_derive_serialize(
};
ctxt.check()?;

let flatten = cont.data.iter().any(|f| f.attrs.flatten());

let ident = &cont.ident;
let params = Parameters::new(&cont);
let (impl_generics, ty_generics, where_clause) = params.generics.split_for_impl();
let serialize_body = Stmts(serialize_body(&cont, &params));
let deserialize_body = Stmts(deserialize_body(&cont, &params));
let serialize_length_body = Stmts(serialize_length_body(&cont, &params));
let column_names_body = Stmts(column_names_body(&cont, &params));
let serialize_body = Stmts(serialize_body(&cont, &params));
let serialize_length_body = if flatten {
Stmts(Fragment::Block(quote! { None }))
} else {
Stmts(serialize_length_body(&cont, &params))
};
let const_column_count_fn = format_ident!("__{ident}_column_count_klickhouse");

let impl_block = quote! {
Expand All @@ -111,6 +117,7 @@ pub fn expand_derive_serialize(
#[automatically_derived]
#[allow(clippy)]
#[allow(non_snake_case)]
#[allow(clippy::absurd_extreme_comparisons)]
impl #impl_generics ::klickhouse::Row for #ident #ty_generics #where_clause {
const COLUMN_COUNT: ::std::option::Option<usize> = #const_column_count_fn();

Expand Down Expand Up @@ -199,9 +206,12 @@ fn column_names_body(cont: &Container, _params: &Parameters) -> Fragment {
let name_sources = cont.data.iter().filter(|&field| !field.attrs.skip_serializing())
.map(|field| {
let name = field.attrs.name().name();
let ty = field.ty;
if field.attrs.nested() {
let field_ty = unwrap_vec_type(field.ty).expect("invalid non-Vec nested type");
quote! { out.extend(<#field_ty as ::klickhouse::Row>::column_names()?.into_iter().map(|x| ::std::borrow::Cow::Owned(format!("{}.{}", #name, x)))); }
} else if field.attrs.flatten(){
quote! { out.extend(#ty::column_names()?); }
} else {
quote! { out.push(::std::borrow::Cow::Borrowed(#name)); }
}
Expand Down Expand Up @@ -299,7 +309,14 @@ fn serialize_struct_visitor(fields: &[Field], params: &Parameters) -> Vec<TokenS
}
}
}
} else {
} else if field.attrs.flatten() {
quote! {
let inner_length = #field_ty::column_names().expect("column_names required for flattened struct serialization").len();
out.extend(#field_expr.serialize_row(&type_hints[field_index..field_index + inner_length])?);
field_index += inner_length;
}
}
else {
quote! {
out.push((::std::borrow::Cow::Borrowed(#key_expr), <#field_ty as ::klickhouse::ToSql>::to_sql(#field_expr, type_hints.get(field_index).copied())?));
field_index += 1;
Expand Down Expand Up @@ -379,16 +396,19 @@ fn deserialize_map(
.map(|(i, field)| (field, field_i(i)))
.collect();

let skip = |field: &Field| field.attrs.skip_deserializing() || field.attrs.flatten();

// Declare each field that will be deserialized.
let let_values = fields_names
.iter()
.filter(|&&(field, _)| !field.attrs.skip_deserializing())
.map(|(field, name)| {
let field_ty = field.ty;
quote! {
let mut #name: ::std::option::Option<#field_ty> = ::std::option::Option::None;
}
});
let let_values =
fields_names
.iter()
.filter(|&&(field, _)| !skip(field))
.map(|(field, name)| {
let field_ty = field.ty;
quote! {
let mut #name: ::std::option::Option<#field_ty> = ::std::option::Option::None;
}
});

// Match arms to extract a value for a field.
let mut name_match_arms = Vec::with_capacity(fields_names.len());
Expand All @@ -400,7 +420,7 @@ fn deserialize_map(
let mut current_index = quote! { 0usize };
fields_names
.iter()
.filter(|&&(field, _)| !field.attrs.skip_deserializing())
.filter(|&&(field, _)| !skip(field))
.for_each(|(field, name)| {
let deser_name = field.attrs.name().name();
let local_index = current_index.clone();
Expand Down Expand Up @@ -510,13 +530,45 @@ fn deserialize_map(
}
};

let index_match_arm = quote! {
match _field_index {
#(#index_match_arms)*
#ignored_arm
let index_match_arm = if fields.iter().any(|f| f.attrs.flatten()) {
// Disable index-based matching with flattening
quote! { {} }
} else {
quote! {
match _field_index {
#(#index_match_arms)*
#ignored_arm
}
}
};

// Extract values for flattened fields, before we move `map`.
let mut pull_flatten: Vec<TokenStream> = vec![quote! {
let mut map = map;
let mut map_flattened_fields = std::collections::HashMap::<&str, (&::klickhouse::Type, ::klickhouse::Value)>::default();
}];
for (f, _) in fields_names.iter() {
if !f.attrs.flatten() {
continue;
}
let ty = f.ty;
let name = f.original.ident.as_ref().unwrap();
let missing_names_error =
format!("Flattened field {} should provide Row::column_names", name);
// TODO: To give the actual field, we would need to change the type of
// KlickhouseError::MissingField from &'static str to Cow.
let missing_col_error = format!("Flattened field {} has missing column", name);
pull_flatten.push(quote! {
for c in #ty::column_names()
.ok_or_else(|| ::klickhouse::KlickhouseError::DeserializeError(#missing_names_error.into()))? {
let idx = map.iter().enumerate().find(|(_, (c2,_,_))| c2 == &c)
.ok_or(::klickhouse::KlickhouseError::MissingField(#missing_col_error))?.0;
let (col, ty, val) = map.swap_remove(idx);
map_flattened_fields.insert(col, (ty, val));
}
});
}

let match_keys = quote! {
#[allow(unused_comparisons)]
for (_field_index, (_name, _type_, _value)) in map.into_iter().enumerate() {
Expand All @@ -527,25 +579,42 @@ fn deserialize_map(
}
};

let extract_values = fields_names
.iter()
.filter(|&&(field, _)| !field.attrs.skip_deserializing())
.map(|(field, name)| {
let missing_expr = Match(expr_is_missing(field, cattrs));

quote! {
let #name = match #name {
::std::option::Option::Some(#name) => #name,
::std::option::Option::None => #missing_expr
};
}
});
let extract_values =
fields_names
.iter()
.filter(|&&(field, _)| !skip(field))
.map(|(field, name)| {
let missing_expr = Match(expr_is_missing(field, cattrs));

quote! {
let #name = match #name {
::std::option::Option::Some(#name) => #name,
::std::option::Option::None => #missing_expr
};
}
});

let result = fields_names.iter().map(|(field, name)| {
let member = &field.member;
if field.attrs.skip_deserializing() {
let value = Expr(expr_is_missing(field, cattrs));
quote!(#member: #value)
} else if field.attrs.flatten() {
let ty = field.ty;
quote! {
#member: {
// Recreate map based on the subfield column names and recursive to deserialize it.
// The unwraps would have produced an error earlier.
// The map is guaranteed to contain values for all fields.
let mut map2 = vec![];
for c in #ty::column_names().unwrap() {
use std::borrow::Borrow;
let c: &str = c.borrow();
let (c, (ty, val)) = map_flattened_fields.remove_entry(c).unwrap();
map2.push((c, ty, val));
}
klickhouse::Row::deserialize_row(map2)? }
}
} else {
quote!(#member: #name)
}
Expand All @@ -571,6 +640,8 @@ fn deserialize_map(
#(#let_values)*
#(#nested_temp_decls)*

#(#pull_flatten)*

#match_keys

#(#nested_rectify)*
Expand Down
1 change: 1 addition & 0 deletions klickhouse_derive/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub const BOUND: Symbol = Symbol("bound");
pub const DEFAULT: Symbol = Symbol("default");
pub const DENY_UNKNOWN_FIELDS: Symbol = Symbol("deny_unknown_fields");
pub const NESTED: Symbol = Symbol("nested");
pub const FLATTEN: Symbol = Symbol("flatten");
pub const DESERIALIZE_WITH: Symbol = Symbol("deserialize_with");
pub const FROM: Symbol = Symbol("from");
pub const INTO: Symbol = Symbol("into");
Expand Down

0 comments on commit d58c633

Please sign in to comment.