Skip to content

Commit

Permalink
feat(functions): support aggregate function retention
Browse files Browse the repository at this point in the history
  • Loading branch information
Kun FAN committed Apr 21, 2022
1 parent 9e20e2f commit b35f4b7
Show file tree
Hide file tree
Showing 6 changed files with 305 additions and 0 deletions.
224 changes: 224 additions & 0 deletions common/functions/src/aggregates/aggregate_retention.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
// Copyright 2022 Datafuse Labs.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::alloc::Layout;
use std::fmt;
use std::sync::Arc;

use bytes::BytesMut;
use common_datavalues::prelude::*;
use common_exception::ErrorCode;
use common_exception::Result;
use common_io::prelude::*;
use serde_json::json;
use serde_json::Value as JsonValue;

use super::aggregate_function::AggregateFunction;
use super::aggregate_function::AggregateFunctionRef;
use super::aggregate_function_factory::AggregateFunctionDescription;
use super::StateAddr;
use crate::aggregates::aggregator_common::assert_variadic_arguments;

struct AggregateRetentionState {
pub events: u32,
}

impl AggregateRetentionState {
#[inline(always)]
fn add(&mut self, event: u8) {
self.events |= 1 << event;
}

fn merge(&mut self, other: &Self) {
self.events |= other.events;
}

fn serialize(&self, writer: &mut BytesMut) -> Result<()> {
serialize_into_buf(writer, &self.events)
}

fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()> {
self.events = deserialize_from_slice(reader)?;
Ok(())
}
}

#[derive(Clone)]
pub struct AggregateRetentionFunction {
display_name: String,
events_size: u8,
_arguments: Vec<DataField>,
}

impl AggregateFunction for AggregateRetentionFunction {
fn name(&self) -> &str {
"AggregateRetentionFunction"
}

fn return_type(&self) -> Result<DataTypePtr> {
Ok(JsonValue::to_data_type())
}

fn init_state(&self, place: StateAddr) {
place.write(|| AggregateRetentionState { events: 0 });
}

fn state_layout(&self) -> std::alloc::Layout {
Layout::new::<AggregateRetentionState>()
}

fn accumulate(
&self,
place: StateAddr,
columns: &[common_datavalues::ColumnRef],
_validity: Option<&common_arrow::arrow::bitmap::Bitmap>,
input_rows: usize,
) -> Result<()> {
let state = place.get::<AggregateRetentionState>();
let new_columns: Vec<&BooleanColumn> = columns
.iter()
.map(|column| Series::check_get(column).unwrap())
.collect();
for i in 0..input_rows {
for j in 0..self.events_size {
if new_columns[j as usize].get_data(i) {
state.add(j);
}
}
}
Ok(())
}

fn accumulate_row(
&self,
place: StateAddr,
columns: &[common_datavalues::ColumnRef],
row: usize,
) -> Result<()> {
let state = place.get::<AggregateRetentionState>();
let new_columns: Vec<&BooleanColumn> = columns
.iter()
.map(|column| Series::check_get(column).unwrap())
.collect();
for j in 0..self.events_size {
if new_columns[j as usize].get_data(row) {
state.add(j);
}
}
Ok(())
}

fn serialize(&self, place: StateAddr, writer: &mut BytesMut) -> Result<()> {
let state = place.get::<AggregateRetentionState>();
state.serialize(writer)
}

fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
let state = place.get::<AggregateRetentionState>();
state.deserialize(reader)
}

fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let rhs = rhs.get::<AggregateRetentionState>();
let state = place.get::<AggregateRetentionState>();
state.merge(rhs);
Ok(())
}

#[allow(unused_mut)]
fn merge_result(
&self,
place: StateAddr,
array: &mut dyn common_datavalues::MutableColumn,
) -> Result<()> {
let state = place.get::<AggregateRetentionState>();
let builder: &mut MutableObjectColumn<JsonValue> = Series::check_get_mutable_column(array)?;
let mut vec: Vec<u8> = vec![0; self.events_size as usize];
if state.events & 1 == 1 {
vec[0] = 1;
for i in 1..self.events_size {
if state.events & (1 << i) != 0 {
vec[i as usize] = 1;
}
}
}
builder.append_value(json!(vec));
Ok(())
}

fn accumulate_keys(
&self,
places: &[StateAddr],
offset: usize,
columns: &[common_datavalues::ColumnRef],
_input_rows: usize,
) -> Result<()> {
let new_columns: Vec<&BooleanColumn> = columns
.iter()
.map(|column| Series::check_get(column).unwrap())
.collect();
for (row, place) in places.iter().enumerate() {
let place = place.next(offset);
let state = place.get::<AggregateRetentionState>();
for j in 0..self.events_size {
if new_columns[j as usize].get_data(row) {
state.add(j);
}
}
}
Ok(())
}
}

impl fmt::Display for AggregateRetentionFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.display_name)
}
}

impl AggregateRetentionFunction {
pub fn try_create(
display_name: &str,
arguments: Vec<DataField>,
) -> Result<AggregateFunctionRef> {
Ok(Arc::new(Self {
display_name: display_name.to_owned(),
events_size: arguments.len() as u8,
_arguments: arguments,
}))
}
}

pub fn try_create_aggregate_retention_function(
display_name: &str,
_params: Vec<DataValue>,
arguments: Vec<DataField>,
) -> Result<AggregateFunctionRef> {
assert_variadic_arguments(display_name, arguments.len(), (1, 32))?;

for argument in arguments.iter() {
let data_type = argument.data_type();
if data_type.data_type_id() != TypeID::Boolean {
return Err(ErrorCode::BadArguments(
"The arguments of AggregateRetention should be an expression which returns a Boolean result"
));
}
}

AggregateRetentionFunction::try_create(display_name, arguments)
}

