Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(batch): introduce batch AsOf join #19790

Merged
merged 12 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions e2e_test/batch/join/asof_join.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
statement ok
SET RW_IMPLICIT_FLUSH TO true;

statement ok
create table t1 (v1 int, v2 int, v3 int primary key);

statement ok
create table t2 (v1 int, v2 int, v3 int primary key);

statement ok
insert into t1 values (1, 2, 3), (2, 3, 4);

statement ok
insert into t2 values (1, 3, 4), (1, 2, 5), (1, 2, 6);

# asof inner join
query IIIIII
SELECT t1.v1 t1_v1, t1.v2 t1_v2, t1.v3 t1_v3, t2.v1 t2_v1, t2.v2 t2_v2, t2.v3 t2_v3 FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1 and t1.v2 < t2.v2 order by t1.v1;
----
1 2 3 1 3 4

# asof left join
query IIIIII
SELECT t1.v1 t1_v1, t1.v2 t1_v2, t1.v3 t1_v3, t2.v1 t2_v1, t2.v2 t2_v2, t2.v3 t2_v3 FROM t1 ASOF LEFT JOIN t2 ON t1.v1 = t2.v1 and t1.v2 < t2.v2 order by t1.v1;
----
1 2 3 1 3 4
2 3 4 NULL NULL NULL

# asof left join
query IIIIII
SELECT t1.v1 t1_v1, t1.v2 t1_v2, t1.v3 t1_v3, t2.v1 t2_v1, t2.v2 t2_v2, t2.v3 t2_v3 FROM t1 ASOF LEFT JOIN t2 ON t1.v1 = t2.v1 and t1.v2 > t2.v2 order by t1.v1;
----
1 2 3 NULL NULL NULL
2 3 4 NULL NULL NULL

statement ok
drop table t1;

statement ok
drop table t2;
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,36 @@
SELECT * FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1;
expected_outputs:
- stream_error
- batch_error

- sql:
CREATE TABLE t1(v1 varchar, v2 int, v3 int);
CREATE TABLE t2(v1 varchar, v2 int, v3 int);
SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1 || 'a' and t1.v2 > t2.v2;
expected_outputs:
- batch_error
- stream_plan
- batch_plan

- sql:
CREATE TABLE t1(v1 varchar, v2 int, v3 int);
CREATE TABLE t2(v1 varchar, v2 int, v3 int);
SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF LEFT JOIN t2 ON t1.v1 = t2.v1 and t1.v2 *2 < t2.v2;
expected_outputs:
- stream_plan
- batch_plan

- sql:
CREATE TABLE t1(v1 varchar, v2 int, v3 int);
CREATE TABLE t2(v1 varchar, v2 int, v3 int);
SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1 and t1.v2 < t2.v2 and t1.v3 < t2.v3;
expected_outputs:
- stream_error
- batch_error

- sql:
CREATE TABLE t1(v1 varchar, v2 int, v3 int);
CREATE TABLE t2(v1 varchar, v2 int, v3 int);
SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v2 < t2.v2;
expected_outputs:
- stream_error
- batch_error
27 changes: 24 additions & 3 deletions src/frontend/planner_test/tests/testdata/output/asof_join.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
# This file is automatically generated. See `src/frontend/planner_test/README.md` for more information.
- sql: CREATE TABLE t1(v1 varchar, v2 int, v3 int); CREATE TABLE t2(v1 varchar, v2 int, v3 int); SELECT * FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1;
batch_error: 'Invalid input syntax: AsOf join requires exactly 1 ineuquality condition'
stream_error: 'Invalid input syntax: AsOf join requires exactly 1 ineuquality condition'
- sql: CREATE TABLE t1(v1 varchar, v2 int, v3 int); CREATE TABLE t2(v1 varchar, v2 int, v3 int); SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1 || 'a' and t1.v2 > t2.v2;
batch_plan: |-
BatchExchange { order: [], dist: Single }
└─BatchGroupTopN { order: [t2.v2 DESC], limit: 1, offset: 0, group_key: [t1.v1, t1.v2] }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This plan seems incorrect because it will make the rows less then expected

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the implementation and this should be resolved now.

