Skip to content

Commit

Permalink
interim commit: adding tests but failing
Browse files Browse the repository at this point in the history
  • Loading branch information
kwannoel committed Aug 1, 2024
1 parent 1d3f41e commit c5a8387
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 15 deletions.
22 changes: 22 additions & 0 deletions e2e_test/streaming/aggregate/two_phase_approx_percentile.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Single phase approx percentile
statement ok
create table t(p_col double, grp_col int);

# statement ok
# insert into t select a, 1 from generate_series(0, 10) t(a);

# statement ok
# insert into t values(0, 1);

statement ok
flush;

statement ok
create materialized view m1 as select
approx_percentile(0.01, 0.01) within group (order by p_col) as p01
from t;

query I
select * from m1;
----
10
5 changes: 5 additions & 0 deletions src/common/src/util/stream_graph_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,11 @@ pub fn visit_stream_node_tables_inner<F>(
NodeBody::Materialize(node) if !internal_tables_only => {
always!(node.table, "Materialize")
}

NodeBody::GlobalApproxPercentile(node) => {
always!(node.bucket_state_table, "GlobalApproxPercentileBucketState");
always!(node.count_state_table, "GlobalApproxPercentileCountState");
}
_ => {}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use pretty_xmlish::{Pretty, XmlNode};
use risingwave_common::catalog::{Field, Schema};
use risingwave_common::types::DataType;
use risingwave_common::util::sort_util::OrderType;
use risingwave_pb::catalog::Table;
use risingwave_pb::stream_plan::stream_node::PbNodeBody;
use risingwave_pb::stream_plan::GlobalApproxPercentileNode;

Expand Down Expand Up @@ -98,7 +97,7 @@ impl PlanTreeNodeUnary for StreamGlobalApproxPercentile {
impl_plan_tree_node_for_unary! {StreamGlobalApproxPercentile}

impl StreamNode for StreamGlobalApproxPercentile {
fn to_stream_prost_body(&self, _state: &mut BuildFragmentGraphState) -> PbNodeBody {
fn to_stream_prost_body(&self, state: &mut BuildFragmentGraphState) -> PbNodeBody {
let relative_error = self.relative_error.get_data().as_ref().unwrap();
let relative_error = relative_error.as_float64().into_inner();
let base = (1.0 + relative_error) / (1.0 - relative_error);
Expand All @@ -116,11 +115,13 @@ impl StreamNode for StreamGlobalApproxPercentile {
bucket_state_table: Some(
bucket_table_builder
.build(vec![], 0)
.with_id(state.gen_table_id_wrapped())
.to_internal_table_prost(),
),
count_state_table: Some(
count_table_builder
.build(vec![], 0)
.with_id(state.gen_table_id_wrapped())
.to_internal_table_prost(),
),
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ pub struct StreamLocalApproxPercentile {
impl StreamLocalApproxPercentile {
pub fn new(input: PlanRef, approx_percentile_agg_call: &PlanAggCall) -> Self {
let schema = Schema::new(vec![
Field::with_name(DataType::Int64, "bucket_id"),
Field::with_name(DataType::Int64, "count"),
Field::with_name(DataType::Int32, "bucket_id"),
Field::with_name(DataType::Int32, "count"),
]);
// FIXME(kwannoel): How does watermark work with FixedBitSet
let watermark_columns = FixedBitSet::with_capacity(2);
Expand Down
3 changes: 3 additions & 0 deletions src/stream/src/common/table/state_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,9 @@ where

/// Update a row. The old and new value should have the same pk.
pub fn update(&mut self, old_value: impl Row, new_value: impl Row) {
println!("pk: {:?}", self.pk_indices());
println!("old_row: {:?}", old_value);
println!("new_row: {:?}", new_value);
let old_pk = (&old_value).project(self.pk_indices());
let new_pk = (&new_value).project(self.pk_indices());
debug_assert!(
Expand Down
35 changes: 24 additions & 11 deletions src/stream/src/executor/approx_percentile/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use core::ops::Bound;

use risingwave_common::bail;
use risingwave_common::array::Op;
use risingwave_common::row::RowExt;
use risingwave_storage::store::PrefetchOptions;
Expand Down Expand Up @@ -58,34 +59,45 @@ impl<S: StateStore> GlobalApproxPercentileExecutor<S> {
async fn execute_inner(self) {
let mut bucket_state_table = self.bucket_state_table;
let mut count_state_table = self.count_state_table;
let mut input_stream = self.input.execute();
let first_barrier = expect_first_barrier(&mut input_stream).await?;
bucket_state_table.init_epoch(first_barrier.epoch);
count_state_table.init_epoch(first_barrier.epoch);
yield Message::Barrier(first_barrier);
let mut old_row_count = count_state_table.get_row(&[Datum::None; 0]).await?;
let mut row_count = if let Some(row) = old_row_count.as_ref() {
row.datum_at(0).unwrap().into_int64()
} else {
0
};
#[for_await]
for message in self.input.execute() {
for message in input_stream {
match message? {
Message::Chunk(chunk) => {
for (_, row) in chunk.rows() {
let pk_datum = row.datum_at(0);
let pk = row.project(&[0]);
println!("row: {:?}", row);
let delta_datum = row.datum_at(1);
let delta: i32 = delta_datum.unwrap().into_int32();
row_count = row_count.checked_add(delta as i64).unwrap();

let pk_datum = row.datum_at(0);
let pk = row.project(&[0]);

let old_row = bucket_state_table.get_row(pk).await?;
let old_value: i32 = if let Some(row) = old_row.as_ref() {
row.datum_at(0).unwrap().into_int32()
let old_bucket_row_count: i64 = if let Some(row) = old_row.as_ref() {
row.datum_at(1).unwrap().into_int64()
} else {
0
};

let new_value = old_value + delta;
let new_value_datum = Datum::from(ScalarImpl::Int32(new_value));
let new_value = old_bucket_row_count.checked_add(delta as i64).unwrap();
let new_value_datum = Datum::from(ScalarImpl::Int64(new_value));
let new_row = &[pk_datum.map(|d| d.into()), new_value_datum];
bucket_state_table.update(old_row, new_row);
if old_row.is_none() {
bucket_state_table.insert(new_row);
} else {
bucket_state_table.update(old_row, new_row);
}
}
}
Message::Barrier(barrier) => {
Expand All @@ -100,7 +112,7 @@ impl<S: StateStore> GlobalApproxPercentileExecutor<S> {
.await?
{
let row = keyed_row?.into_owned_row();
let count = row.datum_at(1).unwrap().into_int32();
let count = row.datum_at(1).unwrap().into_int64();
acc_count += count as u64;
if acc_count >= quantile_count {
let bucket_id = row.datum_at(0).unwrap().into_int32();
Expand All @@ -118,14 +130,15 @@ impl<S: StateStore> GlobalApproxPercentileExecutor<S> {
}
}
let row_count_to_persist = &[Datum::from(ScalarImpl::Int64(row_count))];
if let Some(old_row_count) = old_row_count {
if let Some(old_row_count) = old_row_count && row_count_to_persist.into_owned_row() != old_row_count {
count_state_table.update(old_row_count, row_count_to_persist);
} else {
count_state_table.insert(row_count_to_persist);
}
old_row_count = Some(row_count_to_persist.into_owned_row());
count_state_table.commit(barrier.epoch).await?;
bucket_state_table.commit(barrier.epoch).await?;

old_row_count = Some(row_count_to_persist.into_owned_row());
yield Message::Barrier(barrier);
}
m => yield m,
Expand Down

0 comments on commit c5a8387

Please sign in to comment.