pub fn aggregate_retention_function_desc() -> AggregateFunctionDescription {
AggregateFunctionDescription::creator(Box::new(try_create_aggregate_retention_function))
}
3 changes: 3 additions & 0 deletions common/functions/src/aggregates/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use super::aggregate_window_funnel::aggregate_window_funnel_function_desc;
use super::AggregateCountFunction;
use super::AggregateFunctionFactory;
use super::AggregateIfCombinator;
use crate::aggregates::aggregate_retention::aggregate_retention_function_desc;
use crate::aggregates::aggregate_sum::aggregate_sum_function_desc;

pub struct Aggregators;
Expand All @@ -50,6 +51,8 @@ impl Aggregators {

factory.register("window_funnel", aggregate_window_funnel_function_desc());
factory.register("uniq", AggregateDistinctCombinator::uniq_desc());

factory.register("retention", aggregate_retention_function_desc());
}

pub fn register_combinator(factory: &mut AggregateFunctionFactory) {
Expand Down
2 changes: 2 additions & 0 deletions common/functions/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ mod aggregate_combinator_if;
mod aggregate_covariance;
mod aggregate_min_max;
mod aggregate_null_result;
mod aggregate_retention;
mod aggregate_scalar_state;
mod aggregate_stddev_pop;
mod aggregate_window_funnel;
Expand All @@ -53,6 +54,7 @@ pub use aggregate_function_state::StateAddr;
pub use aggregate_function_state::StateAddrs;
pub use aggregate_min_max::AggregateMinMaxFunction;
pub use aggregate_null_result::AggregateNullResultFunction;
pub use aggregate_retention::AggregateRetentionFunction;
pub use aggregate_stddev_pop::AggregateStddevPopFunction;
pub use aggregate_sum::AggregateSumFunction;
pub use aggregate_window_funnel::AggregateWindowFunnelFunction;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
---
title: RETENTION
---

Aggregate function

The RETENTION() function takes as arguments a set of conditions from 1 to 32 arguments of type UInt8 that indicate whether a certain condition was met for the event.

Any condition can be specified as an argument (as in WHERE).

The conditions, except the first, apply in pairs: the result of the second will be true if the first and second are true, of the third if the first and third are true, etc.

## Syntax

```
RETENTION(cond1, cond2, ..., cond32);
```

## Arguments

| Arguments | Description |
| ----------- | ----------- |
| cond | An expression that returns a Boolean result |

## Return Type

The array of 1 or 0.

## Examples

```
CREATE TABLE retention_test(date DATE, uid INT) ENGINE = Memory;
INSERT INTO retention_test SELECT '2018-08-06', number FROM numbers(80);
INSERT INTO retention_test SELECT '2018-08-07', number FROM numbers(50);
INSERT INTO retention_test SELECT '2018-08-08', number FROM numbers(60);
```

```
SELECT sum(get(r, 0)::TINYINT) as r1, sum(get(r, 1)::TINYINT) as r2 FROM (SELECT uid, retention(date = '2018-08-06', date = '2018-08-07') AS r FROM retention_test WHERE date = '2018-08-06' or date = '2018-08-07' GROUP BY uid);
+------+------+
| r1 | r2 |
+------+------+
| 80 | 50 |
+------+------+
SELECT sum(get(r, 0)::TINYINT) as r1, sum(get(r, 1)::TINYINT) as r2 FROM (SELECT uid, retention(date = '2018-08-06', date = '2018-08-08') AS r FROM retention_test WHERE date = '2018-08-06' or date = '2018-08-08' GROUP BY uid);
+------+------+
| r1 | r2 |
+------+------+
| 80 | 60 |
+------+------+
SELECT sum(get(r, 0)::TINYINT) as r1, sum(get(r, 1)::TINYINT) as r2, sum(get(r, 2)::TINYINT) as r3 FROM (SELECT uid, retention(date = '2018-08-06', date = '2018-08-07', date = '2018-08-08') AS r FROM retention_test GROUP BY uid);
+------+------+------+
| r1 | r2 | r3 |
+------+------+------+
| 80 | 50 | 60 |
+------+------+------+
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
80 50
80 60
80 50 60
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
DROP TABLE IF EXISTS retention_test;

CREATE TABLE retention_test(date DATE, uid INT);

INSERT INTO retention_test SELECT '2018-08-06', number FROM numbers(80);
INSERT INTO retention_test SELECT '2018-08-07', number FROM numbers(50);
INSERT INTO retention_test SELECT '2018-08-08', number FROM numbers(60);

SELECT sum(get(r, 0)::TINYINT) as r1, sum(get(r, 1)::TINYINT) as r2 FROM (SELECT uid, retention(date = '2018-08-06', date = '2018-08-07') AS r FROM retention_test WHERE date = '2018-08-06' or date = '2018-08-07' GROUP BY uid);
SELECT sum(get(r, 0)::TINYINT) as r1, sum(get(r, 1)::TINYINT) as r2 FROM (SELECT uid, retention(date = '2018-08-06', date = '2018-08-08') AS r FROM retention_test WHERE date = '2018-08-06' or date = '2018-08-08' GROUP BY uid);
SELECT sum(get(r, 0)::TINYINT) as r1, sum(get(r, 1)::TINYINT) as r2, sum(get(r, 2)::TINYINT) as r3 FROM (SELECT uid, retention(date = '2018-08-06', date = '2018-08-07', date = '2018-08-08') AS r FROM retention_test GROUP BY uid);

DROP TABLE retention_test;

0 comments on commit b35f4b7

Please sign in to comment.