└─BatchExchange { order: [], dist: HashShard(t1.v1, t1.v2) }
└─BatchHashJoin { type: Inner, predicate: t1.v1 = $expr1 AND (t1.v2 > t2.v2), output: [t1.v1, t1.v2, t2.v1, t2.v2] }
├─BatchExchange { order: [], dist: HashShard(t1.v1) }
│ └─BatchScan { table: t1, columns: [t1.v1, t1.v2], distribution: SomeShard }
└─BatchExchange { order: [], dist: HashShard($expr1) }
└─BatchProject { exprs: [t2.v1, t2.v2, ConcatOp(t2.v1, 'a':Varchar) as $expr1] }
└─BatchScan { table: t2, columns: [t2.v1, t2.v2], distribution: SomeShard }
stream_plan: |-
StreamMaterialize { columns: [t1_v1, t1_v2, t2_v1, t2_v2, t1._row_id(hidden), t2._row_id(hidden)], stream_key: [t1._row_id, t2._row_id, t1_v1], pk_columns: [t1._row_id, t2._row_id, t1_v1], pk_conflict: NoCheck }
└─StreamAsOfJoin { type: AsofInner, predicate: t1.v1 = $expr1 AND (t1.v2 > t2.v2), output: [t1.v1, t1.v2, t2.v1, t2.v2, t1._row_id, t2._row_id] }
Expand All @@ -10,10 +21,18 @@
└─StreamExchange { dist: HashShard($expr1) }
└─StreamProject { exprs: [t2.v1, t2.v2, ConcatOp(t2.v1, 'a':Varchar) as $expr1, t2._row_id] }
└─StreamTableScan { table: t2, columns: [t2.v1, t2.v2, t2._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t2._row_id], pk: [_row_id], dist: UpstreamHashShard(t2._row_id) }
batch_error: |-
Not supported: AsOf join in batch query
HINT: AsOf join is only supported in streaming query
- sql: CREATE TABLE t1(v1 varchar, v2 int, v3 int); CREATE TABLE t2(v1 varchar, v2 int, v3 int); SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF LEFT JOIN t2 ON t1.v1 = t2.v1 and t1.v2 *2 < t2.v2;
batch_plan: |-
BatchExchange { order: [], dist: Single }
└─BatchProject { exprs: [t1.v1, t1.v2, t2.v1, t2.v2] }
└─BatchGroupTopN { order: [t2.v2 ASC], limit: 1, offset: 0, group_key: [t1.v1, t2.v1] }
└─BatchExchange { order: [], dist: HashShard(t1.v1, t2.v1) }
└─BatchHashJoin { type: LeftOuter, predicate: t1.v1 = t2.v1 AND ($expr1 < t2.v2), output: [t1.v1, t1.v2, t2.v1, t2.v2, $expr1] }
├─BatchExchange { order: [], dist: HashShard(t1.v1) }
│ └─BatchProject { exprs: [t1.v1, t1.v2, (t1.v2 * 2:Int32) as $expr1] }
│ └─BatchScan { table: t1, columns: [t1.v1, t1.v2], distribution: SomeShard }
└─BatchExchange { order: [], dist: HashShard(t2.v1) }
└─BatchScan { table: t2, columns: [t2.v1, t2.v2], distribution: SomeShard }
stream_plan: |-
StreamMaterialize { columns: [t1_v1, t1_v2, t2_v1, t2_v2, t1._row_id(hidden), t2._row_id(hidden)], stream_key: [t1._row_id, t2._row_id, t1_v1], pk_columns: [t1._row_id, t2._row_id, t1_v1], pk_conflict: NoCheck }
└─StreamAsOfJoin { type: AsofLeftOuter, predicate: t1.v1 = t2.v1 AND ($expr1 < t2.v2), output: [t1.v1, t1.v2, t2.v1, t2.v2, t1._row_id, t2._row_id] }
Expand All @@ -23,6 +42,8 @@
└─StreamExchange { dist: HashShard(t2.v1) }
└─StreamTableScan { table: t2, columns: [t2.v1, t2.v2, t2._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t2._row_id], pk: [_row_id], dist: UpstreamHashShard(t2._row_id) }
- sql: CREATE TABLE t1(v1 varchar, v2 int, v3 int); CREATE TABLE t2(v1 varchar, v2 int, v3 int); SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1 and t1.v2 < t2.v2 and t1.v3 < t2.v3;
batch_error: 'Invalid input syntax: AsOf join requires exactly 1 ineuquality condition'
stream_error: 'Invalid input syntax: AsOf join requires exactly 1 ineuquality condition'
- sql: CREATE TABLE t1(v1 varchar, v2 int, v3 int); CREATE TABLE t2(v1 varchar, v2 int, v3 int); SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v2 < t2.v2;
batch_error: 'Invalid input syntax: AsOf join requires at least 1 equal condition'
stream_error: 'Invalid input syntax: AsOf join requires at least 1 equal condition'
9 changes: 7 additions & 2 deletions src/frontend/src/optimizer/plan_node/batch_hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,16 @@ impl BatchHashJoin {
// we can not derive the hash distribution from the side where outer join can generate a
// NULL row
(Distribution::HashShard(_), Distribution::HashShard(_)) => match join.join_type {
JoinType::AsofInner | JoinType::AsofLeftOuter | JoinType::Unspecified => {
JoinType::Unspecified => {
unreachable!()
}
JoinType::FullOuter => Distribution::SomeShard,
JoinType::Inner | JoinType::LeftOuter | JoinType::LeftSemi | JoinType::LeftAnti => {
JoinType::Inner
| JoinType::LeftOuter
| JoinType::LeftSemi
| JoinType::LeftAnti
| JoinType::AsofInner
| JoinType::AsofLeftOuter => {
let l2o = join.l2i_col_mapping().composite(&join.i2o_col_mapping());
l2o.rewrite_provided_distribution(left)
}
Expand Down
10 changes: 10 additions & 0 deletions src/frontend/src/optimizer/plan_node/eq_join_predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,16 @@ impl EqJoinPredicate {
&mut self.other_cond
}

/// Get the equal predicate
pub fn eq_predicate(&self) -> Self {
Self {
other_cond: Condition::true_cond(),
eq_keys: self.eq_keys.clone(),
left_cols_num: self.left_cols_num,
right_cols_num: self.right_cols_num,
}
}

/// Get a reference to the join predicate's eq keys.
///
/// Note: `right_col_index` starts from `left_cols_num`
Expand Down
151 changes: 138 additions & 13 deletions src/frontend/src/optimizer/plan_node/logical_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ use std::collections::HashMap;
use fixedbitset::FixedBitSet;
use itertools::{EitherOrBoth, Itertools};
use pretty_xmlish::{Pretty, XmlNode};
use risingwave_pb::plan_common::JoinType;
use risingwave_expr::bail;
use risingwave_pb::expr::expr_node::PbType;
use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
use risingwave_pb::stream_plan::StreamScanType;
use risingwave_sqlparser::ast::AsOf;

Expand All @@ -36,9 +38,10 @@ use crate::optimizer::plan_node::generic::DynamicFilter;
use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
use crate::optimizer::plan_node::utils::IndicesDisplay;
use crate::optimizer::plan_node::{
BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
BatchGroupTopN, BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, BatchProject,
ColumnPruningContext, EqJoinPredicate, LogicalFilter, LogicalScan, PredicatePushdownContext,
RewriteStreamContext, StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin,
ToStreamContext,
};
use crate::optimizer::plan_visitor::LogicalCardinalityExt;
use crate::optimizer::property::{Distribution, Order, RequiredDist};
Expand Down Expand Up @@ -1379,25 +1382,145 @@ impl LogicalJoin {
let logical_join = self.clone_with_left_right(left, right);

let inequality_desc =
StreamAsOfJoin::get_inequality_desc_from_predicate(predicate.clone(), left_len)?;
Self::get_inequality_desc_from_predicate(predicate.other_cond().clone(), left_len)?;

Ok(StreamAsOfJoin::new(
logical_join.core.clone(),
predicate,
inequality_desc,
))
}
}

impl ToBatch for LogicalJoin {
fn to_batch(&self) -> Result<PlanRef> {
if JoinType::AsofInner == self.join_type() || JoinType::AsofLeftOuter == self.join_type() {
return Err(ErrorCode::NotSupported(
"AsOf join in batch query".to_owned(),
"AsOf join is only supported in streaming query".to_owned(),
/// Convert the logical `AsOf` join to a Hash join + a Group top 1.
fn to_batch_asof_join(
&self,
mut logical_join: generic::Join<PlanRef>,
predicate: EqJoinPredicate,
) -> Result<PlanRef> {
use risingwave_common::util::sort_util::{ColumnOrder, OrderType};

use super::batch::prelude::*;

if predicate.eq_keys().is_empty() {
return Err(ErrorCode::InvalidInputSyntax(
"AsOf join requires at least 1 equal condition".to_owned(),
)
.into());
}

logical_join.join_type = match logical_join.join_type {
JoinType::AsofInner => JoinType::Inner,
JoinType::AsofLeftOuter => JoinType::LeftOuter,
_ => unreachable!(),
};
let left_schema_len = logical_join.left.schema().len();
let asof_desc =
Self::get_inequality_desc_from_predicate(predicate.non_eq_cond(), left_schema_len)?;

let (left_asof_idx, right_asof_idx) = (
asof_desc.left_idx as usize,
asof_desc.right_idx as usize + left_schema_len,
);

// Add the AsOf columns to the output indices
let original_output_indices = logical_join.output_indices.clone();
if !logical_join.output_indices.contains(&left_asof_idx) {
logical_join.output_indices.push(left_asof_idx);
}
if !logical_join.output_indices.contains(&right_asof_idx) {
logical_join.output_indices.push(right_asof_idx);
}

let mapping = logical_join.i2o_col_mapping();

let batch_join = BatchHashJoin::new(logical_join, predicate.clone());

let right_output_asof_idx = mapping.map(right_asof_idx);

// Add a Group Top1 operator that group by LHS's join key and sort by RHS's asof column.
let order = match asof_desc.inequality_type() {
PbAsOfJoinInequalityType::AsOfInequalityTypeLt
| PbAsOfJoinInequalityType::AsOfInequalityTypeLe => Order::new(vec![ColumnOrder::new(
right_output_asof_idx,
OrderType::ascending(),
)]),
PbAsOfJoinInequalityType::AsOfInequalityTypeGt
| PbAsOfJoinInequalityType::AsOfInequalityTypeGe => Order::new(vec![ColumnOrder::new(
right_output_asof_idx,
OrderType::descending(),
)]),
PbAsOfJoinInequalityType::AsOfInequalityTypeUnspecified => {
bail!("unspecified AsOf join inequality type")
}
};
let group_key = [
predicate.left_eq_indexes(),
vec![asof_desc.left_idx as usize],
]
.concat();
let logical_group_top1 = generic::TopN::with_group(
batch_join.into(),
generic::TopNLimit::new(1, false),
0,
order,
group_key,
);
let batch_group_top1 = BatchGroupTopN::new(logical_group_top1);

let group_top_1_schema_len = batch_group_top1.schema().len();
if original_output_indices.len() != group_top_1_schema_len {
assert!(original_output_indices.len() < group_top_1_schema_len);
let logical_project = generic::Project::with_out_col_idx(
batch_group_top1.into(),
0..original_output_indices.len(),
);
Ok(BatchProject::new(logical_project).into())
} else {
Ok(batch_group_top1.into())
}
}

pub fn get_inequality_desc_from_predicate(
predicate: Condition,
left_input_len: usize,
) -> Result<AsOfJoinDesc> {
let expr: ExprImpl = predicate.into();
if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
{
Ok(AsOfJoinDesc {
left_idx: left_input_ref.index() as u32,
right_idx: (right_input_ref.index() - left_input_len) as u32,
inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
})
} else {
bail!("inequal condition from the same side should be push down in optimizer");
}
} else {
Err(ErrorCode::InvalidInputSyntax(
"AsOf join requires exactly 1 ineuquality condition".to_owned(),
)
.into())
}
}

fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
match expr_type {
PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
_ => Err(ErrorCode::InvalidInputSyntax(format!(
"Invalid comparison type: {}",
expr_type.as_str_name()
))
.into()),
}
}
}

impl ToBatch for LogicalJoin {
fn to_batch(&self) -> Result<PlanRef> {
let predicate = EqJoinPredicate::create(
self.left().schema().len(),
self.right().schema().len(),
Expand All @@ -1411,7 +1534,9 @@ impl ToBatch for LogicalJoin {
let ctx = self.base.ctx();
let config = ctx.session_ctx().config();

if predicate.has_eq() {
if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
self.to_batch_asof_join(logical_join, predicate)
} else if predicate.has_eq() {
if !predicate.eq_keys_are_type_aligned() {
return Err(ErrorCode::InternalError(format!(
"Join eq keys are not aligned for predicate: {predicate:?}"
Expand Down
Loading
Loading