Skip to content

Commit

Permalink
[pipelineX](fix) Fix BE crash caused by join and constant expr (apach…
Browse files Browse the repository at this point in the history
  • Loading branch information
Gabriel39 authored and pull[bot] committed Sep 28, 2023
1 parent bd9c6a5 commit 2579243
Show file tree
Hide file tree
Showing 14 changed files with 96 additions and 18 deletions.
10 changes: 9 additions & 1 deletion be/src/pipeline/exec/hashjoin_build_sink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,6 @@ Status HashJoinBuildSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* st
_build_expr_ctxs.push_back(ctx);

const auto vexpr = _build_expr_ctxs.back()->root();
const auto& data_type = vexpr->data_type();

bool null_aware = eq_join_conjunct.__isset.opcode &&
eq_join_conjunct.opcode == TExprOpcode::EQ_FOR_NULL;
Expand All @@ -421,7 +420,10 @@ Status HashJoinBuildSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* st
_store_null_in_hash_table.emplace_back(
null_aware ||
(_build_expr_ctxs.back()->root()->is_nullable() && build_stores_null));
}

for (const auto& expr : _build_expr_ctxs) {
const auto& data_type = expr->root()->data_type();
if (!data_type->have_maximum_size_of_value()) {
break;
}
Expand Down Expand Up @@ -589,6 +591,12 @@ Status HashJoinBuildSinkOperatorX::sink(RuntimeState* state, vectorized::Block*

local_state.init_short_circuit_for_probe();
if (source_state == SourceState::FINISHED) {
// Since the comparison of null values is meaningless, null aware left anti join should not output null
// when the build side is not empty.
if (!local_state._shared_state->build_blocks->empty() &&
_join_op == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN) {
local_state._shared_state->probe_ignore_null = true;
}
local_state._dependency->set_ready_for_read();
}

Expand Down
48 changes: 40 additions & 8 deletions be/src/pipeline/exec/hashjoin_probe_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Status HashJoinProbeLocalState::init(RuntimeState* state, LocalStateInfo& info)
SCOPED_TIMER(profile()->total_time_counter());
SCOPED_TIMER(_open_timer);
auto& p = _parent->cast<HashJoinProbeOperatorX>();
_probe_ignore_null = p._probe_ignore_null;
_shared_state->probe_ignore_null = p._probe_ignore_null;
_probe_expr_ctxs.resize(p._probe_expr_ctxs.size());
for (size_t i = 0; i < _probe_expr_ctxs.size(); i++) {
RETURN_IF_ERROR(p._probe_expr_ctxs[i]->clone(state, _probe_expr_ctxs[i]));
Expand All @@ -43,11 +43,6 @@ Status HashJoinProbeLocalState::init(RuntimeState* state, LocalStateInfo& info)
for (size_t i = 0; i < _other_join_conjuncts.size(); i++) {
RETURN_IF_ERROR(p._other_join_conjuncts[i]->clone(state, _other_join_conjuncts[i]));
}
// Since the comparison of null values is meaningless, null aware left anti join should not output null
// when the build side is not empty.
if (!_shared_state->build_blocks->empty() && p._join_op == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN) {
_probe_ignore_null = true;
}
_construct_mutable_join_block();
_probe_column_disguise_null.reserve(_probe_expr_ctxs.size());
_probe_arena_memory_usage =
Expand Down Expand Up @@ -189,6 +184,42 @@ Status HashJoinProbeOperatorX::pull(doris::RuntimeState* state, vectorized::Bloc
local_state.init_for_probe(state);
SCOPED_TIMER(local_state._probe_timer);
if (local_state._shared_state->short_circuit_for_probe) {
/// If `_short_circuit_for_probe` is true, this indicates no rows
/// match the join condition, and this is 'mark join', so we need to create a column as mark
/// with all rows set to 0.
if (_is_mark_join) {
auto block_rows = local_state._probe_block.rows();
if (block_rows == 0) {
if (local_state._probe_eos) {
source_state = SourceState::FINISHED;
}
return Status::OK();
}

vectorized::Block temp_block;
//get probe side output column
for (int i = 0; i < _left_output_slot_flags.size(); ++i) {
if (_left_output_slot_flags[i]) {
temp_block.insert(local_state._probe_block.get_by_position(i));
}
}
auto mark_column = vectorized::ColumnUInt8::create(block_rows, 0);
temp_block.insert(
{std::move(mark_column), std::make_shared<vectorized::DataTypeUInt8>(), ""});

{
SCOPED_TIMER(local_state._join_filter_timer);
RETURN_IF_ERROR(vectorized::VExprContext::filter_block(
local_state._conjuncts, &temp_block, temp_block.columns()));
}

RETURN_IF_ERROR(local_state._build_output_block(&temp_block, output_block, false));
temp_block.clear();
local_state._probe_block.clear_column_data(
_child_x->row_desc().num_materialized_slots());
local_state.reached_limit(output_block, source_state);
return Status::OK();
}
// If we use a short-circuit strategy, should return empty block directly.
source_state = SourceState::FINISHED;
return Status::OK();
Expand Down Expand Up @@ -241,7 +272,7 @@ Status HashJoinProbeOperatorX::pull(doris::RuntimeState* state, vectorized::Bloc
*local_state._shared_state->hash_table_variants,
*local_state._process_hashtable_ctx_variants,
vectorized::make_bool_variant(local_state._need_null_map_for_probe),
vectorized::make_bool_variant(local_state._probe_ignore_null));
vectorized::make_bool_variant(local_state._shared_state->probe_ignore_null));
});
} else if (local_state._probe_eos) {
if (_is_right_semi_anti || (_is_outer_join && _join_op != TJoinOp::LEFT_OUTER_JOIN)) {
Expand Down Expand Up @@ -299,7 +330,8 @@ bool HashJoinProbeOperatorX::need_more_input_data(RuntimeState* state) const {
auto& local_state = state->get_local_state(id())->cast<HashJoinProbeLocalState>();
return (local_state._probe_block.rows() == 0 ||
local_state._probe_index == local_state._probe_block.rows()) &&
!local_state._probe_eos && !local_state._shared_state->short_circuit_for_probe;
!local_state._probe_eos &&
(!local_state._shared_state->short_circuit_for_probe || _is_mark_join);
}

Status HashJoinProbeOperatorX::_do_evaluate(vectorized::Block& block,
Expand Down
1 change: 0 additions & 1 deletion be/src/pipeline/exec/hashjoin_probe_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ class HashJoinProbeLocalState final

bool _need_null_map_for_probe = false;
bool _has_set_need_null_map_for_probe = false;
bool _probe_ignore_null = false;
std::unique_ptr<vectorized::HashJoinProbeContext> _probe_context;
vectorized::ColumnUInt8::MutablePtr _null_map_column;
// for cases when a probe row matches more than batch size build rows.
Expand Down
5 changes: 4 additions & 1 deletion be/src/pipeline/exec/nested_loop_join_probe_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ Status NestedLoopJoinProbeLocalState::init(RuntimeState* state, LocalStateInfo&
RETURN_IF_ERROR(p._join_conjuncts[i]->clone(state, _join_conjuncts[i]));
}
_construct_mutable_join_block();

_loop_join_timer = ADD_TIMER(profile(), "LoopGenerateJoin");
return Status::OK();
}

Expand Down Expand Up @@ -349,7 +351,7 @@ void NestedLoopJoinProbeLocalState::_finalize_current_phase(vectorized::MutableB
DCHECK_LE(_left_block_start_pos + _left_side_process_count, _child_block->rows());
for (int j = _left_block_start_pos;
j < _left_block_start_pos + _left_side_process_count; ++j) {
mark_data.emplace_back(IsSemi != _cur_probe_row_visited_flags[j]);
mark_data.emplace_back(IsSemi == _cur_probe_row_visited_flags[j]);
}
for (size_t i = 0; i < p._num_probe_side_columns; ++i) {
const vectorized::ColumnWithTypeAndName src_column =
Expand Down Expand Up @@ -562,6 +564,7 @@ Status NestedLoopJoinProbeOperatorX::pull(RuntimeState* state, vectorized::Block
set_build_side_flag, set_probe_side_flag>(
state, join_op_variants);
};
SCOPED_TIMER(local_state._loop_join_timer);
RETURN_IF_ERROR(std::visit(
func, local_state._shared_state->join_op_variants,
vectorized::make_bool_variant(_match_all_build || _is_right_semi_anti),
Expand Down
2 changes: 2 additions & 0 deletions be/src/pipeline/exec/nested_loop_join_probe_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ class NestedLoopJoinProbeLocalState final
std::stack<uint16_t> _probe_offset_stack;
uint64_t _output_null_idx_build_side = 0;
vectorized::VExprContextSPtrs _join_conjuncts;

RuntimeProfile::Counter* _loop_join_timer;
};

class NestedLoopJoinProbeOperatorX final
Expand Down
2 changes: 1 addition & 1 deletion be/src/pipeline/exec/scan_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ template <typename Derived>
class ScanLocalState : public ScanLocalStateBase {
ENABLE_FACTORY_CREATOR(ScanLocalState);
ScanLocalState(RuntimeState* state, OperatorXBase* parent);
virtual ~ScanLocalState() = default;
~ScanLocalState() override = default;

Status init(RuntimeState* state, LocalStateInfo& info) override;
Status open(RuntimeState* state) override;
Expand Down
1 change: 1 addition & 0 deletions be/src/pipeline/pipeline_x/dependency.h
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ struct HashJoinSharedState : public JoinSharedState {
size_t build_exprs_size = 0;
std::shared_ptr<std::vector<vectorized::Block>> build_blocks =
std::make_shared<std::vector<vectorized::Block>>();
bool probe_ignore_null = false;
};

class HashJoinDependency final : public WriteDependency {
Expand Down
7 changes: 6 additions & 1 deletion be/src/vec/exprs/vcase_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,13 @@ Status VCaseExpr::prepare(RuntimeState* state, const RowDescriptor& desc, VExprC

Status VCaseExpr::open(RuntimeState* state, VExprContext* context,
FunctionContext::FunctionStateScope scope) {
RETURN_IF_ERROR(VExpr::open(state, context, scope));
for (int i = 0; i < _children.size(); ++i) {
RETURN_IF_ERROR(_children[i]->open(state, context, scope));
}
RETURN_IF_ERROR(VExpr::init_function_context(context, scope, _function));
if (scope == FunctionContext::FRAGMENT_LOCAL) {
RETURN_IF_ERROR(VExpr::get_const_col(context, nullptr));
}
return Status::OK();
}

Expand Down
7 changes: 6 additions & 1 deletion be/src/vec/exprs/vcast_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,13 @@ doris::Status VCastExpr::prepare(doris::RuntimeState* state, const doris::RowDes

doris::Status VCastExpr::open(doris::RuntimeState* state, VExprContext* context,
FunctionContext::FunctionStateScope scope) {
RETURN_IF_ERROR(VExpr::open(state, context, scope));
for (int i = 0; i < _children.size(); ++i) {
RETURN_IF_ERROR(_children[i]->open(state, context, scope));
}
RETURN_IF_ERROR(VExpr::init_function_context(context, scope, _function));
if (scope == FunctionContext::FRAGMENT_LOCAL) {
RETURN_IF_ERROR(VExpr::get_const_col(context, nullptr));
}
return Status::OK();
}

Expand Down
7 changes: 6 additions & 1 deletion be/src/vec/exprs/vectorized_fn_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,13 @@ Status VectorizedFnCall::prepare(RuntimeState* state, const RowDescriptor& desc,

Status VectorizedFnCall::open(RuntimeState* state, VExprContext* context,
FunctionContext::FunctionStateScope scope) {
RETURN_IF_ERROR(VExpr::open(state, context, scope));
for (int i = 0; i < _children.size(); ++i) {
RETURN_IF_ERROR(_children[i]->open(state, context, scope));
}
RETURN_IF_ERROR(VExpr::init_function_context(context, scope, _function));
if (scope == FunctionContext::FRAGMENT_LOCAL) {
RETURN_IF_ERROR(VExpr::get_const_col(context, nullptr));
}
return Status::OK();
}

Expand Down
9 changes: 8 additions & 1 deletion be/src/vec/exprs/vexpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ Status VExpr::open(RuntimeState* state, VExprContext* context,
for (int i = 0; i < _children.size(); ++i) {
RETURN_IF_ERROR(_children[i]->open(state, context, scope));
}
if (scope == FunctionContext::FRAGMENT_LOCAL) {
RETURN_IF_ERROR(VExpr::get_const_col(context, nullptr));
}
return Status::OK();
}

Expand Down Expand Up @@ -466,6 +469,7 @@ Status VExpr::get_const_col(VExprContext* context,
}

if (_constant_col != nullptr) {
DCHECK(column_wrapper != nullptr);
*column_wrapper = _constant_col;
return Status::OK();
}
Expand All @@ -479,7 +483,10 @@ Status VExpr::get_const_col(VExprContext* context,
DCHECK(result != -1);
const auto& column = block.get_by_position(result).column;
_constant_col = std::make_shared<ColumnPtrWrapper>(column);
*column_wrapper = _constant_col;
if (column_wrapper != nullptr) {
*column_wrapper = _constant_col;
}

return Status::OK();
}

Expand Down
7 changes: 6 additions & 1 deletion be/src/vec/exprs/vin_predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,13 @@ Status VInPredicate::prepare(RuntimeState* state, const RowDescriptor& desc,

Status VInPredicate::open(RuntimeState* state, VExprContext* context,
FunctionContext::FunctionStateScope scope) {
RETURN_IF_ERROR(VExpr::open(state, context, scope));
for (int i = 0; i < _children.size(); ++i) {
RETURN_IF_ERROR(_children[i]->open(state, context, scope));
}
RETURN_IF_ERROR(VExpr::init_function_context(context, scope, _function));
if (scope == FunctionContext::FRAGMENT_LOCAL) {
RETURN_IF_ERROR(VExpr::get_const_col(context, nullptr));
}
return Status::OK();
}

Expand Down
7 changes: 6 additions & 1 deletion be/src/vec/exprs/vmatch_predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,16 @@ Status VMatchPredicate::prepare(RuntimeState* state, const RowDescriptor& desc,

Status VMatchPredicate::open(RuntimeState* state, VExprContext* context,
FunctionContext::FunctionStateScope scope) {
RETURN_IF_ERROR(VExpr::open(state, context, scope));
for (int i = 0; i < _children.size(); ++i) {
RETURN_IF_ERROR(_children[i]->open(state, context, scope));
}
RETURN_IF_ERROR(VExpr::init_function_context(context, scope, _function));
if (scope == FunctionContext::THREAD_LOCAL || scope == FunctionContext::FRAGMENT_LOCAL) {
context->fn_context(_fn_context_index)->set_function_state(scope, _inverted_index_ctx);
}
if (scope == FunctionContext::FRAGMENT_LOCAL) {
RETURN_IF_ERROR(VExpr::get_const_col(context, nullptr));
}
return Status::OK();
}

Expand Down
1 change: 1 addition & 0 deletions be/src/vec/functions/function_fake.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class FunctionFake : public IFunction {
}

bool use_default_implementation_for_nulls() const override { return true; }
bool use_default_implementation_for_constants() const override { return false; }

Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) override {
Expand Down

0 comments on commit 2579243

Please sign in to comment.