diff --git a/e2e_test/streaming/asof_join.slt b/e2e_test/streaming/asof_join.slt new file mode 100644 index 0000000000000..6e35d5aa7d40b --- /dev/null +++ b/e2e_test/streaming/asof_join.slt @@ -0,0 +1,143 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +# asof inner join + +statement ok +create table t1 (v1 int, v2 int, v3 int primary key); + +statement ok +create table t2 (v1 int, v2 int, v3 int primary key); + +statement ok +create materialized view mv1 as SELECT t1.v1 t1_v1, t1.v2 t1_v2, t1.v3 t1_v3, t2.v1 t2_v1, t2.v2 t2_v2, t2.v3 t2_v3 FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1 and t1.v2 <= t2.v2; + +statement ok +insert into t1 values (1, 2, 3); + +statement ok +insert into t2 values (1, 3, 4); + +query III +select * from mv1; +---- +1 2 3 1 3 4 + +statement ok +insert into t2 values (1, 2, 3); + +query III +select * from mv1; +---- +1 2 3 1 2 3 + +statement ok +delete from t1 where v3 = 3; + +query III +select * from mv1; +---- + + +statement ok +insert into t1 values (2, 3, 4); + +statement ok +insert into t2 values (2, 3, 6), (2, 3, 7), (2, 3, 5); + +query III +select * from mv1; +---- +2 3 4 2 3 5 + +statement ok +insert into t2 values (2, 3, 1), (2, 3, 2); + +query III +select * from mv1; +---- +2 3 4 2 3 1 + +statement ok +drop materialized view mv1; + +statement ok +drop table t1; + +statement ok +drop table t2; + + +# asof left join + +statement ok +create table t1 (v1 int, v2 int, v3 int primary key); + +statement ok +create table t2 (v1 int, v2 int, v3 int primary key); + +statement ok +create materialized view mv1 as SELECT t1.v1 t1_v1, t1.v2 t1_v2, t1.v3 t1_v3, t2.v1 t2_v1, t2.v2 t2_v2, t2.v3 t2_v3 FROM t1 ASOF LEFT JOIN t2 ON t1.v1 = t2.v1 and t1.v2 > t2.v2; + +statement ok +insert into t1 values (1, 2, 3); + +statement ok +insert into t2 values (1, 2, 4); + +query III +select * from mv1; +---- +1 2 3 NULL NULL NULL + +statement ok +insert into t2 values (1, 1, 3); + +query III +select * from mv1; +---- +1 2 3 1 1 3 + +statement ok +delete from t1 where v3 = 3; + +query III +select * from mv1; +---- + + +statement ok +insert into t1 values (2, 3, 4); + +statement ok +insert into t2 values (2, 2, 6), (2, 2, 7), (2, 2, 5); + +query III +select * from mv1; +---- +2 3 4 2 2 5 + +statement ok +insert into t2 values (2, 2, 1), (2, 2, 2); + +query III +select * from mv1; +---- +2 3 4 2 2 1 + +statement ok +delete from t2 where v1 = 2; + +query III +select * from mv1; +---- +2 3 4 NULL NULL NULL + +statement ok +drop materialized view mv1; + +statement ok +drop table t1; + +statement ok +drop table t2; diff --git a/proto/plan_common.proto b/proto/plan_common.proto index a552c9f0a5fae..f561ee427ea46 100644 --- a/proto/plan_common.proto +++ b/proto/plan_common.proto @@ -146,6 +146,8 @@ enum JoinType { JOIN_TYPE_LEFT_ANTI = 6; JOIN_TYPE_RIGHT_SEMI = 7; JOIN_TYPE_RIGHT_ANTI = 8; + JOIN_TYPE_ASOF_INNER = 9; + JOIN_TYPE_ASOF_LEFT_OUTER = 10; } enum AsOfJoinType { diff --git a/proto/stream_plan.proto b/proto/stream_plan.proto index 7fe63054c565c..1703069c4047e 100644 --- a/proto/stream_plan.proto +++ b/proto/stream_plan.proto @@ -885,6 +885,7 @@ message StreamNode { LocalApproxPercentileNode local_approx_percentile = 144; GlobalApproxPercentileNode global_approx_percentile = 145; RowMergeNode row_merge = 146; + AsOfJoinNode as_of_join = 147; } // The id for the operator. This is local per mview. // TODO: should better be a uint32. diff --git a/src/batch/src/executor/join/mod.rs b/src/batch/src/executor/join/mod.rs index cf2388314d8f6..4ac630489a552 100644 --- a/src/batch/src/executor/join/mod.rs +++ b/src/batch/src/executor/join/mod.rs @@ -62,7 +62,9 @@ impl JoinType { PbJoinType::RightSemi => JoinType::RightSemi, PbJoinType::RightAnti => JoinType::RightAnti, PbJoinType::FullOuter => JoinType::FullOuter, - PbJoinType::Unspecified => unreachable!(), + PbJoinType::AsofInner | PbJoinType::AsofLeftOuter | PbJoinType::Unspecified => { + unreachable!() + } } } } diff --git a/src/common/src/util/stream_graph_visitor.rs b/src/common/src/util/stream_graph_visitor.rs index 04e0e42a1a7f1..8eae1492985be 100644 --- a/src/common/src/util/stream_graph_visitor.rs +++ b/src/common/src/util/stream_graph_visitor.rs @@ -269,6 +269,12 @@ pub fn visit_stream_node_tables_inner( always!(node.bucket_state_table, "GlobalApproxPercentileBucketState"); always!(node.count_state_table, "GlobalApproxPercentileCountState"); } + + // AsOf join + NodeBody::AsOfJoin(node) => { + always!(node.left_table, "AsOfJoinLeft"); + always!(node.right_table, "AsOfJoinRight"); + } _ => {} } }; diff --git a/src/frontend/planner_test/tests/testdata/input/asof_join.yaml b/src/frontend/planner_test/tests/testdata/input/asof_join.yaml new file mode 100644 index 0000000000000..f6ca65716c2ea --- /dev/null +++ b/src/frontend/planner_test/tests/testdata/input/asof_join.yaml @@ -0,0 +1,35 @@ +- sql: + CREATE TABLE t1(v1 varchar, v2 int, v3 int); + CREATE TABLE t2(v1 varchar, v2 int, v3 int); + SELECT * FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1; + expected_outputs: + - stream_error + +- sql: + CREATE TABLE t1(v1 varchar, v2 int, v3 int); + CREATE TABLE t2(v1 varchar, v2 int, v3 int); + SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1 || 'a' and t1.v2 > t2.v2; + expected_outputs: + - batch_error + - stream_plan + +- sql: + CREATE TABLE t1(v1 varchar, v2 int, v3 int); + CREATE TABLE t2(v1 varchar, v2 int, v3 int); + SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF LEFT JOIN t2 ON t1.v1 = t2.v1 and t1.v2 *2 < t2.v2; + expected_outputs: + - stream_plan + +- sql: + CREATE TABLE t1(v1 varchar, v2 int, v3 int); + CREATE TABLE t2(v1 varchar, v2 int, v3 int); + SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1 and t1.v2 < t2.v2 and t1.v3 < t2.v3; + expected_outputs: + - stream_error + +- sql: + CREATE TABLE t1(v1 varchar, v2 int, v3 int); + CREATE TABLE t2(v1 varchar, v2 int, v3 int); + SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v2 < t2.v2; + expected_outputs: + - stream_error diff --git a/src/frontend/planner_test/tests/testdata/output/asof_join.yaml b/src/frontend/planner_test/tests/testdata/output/asof_join.yaml new file mode 100644 index 0000000000000..508c9de04f18d --- /dev/null +++ b/src/frontend/planner_test/tests/testdata/output/asof_join.yaml @@ -0,0 +1,28 @@ +# This file is automatically generated. See `src/frontend/planner_test/README.md` for more information. +- sql: CREATE TABLE t1(v1 varchar, v2 int, v3 int); CREATE TABLE t2(v1 varchar, v2 int, v3 int); SELECT * FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1; + stream_error: 'Invalid input syntax: AsOf join requires exactly 1 ineuquality condition' +- sql: CREATE TABLE t1(v1 varchar, v2 int, v3 int); CREATE TABLE t2(v1 varchar, v2 int, v3 int); SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1 || 'a' and t1.v2 > t2.v2; + stream_plan: |- + StreamMaterialize { columns: [t1_v1, t1_v2, t2_v1, t2_v2, t1._row_id(hidden), t2._row_id(hidden)], stream_key: [t1._row_id, t2._row_id, t1_v1], pk_columns: [t1._row_id, t2._row_id, t1_v1], pk_conflict: NoCheck } + └─StreamAsOfJoin { type: AsofInner, predicate: t1.v1 = $expr1 AND (t1.v2 > t2.v2), output: [t1.v1, t1.v2, t2.v1, t2.v2, t1._row_id, t2._row_id] } + ├─StreamExchange { dist: HashShard(t1.v1) } + │ └─StreamTableScan { table: t1, columns: [t1.v1, t1.v2, t1._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t1._row_id], pk: [_row_id], dist: UpstreamHashShard(t1._row_id) } + └─StreamExchange { dist: HashShard($expr1) } + └─StreamProject { exprs: [t2.v1, t2.v2, ConcatOp(t2.v1, 'a':Varchar) as $expr1, t2._row_id] } + └─StreamTableScan { table: t2, columns: [t2.v1, t2.v2, t2._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t2._row_id], pk: [_row_id], dist: UpstreamHashShard(t2._row_id) } + batch_error: |- + Not supported: AsOf join in batch query + HINT: AsOf join is only supported in streaming query +- sql: CREATE TABLE t1(v1 varchar, v2 int, v3 int); CREATE TABLE t2(v1 varchar, v2 int, v3 int); SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF LEFT JOIN t2 ON t1.v1 = t2.v1 and t1.v2 *2 < t2.v2; + stream_plan: |- + StreamMaterialize { columns: [t1_v1, t1_v2, t2_v1, t2_v2, t1._row_id(hidden), t2._row_id(hidden)], stream_key: [t1._row_id, t2._row_id, t1_v1], pk_columns: [t1._row_id, t2._row_id, t1_v1], pk_conflict: NoCheck } + └─StreamAsOfJoin { type: AsofLeftOuter, predicate: t1.v1 = t2.v1 AND ($expr1 < t2.v2), output: [t1.v1, t1.v2, t2.v1, t2.v2, t1._row_id, t2._row_id] } + ├─StreamExchange { dist: HashShard(t1.v1) } + │ └─StreamProject { exprs: [t1.v1, t1.v2, (t1.v2 * 2:Int32) as $expr1, t1._row_id] } + │ └─StreamTableScan { table: t1, columns: [t1.v1, t1.v2, t1._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t1._row_id], pk: [_row_id], dist: UpstreamHashShard(t1._row_id) } + └─StreamExchange { dist: HashShard(t2.v1) } + └─StreamTableScan { table: t2, columns: [t2.v1, t2.v2, t2._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t2._row_id], pk: [_row_id], dist: UpstreamHashShard(t2._row_id) } +- sql: CREATE TABLE t1(v1 varchar, v2 int, v3 int); CREATE TABLE t2(v1 varchar, v2 int, v3 int); SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1 and t1.v2 < t2.v2 and t1.v3 < t2.v3; + stream_error: 'Invalid input syntax: AsOf join requires exactly 1 ineuquality condition' +- sql: CREATE TABLE t1(v1 varchar, v2 int, v3 int); CREATE TABLE t2(v1 varchar, v2 int, v3 int); SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v2 < t2.v2; + stream_error: 'Invalid input syntax: AsOf join requires at least 1 equal condition' diff --git a/src/frontend/src/binder/relation/join.rs b/src/frontend/src/binder/relation/join.rs index d13b683be08b0..30bd0a2906222 100644 --- a/src/frontend/src/binder/relation/join.rs +++ b/src/frontend/src/binder/relation/join.rs @@ -92,6 +92,8 @@ impl Binder { JoinOperator::FullOuter(constraint) => (constraint, JoinType::FullOuter), // Cross join equals to inner join with with no constraint. JoinOperator::CrossJoin => (JoinConstraint::None, JoinType::Inner), + JoinOperator::AsOfInner(constraint) => (constraint, JoinType::AsofInner), + JoinOperator::AsOfLeft(constraint) => (constraint, JoinType::AsofLeftOuter), }; let right: Relation; let cond: ExprImpl; diff --git a/src/frontend/src/optimizer/plan_node/batch_hash_join.rs b/src/frontend/src/optimizer/plan_node/batch_hash_join.rs index 399817336e4ca..bb5bca88d2b19 100644 --- a/src/frontend/src/optimizer/plan_node/batch_hash_join.rs +++ b/src/frontend/src/optimizer/plan_node/batch_hash_join.rs @@ -66,7 +66,9 @@ impl BatchHashJoin { // we can not derive the hash distribution from the side where outer join can generate a // NULL row (Distribution::HashShard(_), Distribution::HashShard(_)) => match join.join_type { - JoinType::Unspecified => unreachable!(), + JoinType::AsofInner | JoinType::AsofLeftOuter | JoinType::Unspecified => { + unreachable!() + } JoinType::FullOuter => Distribution::SomeShard, JoinType::Inner | JoinType::LeftOuter | JoinType::LeftSemi | JoinType::LeftAnti => { let l2o = join.l2i_col_mapping().composite(&join.i2o_col_mapping()); diff --git a/src/frontend/src/optimizer/plan_node/generic/join.rs b/src/frontend/src/optimizer/plan_node/generic/join.rs index 105f8bebb32bd..f7ce096e73eb2 100644 --- a/src/frontend/src/optimizer/plan_node/generic/join.rs +++ b/src/frontend/src/optimizer/plan_node/generic/join.rs @@ -277,7 +277,7 @@ impl GenericPlanNode for Join { .rewrite_functional_dependency_set(right_fd_set) }; let fd_set: FunctionalDependencySet = match self.join_type { - JoinType::Inner => { + JoinType::Inner | JoinType::AsofInner => { let mut fd_set = FunctionalDependencySet::new(full_out_col_num); for i in &self.on.conjunctions { if let Some((col, _)) = i.as_eq_const() { @@ -300,7 +300,7 @@ impl GenericPlanNode for Join { .for_each(|fd| fd_set.add_functional_dependency(fd)); fd_set } - JoinType::LeftOuter => get_new_left_fd_set(left_fd_set), + JoinType::LeftOuter | JoinType::AsofLeftOuter => get_new_left_fd_set(left_fd_set), JoinType::RightOuter => get_new_right_fd_set(right_fd_set), JoinType::FullOuter => FunctionalDependencySet::new(full_out_col_num), JoinType::LeftSemi | JoinType::LeftAnti => left_fd_set, @@ -325,9 +325,12 @@ impl Join { pub fn full_out_col_num(left_len: usize, right_len: usize, join_type: JoinType) -> usize { match join_type { - JoinType::Inner | JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter => { - left_len + right_len - } + JoinType::Inner + | JoinType::LeftOuter + | JoinType::RightOuter + | JoinType::FullOuter + | JoinType::AsofInner + | JoinType::AsofLeftOuter => left_len + right_len, JoinType::LeftSemi | JoinType::LeftAnti => left_len, JoinType::RightSemi | JoinType::RightAnti => right_len, JoinType::Unspecified => unreachable!(), @@ -371,7 +374,12 @@ impl Join { let right_len = self.right.schema().len(); match self.join_type { - JoinType::Inner | JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter => { + JoinType::Inner + | JoinType::LeftOuter + | JoinType::RightOuter + | JoinType::FullOuter + | JoinType::AsofInner + | JoinType::AsofLeftOuter => { ColIndexMapping::identity_or_none(left_len + right_len, left_len) } @@ -389,7 +397,12 @@ impl Join { let right_len = self.right.schema().len(); match self.join_type { - JoinType::Inner | JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter => { + JoinType::Inner + | JoinType::LeftOuter + | JoinType::RightOuter + | JoinType::FullOuter + | JoinType::AsofInner + | JoinType::AsofLeftOuter => { ColIndexMapping::with_shift_offset(left_len + right_len, -(left_len as isize)) } JoinType::LeftSemi | JoinType::LeftAnti => ColIndexMapping::empty(left_len, right_len), @@ -445,13 +458,16 @@ impl Join { pub fn add_which_join_key_to_pk(&self) -> EitherOrBoth<(), ()> { match self.join_type { - JoinType::Inner => { + JoinType::Inner | JoinType::AsofInner => { // Theoretically adding either side is ok, but the distribution key of the inner // join derived based on the left side by default, so we choose the left side here // to ensure the pk comprises the distribution key. EitherOrBoth::Left(()) } - JoinType::LeftOuter | JoinType::LeftSemi | JoinType::LeftAnti => EitherOrBoth::Left(()), + JoinType::LeftOuter + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::AsofLeftOuter => EitherOrBoth::Left(()), JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => { EitherOrBoth::Right(()) } diff --git a/src/frontend/src/optimizer/plan_node/logical_join.rs b/src/frontend/src/optimizer/plan_node/logical_join.rs index 2b64b5fd93ad5..0f642e3c3e88a 100644 --- a/src/frontend/src/optimizer/plan_node/logical_join.rs +++ b/src/frontend/src/optimizer/plan_node/logical_join.rs @@ -33,6 +33,7 @@ use crate::error::{ErrorCode, Result, RwError}; use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::generic::DynamicFilter; +use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin; use crate::optimizer::plan_node::utils::IndicesDisplay; use crate::optimizer::plan_node::{ BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate, @@ -837,14 +838,13 @@ impl PredicatePushdown for LogicalJoin { } impl LogicalJoin { - fn to_stream_hash_join( + fn get_stream_input_for_hash_join( &self, - predicate: EqJoinPredicate, + predicate: &EqJoinPredicate, ctx: &mut ToStreamContext, - ) -> Result { + ) -> Result<(PlanRef, PlanRef)> { use super::stream::prelude::*; - assert!(predicate.has_eq()); let mut right = self.right().to_stream_with_dist_required( &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()), ctx, @@ -888,6 +888,18 @@ impl LogicalJoin { } _ => unreachable!(), } + Ok((left, right)) + } + + fn to_stream_hash_join( + &self, + predicate: EqJoinPredicate, + ctx: &mut ToStreamContext, + ) -> Result { + use super::stream::prelude::*; + + assert!(predicate.has_eq()); + let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?; let logical_join = self.clone_with_left_right(left, right); @@ -1260,10 +1272,45 @@ impl LogicalJoin { .expect("Fail to convert to lookup join") .into()) } + + fn to_stream_asof_join( + &self, + predicate: EqJoinPredicate, + ctx: &mut ToStreamContext, + ) -> Result { + use super::stream::prelude::*; + + if predicate.eq_keys().is_empty() { + return Err(ErrorCode::InvalidInputSyntax( + "AsOf join requires at least 1 equal condition".to_string(), + ) + .into()); + } + + let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?; + let left_len = left.schema().len(); + let logical_join = self.clone_with_left_right(left, right); + + let inequality_desc = + StreamAsOfJoin::get_inequality_desc_from_predicate(predicate.clone(), left_len)?; + + Ok(StreamAsOfJoin::new( + logical_join.core.clone(), + predicate, + inequality_desc, + )) + } } impl ToBatch for LogicalJoin { fn to_batch(&self) -> Result { + if JoinType::AsofInner == self.join_type() || JoinType::AsofLeftOuter == self.join_type() { + return Err(ErrorCode::NotSupported( + "AsOf join in batch query".to_string(), + "AsOf join is only supported in streaming query".to_string(), + ) + .into()); + } let predicate = EqJoinPredicate::create( self.left().schema().len(), self.right().schema().len(), @@ -1320,7 +1367,9 @@ impl ToStream for LogicalJoin { self.on().clone(), ); - if predicate.has_eq() { + if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter { + self.to_stream_asof_join(predicate, ctx).map(|x| x.into()) + } else if predicate.has_eq() { if !predicate.eq_keys_are_type_aligned() { return Err(ErrorCode::InternalError(format!( "Join eq keys are not aligned for predicate: {predicate:?}" diff --git a/src/frontend/src/optimizer/plan_node/mod.rs b/src/frontend/src/optimizer/plan_node/mod.rs index db1200de2a27a..0ec266cd2339d 100644 --- a/src/frontend/src/optimizer/plan_node/mod.rs +++ b/src/frontend/src/optimizer/plan_node/mod.rs @@ -883,6 +883,7 @@ mod logical_topn; mod logical_union; mod logical_update; mod logical_values; +mod stream_asof_join; mod stream_changelog; mod stream_dedup; mod stream_delta_join; @@ -898,6 +899,7 @@ mod stream_group_topn; mod stream_hash_agg; mod stream_hash_join; mod stream_hop_window; +mod stream_join_common; mod stream_local_approx_percentile; mod stream_materialize; mod stream_now; @@ -994,6 +996,7 @@ pub use logical_topn::LogicalTopN; pub use logical_union::LogicalUnion; pub use logical_update::LogicalUpdate; pub use logical_values::LogicalValues; +pub use stream_asof_join::StreamAsOfJoin; pub use stream_cdc_table_scan::StreamCdcTableScan; pub use stream_changelog::StreamChangeLog; pub use stream_dedup::StreamDedup; @@ -1010,6 +1013,7 @@ pub use stream_group_topn::StreamGroupTopN; pub use stream_hash_agg::StreamHashAgg; pub use stream_hash_join::StreamHashJoin; pub use stream_hop_window::StreamHopWindow; +use stream_join_common::StreamJoinCommon; pub use stream_local_approx_percentile::StreamLocalApproxPercentile; pub use stream_materialize::StreamMaterialize; pub use stream_now::StreamNow; @@ -1159,6 +1163,7 @@ macro_rules! for_all_plan_nodes { , { Stream, GlobalApproxPercentile } , { Stream, LocalApproxPercentile } , { Stream, RowMerge } + , { Stream, AsOfJoin } } }; } @@ -1288,6 +1293,7 @@ macro_rules! for_stream_plan_nodes { , { Stream, GlobalApproxPercentile } , { Stream, LocalApproxPercentile } , { Stream, RowMerge } + , { Stream, AsOfJoin } } }; } diff --git a/src/frontend/src/optimizer/plan_node/stream_asof_join.rs b/src/frontend/src/optimizer/plan_node/stream_asof_join.rs new file mode 100644 index 0000000000000..f241769168604 --- /dev/null +++ b/src/frontend/src/optimizer/plan_node/stream_asof_join.rs @@ -0,0 +1,350 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use fixedbitset::FixedBitSet; +use itertools::Itertools; +use pretty_xmlish::{Pretty, XmlNode}; +use risingwave_common::util::sort_util::OrderType; +use risingwave_expr::bail; +use risingwave_pb::expr::expr_node::PbType; +use risingwave_pb::plan_common::{AsOfJoinDesc, AsOfJoinType, JoinType, PbAsOfJoinInequalityType}; +use risingwave_pb::stream_plan::stream_node::NodeBody; +use risingwave_pb::stream_plan::AsOfJoinNode; + +use super::generic::GenericPlanNode; +use super::stream::prelude::*; +use super::utils::{ + childless_record, plan_node_name, watermark_pretty, Distill, TableCatalogBuilder, +}; +use super::{ + generic, ExprRewritable, PlanBase, PlanRef, PlanTreeNodeBinary, StreamJoinCommon, StreamNode, +}; +use crate::error::{ErrorCode, Result}; +use crate::expr::{ExprImpl, ExprRewriter, ExprVisitor}; +use crate::optimizer::plan_node::expr_visitable::ExprVisitable; +use crate::optimizer::plan_node::utils::IndicesDisplay; +use crate::optimizer::plan_node::{EqJoinPredicate, EqJoinPredicateDisplay}; +use crate::optimizer::property::MonotonicityMap; +use crate::stream_fragmenter::BuildFragmentGraphState; +use crate::TableCatalog; + +/// [`StreamAsOfJoin`] implements [`super::LogicalJoin`] with hash tables. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct StreamAsOfJoin { + pub base: PlanBase, + core: generic::Join, + + /// The join condition must be equivalent to `logical.on`, but separated into equal and + /// non-equal parts to facilitate execution later + eq_join_predicate: EqJoinPredicate, + + /// Whether can optimize for append-only stream. + /// It is true if input of both side is append-only + is_append_only: bool, + + /// inequality description + inequality_desc: AsOfJoinDesc, +} + +impl StreamAsOfJoin { + pub fn new( + core: generic::Join, + eq_join_predicate: EqJoinPredicate, + inequality_desc: AsOfJoinDesc, + ) -> Self { + assert!(core.join_type == JoinType::AsofInner || core.join_type == JoinType::AsofLeftOuter); + + // Inner join won't change the append-only behavior of the stream. The rest might. + let append_only = match core.join_type { + JoinType::Inner => core.left.append_only() && core.right.append_only(), + _ => false, + }; + + let dist = StreamJoinCommon::derive_dist( + core.left.distribution(), + core.right.distribution(), + &core, + ); + + // TODO: derive watermarks + let watermark_columns = FixedBitSet::with_capacity(core.schema().len()); + + // TODO: derive from input + let base = PlanBase::new_stream_with_core( + &core, + dist, + append_only, + false, // TODO(rc): derive EOWC property from input + watermark_columns, + MonotonicityMap::new(), // TODO: derive monotonicity + ); + + Self { + base, + core, + eq_join_predicate, + is_append_only: append_only, + inequality_desc, + } + } + + pub fn get_inequality_desc_from_predicate( + predicate: EqJoinPredicate, + left_input_len: usize, + ) -> Result { + let expr: ExprImpl = predicate.other_cond().clone().into(); + if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() { + if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len + { + Ok(AsOfJoinDesc { + left_idx: left_input_ref.index() as u32, + right_idx: (right_input_ref.index() - left_input_len) as u32, + inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(), + }) + } else { + bail!("inequal condition from the same side should be push down in optimizer"); + } + } else { + Err(ErrorCode::InvalidInputSyntax( + "AsOf join requires exactly 1 ineuquality condition".to_string(), + ) + .into()) + } + } + + fn expr_type_to_comparison_type(expr_type: PbType) -> Result { + match expr_type { + PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt), + PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe), + PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt), + PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe), + _ => Err(ErrorCode::InvalidInputSyntax(format!( + "Invalid comparison type: {}", + expr_type.as_str_name() + )) + .into()), + } + } + + /// Get join type + pub fn join_type(&self) -> JoinType { + self.core.join_type + } + + /// Get a reference to the `AsOf` join's eq join predicate. + pub fn eq_join_predicate(&self) -> &EqJoinPredicate { + &self.eq_join_predicate + } + + pub fn derive_dist_key_in_join_key(&self) -> Vec { + let left_dk_indices = self.left().distribution().dist_column_indices().to_vec(); + let right_dk_indices = self.right().distribution().dist_column_indices().to_vec(); + + StreamJoinCommon::get_dist_key_in_join_key( + &left_dk_indices, + &right_dk_indices, + self.eq_join_predicate(), + ) + } + + /// Return stream asof join internal table catalog. + pub fn infer_internal_table_catalog( + input: I, + join_key_indices: Vec, + dk_indices_in_jk: Vec, + inequality_key_idx: usize, + ) -> (TableCatalog, Vec) { + let schema = input.schema(); + + let internal_table_dist_keys = dk_indices_in_jk + .iter() + .map(|idx| join_key_indices[*idx]) + .collect_vec(); + + // The pk of AsOf join internal table should be join_key + inequality_key + input_pk. + let join_key_len = join_key_indices.len(); + let mut pk_indices = join_key_indices; + + // dedup the pk in dist key.. + let mut deduped_input_pk_indices = vec![]; + for input_pk_idx in input.stream_key().unwrap() { + if !pk_indices.contains(input_pk_idx) + && !deduped_input_pk_indices.contains(input_pk_idx) + { + deduped_input_pk_indices.push(*input_pk_idx); + } + } + + pk_indices.push(inequality_key_idx); + pk_indices.extend(deduped_input_pk_indices.clone()); + + // Build internal table + let mut internal_table_catalog_builder = TableCatalogBuilder::default(); + let internal_columns_fields = schema.fields().to_vec(); + + internal_columns_fields.iter().for_each(|field| { + internal_table_catalog_builder.add_column(field); + }); + pk_indices.iter().for_each(|idx| { + internal_table_catalog_builder.add_order_column(*idx, OrderType::ascending()) + }); + + internal_table_catalog_builder.set_dist_key_in_pk(dk_indices_in_jk.clone()); + + ( + internal_table_catalog_builder.build(internal_table_dist_keys, join_key_len), + deduped_input_pk_indices, + ) + } +} + +impl Distill for StreamAsOfJoin { + fn distill<'a>(&self) -> XmlNode<'a> { + let (ljk, rjk) = self + .eq_join_predicate + .eq_indexes() + .first() + .cloned() + .expect("first join key"); + + let name = plan_node_name!("StreamAsOfJoin", + { "window", self.left().watermark_columns().contains(ljk) && self.right().watermark_columns().contains(rjk) }, + { "append_only", self.is_append_only }, + ); + let verbose = self.base.ctx().is_explain_verbose(); + let mut vec = Vec::with_capacity(6); + vec.push(("type", Pretty::debug(&self.core.join_type))); + + let concat_schema = self.core.concat_schema(); + vec.push(( + "predicate", + Pretty::debug(&EqJoinPredicateDisplay { + eq_join_predicate: self.eq_join_predicate(), + input_schema: &concat_schema, + }), + )); + + if let Some(ow) = watermark_pretty(self.base.watermark_columns(), self.schema()) { + vec.push(("output_watermarks", ow)); + } + + if verbose { + let data = IndicesDisplay::from_join(&self.core, &concat_schema); + vec.push(("output", data)); + } + + childless_record(name, vec) + } +} + +impl PlanTreeNodeBinary for StreamAsOfJoin { + fn left(&self) -> PlanRef { + self.core.left.clone() + } + + fn right(&self) -> PlanRef { + self.core.right.clone() + } + + fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self { + let mut core = self.core.clone(); + core.left = left; + core.right = right; + Self::new(core, self.eq_join_predicate.clone(), self.inequality_desc) + } +} + +impl_plan_tree_node_for_binary! { StreamAsOfJoin } + +impl StreamNode for StreamAsOfJoin { + fn to_stream_prost_body(&self, state: &mut BuildFragmentGraphState) -> NodeBody { + let left_jk_indices = self.eq_join_predicate.left_eq_indexes(); + let right_jk_indices = self.eq_join_predicate.right_eq_indexes(); + let left_jk_indices_prost = left_jk_indices.iter().map(|idx| *idx as i32).collect_vec(); + let right_jk_indices_prost = right_jk_indices.iter().map(|idx| *idx as i32).collect_vec(); + + let dk_indices_in_jk = self.derive_dist_key_in_join_key(); + + let (left_table, left_deduped_input_pk_indices) = Self::infer_internal_table_catalog( + self.left().plan_base(), + left_jk_indices, + dk_indices_in_jk.clone(), + self.inequality_desc.left_idx as usize, + ); + let (right_table, right_deduped_input_pk_indices) = Self::infer_internal_table_catalog( + self.right().plan_base(), + right_jk_indices, + dk_indices_in_jk, + self.inequality_desc.right_idx as usize, + ); + + let left_deduped_input_pk_indices = left_deduped_input_pk_indices + .iter() + .map(|idx| *idx as u32) + .collect_vec(); + + let right_deduped_input_pk_indices = right_deduped_input_pk_indices + .iter() + .map(|idx| *idx as u32) + .collect_vec(); + + let left_table = left_table.with_id(state.gen_table_id_wrapped()); + let right_table = right_table.with_id(state.gen_table_id_wrapped()); + + let null_safe_prost = self.eq_join_predicate.null_safes().into_iter().collect(); + + let asof_join_type = match self.core.join_type { + JoinType::AsofInner => AsOfJoinType::Inner, + JoinType::AsofLeftOuter => AsOfJoinType::LeftOuter, + _ => unreachable!(), + }; + + NodeBody::AsOfJoin(AsOfJoinNode { + join_type: asof_join_type.into(), + left_key: left_jk_indices_prost, + right_key: right_jk_indices_prost, + null_safe: null_safe_prost, + left_table: Some(left_table.to_internal_table_prost()), + right_table: Some(right_table.to_internal_table_prost()), + left_deduped_input_pk_indices, + right_deduped_input_pk_indices, + output_indices: self.core.output_indices.iter().map(|&x| x as u32).collect(), + asof_desc: Some(self.inequality_desc), + }) + } +} + +impl ExprRewritable for StreamAsOfJoin { + fn has_rewritable_expr(&self) -> bool { + true + } + + fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef { + let mut core = self.core.clone(); + core.rewrite_exprs(r); + let eq_join_predicate = self.eq_join_predicate.rewrite_exprs(r); + let desc = Self::get_inequality_desc_from_predicate( + eq_join_predicate.clone(), + core.left.schema().len(), + ) + .unwrap(); + Self::new(core, eq_join_predicate, desc).into() + } +} + +impl ExprVisitable for StreamAsOfJoin { + fn visit_exprs(&self, v: &mut dyn ExprVisitor) { + self.core.visit_exprs(v); + } +} diff --git a/src/frontend/src/optimizer/plan_node/stream_delta_join.rs b/src/frontend/src/optimizer/plan_node/stream_delta_join.rs index f53d4331ae617..84592aee1829a 100644 --- a/src/frontend/src/optimizer/plan_node/stream_delta_join.rs +++ b/src/frontend/src/optimizer/plan_node/stream_delta_join.rs @@ -86,7 +86,7 @@ impl StreamDeltaJoin { } } - /// Get a reference to the batch hash join's eq join predicate. + /// Get a reference to the delta hash join's eq join predicate. pub fn eq_join_predicate(&self) -> &EqJoinPredicate { &self.eq_join_predicate } diff --git a/src/frontend/src/optimizer/plan_node/stream_hash_join.rs b/src/frontend/src/optimizer/plan_node/stream_hash_join.rs index cbce1e1caf45a..0d7863a247d9c 100644 --- a/src/frontend/src/optimizer/plan_node/stream_hash_join.rs +++ b/src/frontend/src/optimizer/plan_node/stream_hash_join.rs @@ -15,13 +15,13 @@ use fixedbitset::FixedBitSet; use itertools::Itertools; use pretty_xmlish::{Pretty, XmlNode}; -use risingwave_common::util::iter_util::ZipEqFast; use risingwave_pb::plan_common::JoinType; use risingwave_pb::stream_plan::stream_node::NodeBody; use risingwave_pb::stream_plan::{DeltaExpression, HashJoinNode, PbInequalityPair}; use super::generic::Join; use super::stream::prelude::*; +use super::stream_join_common::StreamJoinCommon; use super::utils::{childless_record, plan_node_name, watermark_pretty, Distill}; use super::{ generic, ExprRewritable, PlanBase, PlanRef, PlanTreeNodeBinary, StreamDeltaJoin, StreamNode, @@ -30,7 +30,7 @@ use crate::expr::{Expr, ExprDisplay, ExprRewriter, ExprVisitor, InequalityInputP use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::utils::IndicesDisplay; use crate::optimizer::plan_node::{EqJoinPredicate, EqJoinPredicateDisplay}; -use crate::optimizer::property::{Distribution, MonotonicityMap}; +use crate::optimizer::property::MonotonicityMap; use crate::stream_fragmenter::BuildFragmentGraphState; use crate::utils::ColIndexMappingRewriteExt; @@ -72,7 +72,11 @@ impl StreamHashJoin { _ => false, }; - let dist = Self::derive_dist(core.left.distribution(), core.right.distribution(), &core); + let dist = StreamJoinCommon::derive_dist( + core.left.distribution(), + core.right.distribution(), + &core, + ); let mut inequality_pairs = vec![]; let mut clean_left_state_conjunction_idx = None; @@ -215,48 +219,11 @@ impl StreamHashJoin { self.core.join_type } - /// Get a reference to the batch hash join's eq join predicate. + /// Get a reference to the hash join's eq join predicate. pub fn eq_join_predicate(&self) -> &EqJoinPredicate { &self.eq_join_predicate } - pub(super) fn derive_dist( - left: &Distribution, - right: &Distribution, - logical: &generic::Join, - ) -> Distribution { - match (left, right) { - (Distribution::Single, Distribution::Single) => Distribution::Single, - (Distribution::HashShard(_), Distribution::HashShard(_)) => { - // we can not derive the hash distribution from the side where outer join can - // generate a NULL row - match logical.join_type { - JoinType::Unspecified => unreachable!(), - JoinType::FullOuter => Distribution::SomeShard, - JoinType::Inner - | JoinType::LeftOuter - | JoinType::LeftSemi - | JoinType::LeftAnti => { - let l2o = logical - .l2i_col_mapping() - .composite(&logical.i2o_col_mapping()); - l2o.rewrite_provided_distribution(left) - } - JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => { - let r2o = logical - .r2i_col_mapping() - .composite(&logical.i2o_col_mapping()); - r2o.rewrite_provided_distribution(right) - } - } - } - (_, _) => unreachable!( - "suspicious distribution: left: {:?}, right: {:?}", - left, right - ), - } - } - /// Convert this hash join to a delta join plan pub fn into_delta_join(self) -> StreamDeltaJoin { StreamDeltaJoin::new(self.core, self.eq_join_predicate) @@ -265,24 +232,12 @@ impl StreamHashJoin { pub fn derive_dist_key_in_join_key(&self) -> Vec { let left_dk_indices = self.left().distribution().dist_column_indices().to_vec(); let right_dk_indices = self.right().distribution().dist_column_indices().to_vec(); - let left_jk_indices = self.eq_join_predicate.left_eq_indexes(); - let right_jk_indices = self.eq_join_predicate.right_eq_indexes(); - - assert_eq!(left_jk_indices.len(), right_jk_indices.len()); - - let mut dk_indices_in_jk = vec![]; - - for (l_dk_idx, r_dk_idx) in left_dk_indices.iter().zip_eq_fast(right_dk_indices.iter()) { - for dk_idx_in_jk in left_jk_indices.iter().positions(|idx| idx == l_dk_idx) { - if right_jk_indices[dk_idx_in_jk] == *r_dk_idx { - dk_indices_in_jk.push(dk_idx_in_jk); - break; - } - } - } - assert_eq!(dk_indices_in_jk.len(), left_dk_indices.len()); - dk_indices_in_jk + StreamJoinCommon::get_dist_key_in_join_key( + &left_dk_indices, + &right_dk_indices, + self.eq_join_predicate(), + ) } pub fn inequality_pairs(&self) -> &Vec<(bool, InequalityInputPair)> { diff --git a/src/frontend/src/optimizer/plan_node/stream_join_common.rs b/src/frontend/src/optimizer/plan_node/stream_join_common.rs new file mode 100644 index 0000000000000..f44ab8291f444 --- /dev/null +++ b/src/frontend/src/optimizer/plan_node/stream_join_common.rs @@ -0,0 +1,88 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use itertools::Itertools; +use risingwave_common::util::iter_util::ZipEqFast; +use risingwave_pb::plan_common::JoinType; + +use super::{generic, EqJoinPredicate}; +use crate::optimizer::property::Distribution; +use crate::utils::ColIndexMappingRewriteExt; +use crate::PlanRef; + +pub struct StreamJoinCommon; + +impl StreamJoinCommon { + pub(super) fn get_dist_key_in_join_key( + left_dk_indices: &[usize], + right_dk_indices: &[usize], + eq_join_predicate: &EqJoinPredicate, + ) -> Vec { + let left_jk_indices = eq_join_predicate.left_eq_indexes(); + let right_jk_indices = &eq_join_predicate.right_eq_indexes(); + assert_eq!(left_jk_indices.len(), right_jk_indices.len()); + let mut dk_indices_in_jk = vec![]; + for (l_dk_idx, r_dk_idx) in left_dk_indices.iter().zip_eq_fast(right_dk_indices.iter()) { + for dk_idx_in_jk in left_jk_indices.iter().positions(|idx| idx == l_dk_idx) { + if right_jk_indices[dk_idx_in_jk] == *r_dk_idx { + dk_indices_in_jk.push(dk_idx_in_jk); + break; + } + } + } + assert_eq!(dk_indices_in_jk.len(), left_dk_indices.len()); + dk_indices_in_jk + } + + pub(super) fn derive_dist( + left: &Distribution, + right: &Distribution, + logical: &generic::Join, + ) -> Distribution { + match (left, right) { + (Distribution::Single, Distribution::Single) => Distribution::Single, + (Distribution::HashShard(_), Distribution::HashShard(_)) => { + // we can not derive the hash distribution from the side where outer join can + // generate a NULL row + match logical.join_type { + JoinType::Unspecified => { + unreachable!() + } + JoinType::FullOuter => Distribution::SomeShard, + JoinType::Inner + | JoinType::LeftOuter + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::AsofInner + | JoinType::AsofLeftOuter => { + let l2o = logical + .l2i_col_mapping() + .composite(&logical.i2o_col_mapping()); + l2o.rewrite_provided_distribution(left) + } + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => { + let r2o = logical + .r2i_col_mapping() + .composite(&logical.i2o_col_mapping()); + r2o.rewrite_provided_distribution(right) + } + } + } + (_, _) => unreachable!( + "suspicious distribution: left: {:?}, right: {:?}", + left, right + ), + } + } +} diff --git a/src/frontend/src/optimizer/plan_visitor/cardinality_visitor.rs b/src/frontend/src/optimizer/plan_visitor/cardinality_visitor.rs index b17a8318d2b1a..51d61b3afc776 100644 --- a/src/frontend/src/optimizer/plan_visitor/cardinality_visitor.rs +++ b/src/frontend/src/optimizer/plan_visitor/cardinality_visitor.rs @@ -181,6 +181,11 @@ impl PlanVisitor for CardinalityVisitor { // TODO: refine the cardinality of full outer join JoinType::FullOuter => Cardinality::unknown(), + + // For each row from one side, we match `0..=1` rows from the other side. + JoinType::AsofInner => left.mul(right.min(0..=1)), + // For each row from left side, we match exactly 1 row from the right side or a `NULL` row`. + JoinType::AsofLeftOuter => left, } } diff --git a/src/frontend/src/optimizer/rule/apply_join_transpose_rule.rs b/src/frontend/src/optimizer/rule/apply_join_transpose_rule.rs index 3da0348936238..fc6cbdd477539 100644 --- a/src/frontend/src/optimizer/rule/apply_join_transpose_rule.rs +++ b/src/frontend/src/optimizer/rule/apply_join_transpose_rule.rs @@ -130,7 +130,10 @@ impl Rule for ApplyJoinTransposeRule { let (push_left, push_right) = match join.join_type() { // `LeftSemi`, `LeftAnti`, `LeftOuter` can only push to left side if it's right side has // no correlated id. Otherwise push to both sides. - JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftOuter => { + JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::LeftOuter + | JoinType::AsofLeftOuter => { if !join_right_has_correlated_id { (true, false) } else { @@ -147,7 +150,7 @@ impl Rule for ApplyJoinTransposeRule { } } // `Inner` can push to one side if the other side is not dependent on it. - JoinType::Inner => { + JoinType::Inner | JoinType::AsofInner => { if join_cond_has_correlated_id && !join_right_has_correlated_id && !join_left_has_correlated_id @@ -236,7 +239,12 @@ impl ApplyJoinTransposeRule { JoinType::LeftSemi | JoinType::LeftAnti => { left_apply_condition.extend(apply_on); } - JoinType::Inner | JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter => { + JoinType::Inner + | JoinType::LeftOuter + | JoinType::RightOuter + | JoinType::FullOuter + | JoinType::AsofInner + | JoinType::AsofLeftOuter => { let apply_len = apply_left_len + join.schema().len(); let mut d_t1_bit_set = FixedBitSet::with_capacity(apply_len); d_t1_bit_set.set_range(0..apply_left_len + join_left_len, true); @@ -316,7 +324,12 @@ impl ApplyJoinTransposeRule { JoinType::RightSemi | JoinType::RightAnti => { right_apply_condition.extend(apply_on); } - JoinType::Inner | JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter => { + JoinType::Inner + | JoinType::LeftOuter + | JoinType::RightOuter + | JoinType::FullOuter + | JoinType::AsofInner + | JoinType::AsofLeftOuter => { let apply_len = apply_left_len + join.schema().len(); let mut d_t2_bit_set = FixedBitSet::with_capacity(apply_len); d_t2_bit_set.set_range(0..apply_left_len, true); @@ -456,7 +469,12 @@ impl ApplyJoinTransposeRule { JoinType::RightSemi | JoinType::RightAnti => { right_apply_condition.extend(apply_on); } - JoinType::Inner | JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter => { + JoinType::Inner + | JoinType::LeftOuter + | JoinType::RightOuter + | JoinType::FullOuter + | JoinType::AsofInner + | JoinType::AsofLeftOuter => { let apply_len = apply_left_len + join.schema().len(); let mut d_t1_bit_set = FixedBitSet::with_capacity(apply_len); let mut d_t2_bit_set = FixedBitSet::with_capacity(apply_len); @@ -555,7 +573,12 @@ impl ApplyJoinTransposeRule { JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightSemi | JoinType::RightAnti => { new_join.into() } - JoinType::Inner | JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter => { + JoinType::Inner + | JoinType::LeftOuter + | JoinType::RightOuter + | JoinType::FullOuter + | JoinType::AsofInner + | JoinType::AsofLeftOuter => { let mut output_indices_mapping = ColIndexMapping::new( output_indices.iter().map(|x| Some(*x)).collect(), target_size, diff --git a/src/frontend/src/optimizer/rule/join_commute_rule.rs b/src/frontend/src/optimizer/rule/join_commute_rule.rs index 405e28d6825fc..55b975ccb9717 100644 --- a/src/frontend/src/optimizer/rule/join_commute_rule.rs +++ b/src/frontend/src/optimizer/rule/join_commute_rule.rs @@ -72,6 +72,8 @@ impl Rule for JoinCommuteRule { | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::FullOuter + | JoinType::AsofInner + | JoinType::AsofLeftOuter | JoinType::Unspecified => None, } } @@ -116,6 +118,7 @@ impl JoinCommuteRule { JoinType::LeftAnti => JoinType::RightAnti, JoinType::RightSemi => JoinType::LeftSemi, JoinType::RightAnti => JoinType::LeftAnti, + JoinType::AsofInner | JoinType::AsofLeftOuter => unreachable!(), } } } diff --git a/src/frontend/src/optimizer/rule/translate_apply_rule.rs b/src/frontend/src/optimizer/rule/translate_apply_rule.rs index 876ca7d6285b2..87ccbd5924728 100644 --- a/src/frontend/src/optimizer/rule/translate_apply_rule.rs +++ b/src/frontend/src/optimizer/rule/translate_apply_rule.rs @@ -233,8 +233,9 @@ impl TranslateApplyRule { | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti - | JoinType::RightOuter => rewrite(join.right(), right_idxs, true), - JoinType::LeftOuter | JoinType::FullOuter => None, + | JoinType::RightOuter + | JoinType::AsofInner => rewrite(join.right(), right_idxs, true), + JoinType::LeftOuter | JoinType::FullOuter | JoinType::AsofLeftOuter => None, JoinType::Unspecified => unreachable!(), } } @@ -246,7 +247,9 @@ impl TranslateApplyRule { | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti - | JoinType::LeftOuter => rewrite(join.left(), left_idxs, false), + | JoinType::LeftOuter + | JoinType::AsofInner + | JoinType::AsofLeftOuter => rewrite(join.left(), left_idxs, false), JoinType::RightOuter | JoinType::FullOuter => None, JoinType::Unspecified => unreachable!(), } @@ -258,14 +261,18 @@ impl TranslateApplyRule { | JoinType::LeftSemi | JoinType::RightSemi | JoinType::LeftAnti - | JoinType::RightAnti => { + | JoinType::RightAnti + | JoinType::AsofInner => { let left = rewrite(join.left(), left_idxs, false)?; let right = rewrite(join.right(), right_idxs, true)?; let new_join = LogicalJoin::new(left, right, join.join_type(), Condition::true_cond()); Some(new_join.into()) } - JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter => None, + JoinType::LeftOuter + | JoinType::RightOuter + | JoinType::FullOuter + | JoinType::AsofLeftOuter => None, JoinType::Unspecified => unreachable!(), } } @@ -300,7 +307,12 @@ impl TranslateApplyRule { if !left_idxs.is_empty() && right_idxs.is_empty() { // Deal with multi scalar subqueries match apply.join_type() { - JoinType::Inner | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftOuter => { + JoinType::Inner + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::LeftOuter + | JoinType::AsofInner + | JoinType::AsofLeftOuter => { let plan = apply.left(); Self::rewrite(&plan, left_idxs, offset, index_mapping, data_types, index) } diff --git a/src/prost/build.rs b/src/prost/build.rs index 0afbaef2ea730..ee04705ef19e5 100644 --- a/src/prost/build.rs +++ b/src/prost/build.rs @@ -166,6 +166,7 @@ fn main() -> Result<(), Box> { "plan_common.AdditionalCollectionName", "#[derive(Eq, Hash)]", ) + .type_attribute("plan_common.AsOfJoinDesc", "#[derive(Eq, Hash)]") .type_attribute("common.ColumnOrder", "#[derive(Eq, Hash)]") .type_attribute("common.OrderType", "#[derive(Eq, Hash)]") .type_attribute("common.Buffer", "#[derive(Eq)]") diff --git a/src/sqlparser/src/ast/query.rs b/src/sqlparser/src/ast/query.rs index b16a3075f90d9..be03a5f1133ec 100644 --- a/src/sqlparser/src/ast/query.rs +++ b/src/sqlparser/src/ast/query.rs @@ -584,6 +584,20 @@ impl fmt::Display for Join { suffix(constraint) ), JoinOperator::CrossJoin => write!(f, " CROSS JOIN {}", self.relation), + JoinOperator::AsOfInner(constraint) => write!( + f, + " {}ASOF JOIN {}{}", + prefix(constraint), + self.relation, + suffix(constraint) + ), + JoinOperator::AsOfLeft(constraint) => write!( + f, + " {}ASOF LEFT JOIN {}{}", + prefix(constraint), + self.relation, + suffix(constraint) + ), } } } @@ -596,6 +610,8 @@ pub enum JoinOperator { RightOuter(JoinConstraint), FullOuter(JoinConstraint), CrossJoin, + AsOfInner(JoinConstraint), + AsOfLeft(JoinConstraint), } #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/src/sqlparser/src/keywords.rs b/src/sqlparser/src/keywords.rs index 151d66e8083ba..8ec7191f749c2 100644 --- a/src/sqlparser/src/keywords.rs +++ b/src/sqlparser/src/keywords.rs @@ -88,6 +88,7 @@ define_keywords!( AS, ASC, ASENSITIVE, + ASOF, ASYMMETRIC, ASYNC, AT, @@ -616,6 +617,7 @@ pub const RESERVED_FOR_TABLE_ALIAS: &[Keyword] = &[ Keyword::LEFT, Keyword::RIGHT, Keyword::NATURAL, + Keyword::ASOF, Keyword::USING, Keyword::CLUSTER, // for MSSQL-specific OUTER APPLY (seems reserved in most dialects) diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index 1f755a0dfa334..a36844b5619fb 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -4636,7 +4636,13 @@ impl Parser<'_> { join_operator, } } else { - let natural = self.parse_keyword(Keyword::NATURAL); + let (natural, asof) = + match self.parse_one_of_keywords(&[Keyword::NATURAL, Keyword::ASOF]) { + Some(Keyword::NATURAL) => (true, false), + Some(Keyword::ASOF) => (false, true), + Some(_) => unreachable!(), + None => (false, false), + }; let peek_keyword = if let Token::Word(w) = self.peek_token().token { w.keyword } else { @@ -4647,17 +4653,33 @@ impl Parser<'_> { Keyword::INNER | Keyword::JOIN => { let _ = self.parse_keyword(Keyword::INNER); self.expect_keyword(Keyword::JOIN)?; - JoinOperator::Inner + if asof { + JoinOperator::AsOfInner + } else { + JoinOperator::Inner + } } kw @ Keyword::LEFT | kw @ Keyword::RIGHT | kw @ Keyword::FULL => { + let checkpoint = *self; let _ = self.next_token(); let _ = self.parse_keyword(Keyword::OUTER); self.expect_keyword(Keyword::JOIN)?; - match kw { - Keyword::LEFT => JoinOperator::LeftOuter, - Keyword::RIGHT => JoinOperator::RightOuter, - Keyword::FULL => JoinOperator::FullOuter, - _ => unreachable!(), + if asof { + if Keyword::LEFT == kw { + JoinOperator::AsOfLeft + } else { + return self.expected_at( + checkpoint, + "LEFT after ASOF. RIGHT or FULL are not supported", + ); + } + } else { + match kw { + Keyword::LEFT => JoinOperator::LeftOuter, + Keyword::RIGHT => JoinOperator::RightOuter, + Keyword::FULL => JoinOperator::FullOuter, + _ => unreachable!(), + } } } Keyword::OUTER => { @@ -4666,14 +4688,24 @@ impl Parser<'_> { _ if natural => { return self.expected("a join type after NATURAL"); } + _ if asof => { + return self.expected("a join type after ASOF"); + } _ => break, }; let relation = self.parse_table_factor()?; let join_constraint = self.parse_join_constraint(natural)?; let join_operator = join_operator_type(join_constraint); - if let JoinOperator::Inner(JoinConstraint::None) = join_operator { - return self.expected("join constraint after INNER JOIN"); + let need_constraint = match join_operator { + JoinOperator::Inner(JoinConstraint::None) => Some("INNER JOIN"), + JoinOperator::AsOfInner(JoinConstraint::None) => Some("ASOF INNER JOIN"), + JoinOperator::AsOfLeft(JoinConstraint::None) => Some("ASOF LEFT JOIN"), + _ => None, + }; + if let Some(join_type) = need_constraint { + return self.expected(&format!("join constraint after {join_type}")); } + Join { relation, join_operator, diff --git a/src/sqlparser/tests/testdata/asof_join.yaml b/src/sqlparser/tests/testdata/asof_join.yaml new file mode 100644 index 0000000000000..b7ee5b1461b76 --- /dev/null +++ b/src/sqlparser/tests/testdata/asof_join.yaml @@ -0,0 +1,17 @@ +# This file is automatically generated by `src/sqlparser/tests/parser_test.rs`. +- input: SELECT * FROM t1 asof JOIN t2 where t1.v1 = t2.v1 + error_msg: |- + sql parser error: expected join constraint after ASOF INNER JOIN, found: where + LINE 1: SELECT * FROM t1 asof JOIN t2 where t1.v1 = t2.v1 + ^ +- input: SELECT * FROM t1 asof LEFT JOIN t2 ON t1.v1 = t2.v1 and t1.v2 > t2.v2 + formatted_sql: SELECT * FROM t1 ASOF LEFT JOIN t2 ON t1.v1 = t2.v1 AND t1.v2 > t2.v2 + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [Wildcard(None)], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "t1", quote_style: None }]), alias: None, as_of: None }, joins: [Join { relation: Table { name: ObjectName([Ident { value: "t2", quote_style: None }]), alias: None, as_of: None }, join_operator: AsOfLeft(On(BinaryOp { left: BinaryOp { left: CompoundIdentifier([Ident { value: "t1", quote_style: None }, Ident { value: "v1", quote_style: None }]), op: Eq, right: CompoundIdentifier([Ident { value: "t2", quote_style: None }, Ident { value: "v1", quote_style: None }]) }, op: And, right: BinaryOp { left: CompoundIdentifier([Ident { value: "t1", quote_style: None }, Ident { value: "v2", quote_style: None }]), op: Gt, right: CompoundIdentifier([Ident { value: "t2", quote_style: None }, Ident { value: "v2", quote_style: None }]) } })) }] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' +- input: SELECT * FROM t1 asof INNER JOIN t2 ON t1.v1 = t2.v1 and t1.v2 > t2.v2 + formatted_sql: SELECT * FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1 AND t1.v2 > t2.v2 + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [Wildcard(None)], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "t1", quote_style: None }]), alias: None, as_of: None }, joins: [Join { relation: Table { name: ObjectName([Ident { value: "t2", quote_style: None }]), alias: None, as_of: None }, join_operator: AsOfInner(On(BinaryOp { left: BinaryOp { left: CompoundIdentifier([Ident { value: "t1", quote_style: None }, Ident { value: "v1", quote_style: None }]), op: Eq, right: CompoundIdentifier([Ident { value: "t2", quote_style: None }, Ident { value: "v1", quote_style: None }]) }, op: And, right: BinaryOp { left: CompoundIdentifier([Ident { value: "t1", quote_style: None }, Ident { value: "v2", quote_style: None }]), op: Gt, right: CompoundIdentifier([Ident { value: "t2", quote_style: None }, Ident { value: "v2", quote_style: None }]) } })) }] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' +- input: SELECT * FROM t1 asof RIGHT JOIN t2 ON t1.v1 = t2.v1 and t1.v2 > t2.v2 + error_msg: |- + sql parser error: expected LEFT after ASOF. RIGHT or FULL are not supported, found: RIGHT + LINE 1: SELECT * FROM t1 asof RIGHT JOIN t2 ON t1.v1 = t2.v1 and t1.v2 > t2.v2 + ^ diff --git a/src/stream/src/from_proto/hash_join.rs b/src/stream/src/from_proto/hash_join.rs index 2d421274cec39..42034b64b0af5 100644 --- a/src/stream/src/from_proto/hash_join.rs +++ b/src/stream/src/from_proto/hash_join.rs @@ -223,7 +223,9 @@ impl HashKeyDispatcher for HashJoinExecutorDispatcherArgs { }; } match self.join_type_proto { - JoinTypeProto::Unspecified => unreachable!(), + JoinTypeProto::AsofInner + | JoinTypeProto::AsofLeftOuter + | JoinTypeProto::Unspecified => unreachable!(), JoinTypeProto::Inner => build!(Inner), JoinTypeProto::LeftOuter => build!(LeftOuter), JoinTypeProto::RightOuter => build!(RightOuter), diff --git a/src/stream/src/from_proto/mod.rs b/src/stream/src/from_proto/mod.rs index 9a51dd10ddfb7..5ac5379ca57cf 100644 --- a/src/stream/src/from_proto/mod.rs +++ b/src/stream/src/from_proto/mod.rs @@ -67,6 +67,7 @@ use risingwave_storage::StateStore; use self::append_only_dedup::*; use self::approx_percentile::global::*; use self::approx_percentile::local::*; +use self::asof_join::AsOfJoinExecutorBuilder; use self::barrier_recv::*; use self::batch_query::*; use self::cdc_filter::CdcFilterExecutorBuilder; @@ -186,5 +187,6 @@ pub async fn create_executor( NodeBody::GlobalApproxPercentile => GlobalApproxPercentileExecutorBuilder, NodeBody::LocalApproxPercentile => LocalApproxPercentileExecutorBuilder, NodeBody::RowMerge => RowMergeExecutorBuilder, + NodeBody::AsOfJoin => AsOfJoinExecutorBuilder, } }