diff --git a/extension/src/aggregate_utils.rs b/extension/src/aggregate_utils.rs index 2eb51b24a..602cc77d8 100644 --- a/extension/src/aggregate_utils.rs +++ b/extension/src/aggregate_utils.rs @@ -11,6 +11,14 @@ pub unsafe fn get_collation(fcinfo: pg_sys::FunctionCallInfo) -> Option Option { + if fcinfo.is_null() { + Some(100) // TODO: default OID, there should be a constant for this + } else { + unsafe { get_collation(fcinfo) } + } +} + pub unsafe fn in_aggregate_context T>( fcinfo: pg_sys::FunctionCallInfo, f: F, diff --git a/extension/src/datum_utils.rs b/extension/src/datum_utils.rs index 4eba663dd..f445b4c9f 100644 --- a/extension/src/datum_utils.rs +++ b/extension/src/datum_utils.rs @@ -170,11 +170,7 @@ impl DatumHashBuilder { impl Clone for DatumHashBuilder { fn clone(&self) -> Self { - Self { - info: self.info, - type_id: self.type_id, - collation: self.collation, - } + unsafe { DatumHashBuilder::from_type_id(self.type_id, Some(self.collation)) } } } diff --git a/extension/src/frequency.rs b/extension/src/frequency.rs index 4e49669e5..85b70f67a 100644 --- a/extension/src/frequency.rs +++ b/extension/src/frequency.rs @@ -16,7 +16,7 @@ use serde::{ }; use crate::{ - aggregate_utils::{get_collation, in_aggregate_context}, + aggregate_utils::{get_collation_or_default, in_aggregate_context}, build, datum_utils::{ deep_copy_datum, DatumFromSerializedTextReader, DatumHashBuilder, DatumStore, @@ -233,6 +233,49 @@ impl SpaceSavingTransState { } } + fn ingest_aggregate_data( + &mut self, + val_count: u64, + values: &DatumStore, + counts: &[u64], + overcounts: &[u64], + ) { + assert_eq!(self.total_vals, 0); // This should only be called on an empty aggregate + self.total_vals = val_count; + + for (idx, datum) in values.iter().enumerate() { + self.entries.push(SpaceSavingEntry { + value: unsafe { deep_copy_datum(datum, self.indices.typoid()) }, + count: counts[idx], + overcount: overcounts[idx], + }); + self.indices + .insert((self.entries[idx].value, self.type_oid()).into(), idx); + } + } + + fn ingest_aggregate_ints( + &mut self, + val_count: u64, + values: &[i64], + counts: &[u64], + overcounts: &[u64], + ) { + assert_eq!(self.total_vals, 0); // This should only be called on an empty aggregate + assert_eq!(self.type_oid(), pg_sys::INT8OID); + self.total_vals = val_count; + + for (idx, val) in values.iter().enumerate() { + self.entries.push(SpaceSavingEntry { + value: Datum::from(*val), + count: counts[idx], + overcount: overcounts[idx], + }); + self.indices + .insert((self.entries[idx].value, self.type_oid()).into(), idx); + } + } + fn type_oid(&self) -> Oid { self.indices.typoid() } @@ -414,6 +457,36 @@ pub mod toolkit_experimental { } } + impl<'input> From<(&SpaceSavingAggregate<'input>, &pg_sys::FunctionCallInfo)> + for SpaceSavingTransState + { + fn from(data_in: (&SpaceSavingAggregate<'input>, &pg_sys::FunctionCallInfo)) -> Self { + let (agg, fcinfo) = data_in; + let collation = get_collation_or_default(*fcinfo); + let mut trans = if agg.topn == 0 { + SpaceSavingTransState::freq_agg_from_type_id( + agg.freq_param, + agg.type_oid, + collation, + ) + } else { + SpaceSavingTransState::topn_agg_from_type_id( + agg.freq_param, + agg.topn as u32, + agg.type_oid, + collation, + ) + }; + trans.ingest_aggregate_data( + agg.values_seen, + &agg.datums, + agg.counts.as_slice(), + agg.overcounts.as_slice(), + ); + trans + } + } + ron_inout_funcs!(SpaceSavingAggregate); pg_type! { @@ -457,6 +530,44 @@ pub mod toolkit_experimental { } } + impl<'input> + From<( + &SpaceSavingBigIntAggregate<'input>, + &pg_sys::FunctionCallInfo, + )> for SpaceSavingTransState + { + fn from( + data_in: ( + &SpaceSavingBigIntAggregate<'input>, + &pg_sys::FunctionCallInfo, + ), + ) -> Self { + let (agg, fcinfo) = data_in; + let collation = get_collation_or_default(*fcinfo); + let mut trans = if agg.topn == 0 { + SpaceSavingTransState::freq_agg_from_type_id( + agg.freq_param, + pg_sys::INT8OID, + collation, + ) + } else { + SpaceSavingTransState::topn_agg_from_type_id( + agg.freq_param, + agg.topn as u32, + pg_sys::INT8OID, + collation, + ) + }; + trans.ingest_aggregate_ints( + agg.values_seen, + agg.datums.as_slice(), + agg.counts.as_slice(), + agg.overcounts.as_slice(), + ); + trans + } + } + ron_inout_funcs!(SpaceSavingBigIntAggregate); pg_type! { @@ -500,6 +611,36 @@ pub mod toolkit_experimental { } } + impl<'input> From<(&SpaceSavingTextAggregate<'input>, &pg_sys::FunctionCallInfo)> + for SpaceSavingTransState + { + fn from(data_in: (&SpaceSavingTextAggregate<'input>, &pg_sys::FunctionCallInfo)) -> Self { + let (agg, fcinfo) = data_in; + let collation = get_collation_or_default(*fcinfo); + let mut trans = if agg.topn == 0 { + SpaceSavingTransState::freq_agg_from_type_id( + agg.freq_param, + pg_sys::TEXTOID, + collation, + ) + } else { + SpaceSavingTransState::topn_agg_from_type_id( + agg.freq_param, + agg.topn, + pg_sys::TEXTOID, + collation, + ) + }; + trans.ingest_aggregate_data( + agg.values_seen, + &agg.datums, + agg.counts.as_slice(), + agg.overcounts.as_slice(), + ); + trans + } + } + ron_inout_funcs!(SpaceSavingTextAggregate); } @@ -676,11 +817,7 @@ where let mut state = match state { None => { let typ = value.oid(); - let collation = if fcinfo.is_null() { - Some(100) // TODO: default OID, there should be a constant for this - } else { - get_collation(fcinfo) - }; + let collation = get_collation_or_default(fcinfo); make_trans_state(typ, collation).into() } Some(state) => state, @@ -692,6 +829,96 @@ where } } +#[pg_extern(schema = "toolkit_experimental", immutable, parallel_safe)] +pub fn rollup_agg_trans<'input>( + state: Internal, + value: Option>, + fcinfo: pg_sys::FunctionCallInfo, +) -> Option { + let value = match value { + None => return Some(state), + Some(v) => v, + }; + rollup_agg_trans_inner(unsafe { state.to_inner() }, value, fcinfo).internal() +} + +pub fn rollup_agg_trans_inner( + state: Option>, + value: SpaceSavingAggregate, + fcinfo: pg_sys::FunctionCallInfo, +) -> Option> { + unsafe { + in_aggregate_context(fcinfo, || { + let trans = (&value, &fcinfo).into(); + if let Some(state) = state { + Some(SpaceSavingTransState::combine(&*state, &trans).into()) + } else { + Some(trans.into()) + } + }) + } +} + +#[pg_extern(schema = "toolkit_experimental", immutable, parallel_safe)] +pub fn rollup_agg_bigint_trans<'input>( + state: Internal, + value: Option>, + fcinfo: pg_sys::FunctionCallInfo, +) -> Option { + let value = match value { + None => return Some(state), + Some(v) => v, + }; + rollup_agg_bigint_trans_inner(unsafe { state.to_inner() }, value, fcinfo).internal() +} + +pub fn rollup_agg_bigint_trans_inner( + state: Option>, + value: SpaceSavingBigIntAggregate, + fcinfo: pg_sys::FunctionCallInfo, +) -> Option> { + unsafe { + in_aggregate_context(fcinfo, || { + let trans = (&value, &fcinfo).into(); + if let Some(state) = state { + Some(SpaceSavingTransState::combine(&*state, &trans).into()) + } else { + Some(trans.into()) + } + }) + } +} + +#[pg_extern(schema = "toolkit_experimental", immutable, parallel_safe)] +pub fn rollup_agg_text_trans<'input>( + state: Internal, + value: Option>, + fcinfo: pg_sys::FunctionCallInfo, +) -> Option { + let value = match value { + None => return Some(state), + Some(v) => v, + }; + rollup_agg_text_trans_inner(unsafe { state.to_inner() }, value, fcinfo).internal() +} + +pub fn rollup_agg_text_trans_inner( + state: Option>, + value: SpaceSavingTextAggregate, + fcinfo: pg_sys::FunctionCallInfo, +) -> Option> { + unsafe { + in_aggregate_context(fcinfo, || { + let trans = (&value, &fcinfo).into(); + if let Some(state) = state { + Some(SpaceSavingTransState::combine(&*state, &trans).into()) + } else { + Some(trans.into()) + } + }) + } +} + #[pg_extern(schema = "toolkit_experimental", immutable, parallel_safe)] pub fn space_saving_combine( state1: Internal, @@ -970,6 +1197,78 @@ extension_sql!( ], ); +extension_sql!( + "\n\ + CREATE AGGREGATE toolkit_experimental.rollup(\n\ + agg toolkit_experimental.SpaceSavingAggregate\n\ + ) (\n\ + sfunc = toolkit_experimental.rollup_agg_trans,\n\ + stype = internal,\n\ + finalfunc = toolkit_experimental.space_saving_final,\n\ + combinefunc = toolkit_experimental.space_saving_combine,\n\ + serialfunc = toolkit_experimental.space_saving_serialize,\n\ + deserialfunc = toolkit_experimental.space_saving_deserialize,\n\ + parallel = safe\n\ + );\n\ +", + name = "freq_agg_rollup", + requires = [ + rollup_agg_trans, + space_saving_final, + space_saving_combine, + space_saving_serialize, + space_saving_deserialize + ], +); + +extension_sql!( + "\n\ + CREATE AGGREGATE toolkit_experimental.rollup(\n\ + agg toolkit_experimental.SpaceSavingBigIntAggregate\n\ + ) (\n\ + sfunc = toolkit_experimental.rollup_agg_bigint_trans,\n\ + stype = internal,\n\ + finalfunc = toolkit_experimental.space_saving_bigint_final,\n\ + combinefunc = toolkit_experimental.space_saving_combine,\n\ + serialfunc = toolkit_experimental.space_saving_serialize,\n\ + deserialfunc = toolkit_experimental.space_saving_deserialize,\n\ + parallel = safe\n\ + );\n\ +", + name = "freq_agg_bigint_rollup", + requires = [ + rollup_agg_bigint_trans, + space_saving_bigint_final, + space_saving_combine, + space_saving_serialize, + space_saving_deserialize + ], +); + +extension_sql!( + "\n\ + CREATE AGGREGATE toolkit_experimental.rollup(\n\ + agg toolkit_experimental.SpaceSavingTextAggregate\n\ + ) (\n\ + sfunc = toolkit_experimental.rollup_agg_text_trans,\n\ + stype = internal,\n\ + finalfunc = toolkit_experimental.space_saving_text_final,\n\ + combinefunc = toolkit_experimental.space_saving_combine,\n\ + serialfunc = toolkit_experimental.space_saving_serialize,\n\ + deserialfunc = toolkit_experimental.space_saving_deserialize,\n\ + parallel = safe\n\ + );\n\ +", + name = "freq_agg_text_rollup", + requires = [ + rollup_agg_text_trans, + space_saving_text_final, + space_saving_combine, + space_saving_serialize, + space_saving_deserialize + ], +); + #[pg_extern( immutable, parallel_safe, @@ -1089,8 +1388,13 @@ fn validate_topn_for_topn_agg( } #[pg_extern(immutable, parallel_safe, schema = "toolkit_experimental")] -pub fn topn(agg: SpaceSavingAggregate<'_>, n: i32, ty: AnyElement) -> SetOfIterator { - if ty.oid() != agg.type_oid { +pub fn topn( + agg: SpaceSavingAggregate<'_>, + n: i32, + ty: Option, +) -> SetOfIterator { + // If called with a NULL, assume type matches + if ty.is_some() && ty.unwrap().oid() != agg.type_oid { pgx::error!("mischatched types") } @@ -1125,7 +1429,10 @@ pub fn topn(agg: SpaceSavingAggregate<'_>, n: i32, ty: AnyElement) -> SetOfItera name = "topn", schema = "toolkit_experimental" )] -pub fn default_topn(agg: SpaceSavingAggregate<'_>, ty: AnyElement) -> SetOfIterator { +pub fn default_topn( + agg: SpaceSavingAggregate<'_>, + ty: Option, +) -> SetOfIterator { if agg.topn == 0 { pgx::error!("frequency aggregates require a N parameter to topn") } @@ -1367,6 +1674,8 @@ mod tests { use super::*; use pgx_macros::pg_test; use rand::distributions::{Distribution, Uniform}; + use rand::prelude::SliceRandom; + use rand::thread_rng; use rand::RngCore; use rand_distr::Zeta; @@ -1860,6 +2169,80 @@ mod tests { }); } + #[pg_test] + fn test_rollups() { + Spi::execute(|client| { + client.select( + "CREATE TABLE test (raw_data DOUBLE PRECISION, int_data INTEGER, text_data TEXT, bucket INTEGER)", + None, + None, + ); + + // Generate an array of 1000 values by taking the probability curve for a + // zeta curve with an s of 1.1 for the top 5 values, then adding smaller + // amounts of the next 5 most common values, and finally filling with unique values. + let mut vals = vec![1; 95]; + vals.append(&mut vec![2; 45]); + vals.append(&mut vec![3; 39]); + vals.append(&mut vec![4; 21]); + vals.append(&mut vec![5; 17]); + for v in 6..=10 { + vals.append(&mut vec![v, 10]); + } + for v in 0..(1000 - 95 - 45 - 39 - 21 - 17 - (5 * 10)) { + vals.push(11 + v); + } + vals.shuffle(&mut thread_rng()); + + // Probably not the most efficient way of populating this table... + for v in vals { + let cmd = format!( + "INSERT INTO test SELECT {}, {}::INT, {}::TEXT, FLOOR(RANDOM() * 10)", + v, v, v + ); + client.select(&cmd, None, None); + } + + // No matter how the values are batched into subaggregates, we should always + // see the same top 5 values + let mut result = client.select( + "WITH aggs AS (SELECT bucket, toolkit_experimental.raw_topn_agg(5, raw_data) as raw_agg FROM test GROUP BY bucket) + SELECT toolkit_experimental.topn(toolkit_experimental.rollup(raw_agg), NULL::DOUBLE PRECISION)::TEXT from aggs", + None, None + ); + assert_eq!(result.next().unwrap()[1].value(), Some("1")); + assert_eq!(result.next().unwrap()[1].value(), Some("2")); + assert_eq!(result.next().unwrap()[1].value(), Some("3")); + assert_eq!(result.next().unwrap()[1].value(), Some("4")); + assert_eq!(result.next().unwrap()[1].value(), Some("5")); + assert!(result.next().is_none()); + + let mut result = client.select( + "WITH aggs AS (SELECT bucket, toolkit_experimental.topn_agg(5, int_data) as int_agg FROM test GROUP BY bucket) + SELECT toolkit_experimental.topn(toolkit_experimental.rollup(int_agg))::TEXT from aggs", + None, None + ); + assert_eq!(result.next().unwrap()[1].value(), Some("1")); + assert_eq!(result.next().unwrap()[1].value(), Some("2")); + assert_eq!(result.next().unwrap()[1].value(), Some("3")); + assert_eq!(result.next().unwrap()[1].value(), Some("4")); + assert_eq!(result.next().unwrap()[1].value(), Some("5")); + assert!(result.next().is_none()); + + let mut result = client.select( + "WITH aggs AS (SELECT bucket, toolkit_experimental.topn_agg(5, text_data) as text_agg FROM test GROUP BY bucket) + SELECT toolkit_experimental.topn(toolkit_experimental.rollup(text_agg))::TEXT from aggs", + None, None + ); + assert_eq!(result.next().unwrap()[1].value(), Some("1")); + assert_eq!(result.next().unwrap()[1].value(), Some("2")); + assert_eq!(result.next().unwrap()[1].value(), Some("3")); + assert_eq!(result.next().unwrap()[1].value(), Some("4")); + assert_eq!(result.next().unwrap()[1].value(), Some("5")); + assert!(result.next().is_none()); + }); + } + #[pg_test] fn test_freq_agg_invariant() { // The frequency agg invariant is that any element with frequency >= f will appear in the freq_agg(f) @@ -1895,6 +2278,55 @@ mod tests { } } + #[pg_test] + fn test_freq_agg_rollup_maintains_invariant() { + // The frequency agg invariant is that any element with frequency >= f will appear in the freq_agg(f) + + // This test will randomly generate 200 values in the uniform range [0, 99] and check to see any value + // that shows up at least 3 times appears in a frequency aggregate created with freq = 0.015 + let rand100 = Uniform::new_inclusive(0, 99); + let mut rng = rand::thread_rng(); + + let mut counts = [0; 100]; + + let freq = 0.015; + let fcinfo = std::ptr::null_mut(); // dummy value, will use default collation + + let mut aggs = vec![]; + for _ in 0..4 { + let mut state = None.into(); + for _ in 0..50 { + let v = rand100.sample(&mut rng); + let value = unsafe { + AnyElement::from_polymorphic_datum( + pg_sys::Datum::from(v), + false, + pg_sys::INT4OID, + ) + }; + state = super::freq_agg_trans(state, freq, value, fcinfo).unwrap(); + counts[v] += 1; + } + aggs.push(space_saving_final(state, fcinfo).unwrap()); + } + + let state = { + let mut state = None.into(); + for agg in aggs { + state = super::rollup_agg_trans(state, Some(agg), fcinfo).unwrap(); + } + space_saving_final(state, fcinfo).unwrap() + }; + let vals: std::collections::HashSet = + state.datums.iter().map(|datum| datum.value()).collect(); + + for (val, &count) in counts.iter().enumerate() { + if count >= 3 { + assert!(vals.contains(&val)); + } + } + } + #[pg_test] fn test_topn_agg_invariant() { // The ton agg invariant is that we'll be able to track the top n values for any data @@ -1934,7 +2366,7 @@ mod tests { let state = space_saving_final(state, fcinfo).unwrap(); let value = unsafe { AnyElement::from_polymorphic_datum(Datum::from(0), false, pg_sys::INT4OID) }; - let t: Vec = default_topn(state, value.unwrap()).collect(); + let t: Vec = default_topn(state, Some(value.unwrap())).collect(); let agg_topn: Vec = t.iter().map(|x| x.datum().value()).collect(); let mut temp: Vec<(usize, &usize)> = counts.iter().enumerate().collect();