diff --git a/be/src/exec/exec_node.cpp b/be/src/exec/exec_node.cpp index ed032d0976700e..63b88aa9de2b92 100644 --- a/be/src/exec/exec_node.cpp +++ b/be/src/exec/exec_node.cpp @@ -89,6 +89,18 @@ ExecNode::ExecNode(ObjectPool* pool, const TPlanNode& tnode, const DescriptorTbl _output_row_descriptor = std::make_unique( descs, std::vector {tnode.output_tuple_id}, std::vector {true}); } + if (!tnode.intermediate_output_tuple_id_list.empty()) { + DCHECK(tnode.__isset.output_tuple_id) << " no final output tuple id"; + // common subexpression elimination + DCHECK_EQ(tnode.intermediate_output_tuple_id_list.size(), + tnode.intermediate_projections_list.size()); + _intermediate_output_row_descriptor.reserve(tnode.intermediate_output_tuple_id_list.size()); + for (auto output_tuple_id : tnode.intermediate_output_tuple_id_list) { + _intermediate_output_row_descriptor.push_back( + RowDescriptor(descs, std::vector {output_tuple_id}, std::vector {true})); + } + } + _query_statistics = std::make_shared(); } @@ -114,7 +126,15 @@ Status ExecNode::init(const TPlanNode& tnode, RuntimeState* state) { DCHECK(tnode.__isset.output_tuple_id); RETURN_IF_ERROR(vectorized::VExpr::create_expr_trees(tnode.projections, _projections)); } - + if (!tnode.intermediate_projections_list.empty()) { + DCHECK(tnode.__isset.projections) << "no final projections"; + _intermediate_projections.reserve(tnode.intermediate_projections_list.size()); + for (const auto& tnode_projections : tnode.intermediate_projections_list) { + vectorized::VExprContextSPtrs projections; + RETURN_IF_ERROR(vectorized::VExpr::create_expr_trees(tnode_projections, projections)); + _intermediate_projections.push_back(projections); + } + } return Status::OK(); } @@ -143,7 +163,12 @@ Status ExecNode::prepare(RuntimeState* state) { RETURN_IF_ERROR(conjunct->prepare(state, intermediate_row_desc())); } - RETURN_IF_ERROR(vectorized::VExpr::prepare(_projections, state, intermediate_row_desc())); + for (int i = 0; i < _intermediate_projections.size(); i++) { + RETURN_IF_ERROR(vectorized::VExpr::prepare(_intermediate_projections[i], state, + intermediate_row_desc(i))); + } + + RETURN_IF_ERROR(vectorized::VExpr::prepare(_projections, state, projections_row_desc())); for (auto& i : _children) { RETURN_IF_ERROR(i->prepare(state)); @@ -155,6 +180,9 @@ Status ExecNode::alloc_resource(RuntimeState* state) { for (auto& conjunct : _conjuncts) { RETURN_IF_ERROR(conjunct->open(state)); } + for (auto& projections : _intermediate_projections) { + RETURN_IF_ERROR(vectorized::VExpr::open(projections, state)); + } RETURN_IF_ERROR(vectorized::VExpr::open(_projections, state)); return Status::OK(); } @@ -514,6 +542,22 @@ std::string ExecNode::get_name() { Status ExecNode::do_projections(vectorized::Block* origin_block, vectorized::Block* output_block) { SCOPED_TIMER(_exec_timer); SCOPED_TIMER(_projection_timer); + const size_t rows = origin_block->rows(); + if (rows == 0) { + return Status::OK(); + } + vectorized::Block input_block = *origin_block; + + std::vector result_column_ids; + for (auto& projections : _intermediate_projections) { + result_column_ids.resize(projections.size()); + for (int i = 0; i < projections.size(); i++) { + RETURN_IF_ERROR(projections[i]->execute(&input_block, &result_column_ids[i])); + } + input_block.shuffle_columns(result_column_ids); + } + + DCHECK_EQ(rows, input_block.rows()); auto insert_column_datas = [&](auto& to, vectorized::ColumnPtr& from, size_t rows) { if (to->is_nullable() && !from->is_nullable()) { if (_keep_origin || !from->is_exclusive()) { @@ -535,29 +579,26 @@ Status ExecNode::do_projections(vectorized::Block* origin_block, vectorized::Blo using namespace vectorized; MutableBlock mutable_block = VectorizedUtils::build_mutable_mem_reuse_block(output_block, *_output_row_descriptor); - auto rows = origin_block->rows(); - if (rows != 0) { - auto& mutable_columns = mutable_block.mutable_columns(); + auto& mutable_columns = mutable_block.mutable_columns(); - if (mutable_columns.size() != _projections.size()) { - return Status::InternalError( - "Logical error during processing {}, output of projections {} mismatches with " - "exec node output {}", - this->get_name(), _projections.size(), mutable_columns.size()); - } + if (mutable_columns.size() != _projections.size()) { + return Status::InternalError( + "Logical error during processing {}, output of projections {} mismatches with " + "exec node output {}", + this->get_name(), _projections.size(), mutable_columns.size()); + } - for (int i = 0; i < mutable_columns.size(); ++i) { - auto result_column_id = -1; - RETURN_IF_ERROR(_projections[i]->execute(origin_block, &result_column_id)); - auto column_ptr = origin_block->get_by_position(result_column_id) - .column->convert_to_full_column_if_const(); - //TODO: this is a quick fix, we need a new function like "change_to_nullable" to do it - insert_column_datas(mutable_columns[i], column_ptr, rows); - } - DCHECK(mutable_block.rows() == rows); - output_block->set_columns(std::move(mutable_columns)); + for (int i = 0; i < mutable_columns.size(); ++i) { + auto result_column_id = -1; + RETURN_IF_ERROR(_projections[i]->execute(&input_block, &result_column_id)); + auto column_ptr = input_block.get_by_position(result_column_id) + .column->convert_to_full_column_if_const(); + //TODO: this is a quick fix, we need a new function like "change_to_nullable" to do it + insert_column_datas(mutable_columns[i], column_ptr, rows); } + DCHECK(mutable_block.rows() == rows); + output_block->set_columns(std::move(mutable_columns)); return Status::OK(); } diff --git a/be/src/exec/exec_node.h b/be/src/exec/exec_node.h index f2303068437b2f..10b035835d7a7f 100644 --- a/be/src/exec/exec_node.h +++ b/be/src/exec/exec_node.h @@ -220,6 +220,26 @@ class ExecNode { return _output_row_descriptor ? *_output_row_descriptor : _row_descriptor; } virtual const RowDescriptor& intermediate_row_desc() const { return _row_descriptor; } + + // input expr -> intermediate_projections[0] -> intermediate_projections[1] -> intermediate_projections[2] ... -> final projections -> output expr + // prepare _row_descriptor intermediate_row_desc[0] intermediate_row_desc[1] intermediate_row_desc.end() _output_row_descriptor + + [[nodiscard]] const RowDescriptor& intermediate_row_desc(int idx) { + if (idx == 0) { + return intermediate_row_desc(); + } + DCHECK((idx - 1) < _intermediate_output_row_descriptor.size()); + return _intermediate_output_row_descriptor[idx - 1]; + } + + [[nodiscard]] const RowDescriptor& projections_row_desc() const { + if (_intermediate_output_row_descriptor.empty()) { + return intermediate_row_desc(); + } else { + return _intermediate_output_row_descriptor.back(); + } + } + int64_t rows_returned() const { return _num_rows_returned; } int64_t limit() const { return _limit; } bool reached_limit() const { return _limit != -1 && _num_rows_returned >= _limit; } @@ -270,6 +290,10 @@ class ExecNode { std::unique_ptr _output_row_descriptor; vectorized::VExprContextSPtrs _projections; + std::vector _intermediate_output_row_descriptor; + // Used in common subexpression elimination to compute intermediate results. + std::vector _intermediate_projections; + /// Resource information sent from the frontend. const TBackendResourceProfile _resource_profile; diff --git a/be/src/pipeline/pipeline_x/operator.cpp b/be/src/pipeline/pipeline_x/operator.cpp index 989b1ee00a517d..4a16cb65a014be 100644 --- a/be/src/pipeline/pipeline_x/operator.cpp +++ b/be/src/pipeline/pipeline_x/operator.cpp @@ -23,6 +23,8 @@ #include #include "common/logging.h" +#include "common/status.h" +#include "exec/exec_node.h" #include "pipeline/exec/aggregation_sink_operator.h" #include "pipeline/exec/aggregation_source_operator.h" #include "pipeline/exec/analytic_sink_operator.h" @@ -123,10 +125,20 @@ Status OperatorXBase::init(const TPlanNode& tnode, RuntimeState* /*state*/) { } // create the projections expr + if (tnode.__isset.projections) { DCHECK(tnode.__isset.output_tuple_id); RETURN_IF_ERROR(vectorized::VExpr::create_expr_trees(tnode.projections, _projections)); } + if (!tnode.intermediate_projections_list.empty()) { + DCHECK(tnode.__isset.projections) << "no final projections"; + _intermediate_projections.reserve(tnode.intermediate_projections_list.size()); + for (const auto& tnode_projections : tnode.intermediate_projections_list) { + vectorized::VExprContextSPtrs projections; + RETURN_IF_ERROR(vectorized::VExpr::create_expr_trees(tnode_projections, projections)); + _intermediate_projections.push_back(projections); + } + } return Status::OK(); } @@ -134,8 +146,11 @@ Status OperatorXBase::prepare(RuntimeState* state) { for (auto& conjunct : _conjuncts) { RETURN_IF_ERROR(conjunct->prepare(state, intermediate_row_desc())); } - - RETURN_IF_ERROR(vectorized::VExpr::prepare(_projections, state, intermediate_row_desc())); + for (int i = 0; i < _intermediate_projections.size(); i++) { + RETURN_IF_ERROR(vectorized::VExpr::prepare(_intermediate_projections[i], state, + intermediate_row_desc(i))); + } + RETURN_IF_ERROR(vectorized::VExpr::prepare(_projections, state, projections_row_desc())); if (_child_x && !is_source()) { RETURN_IF_ERROR(_child_x->prepare(state)); @@ -149,6 +164,9 @@ Status OperatorXBase::open(RuntimeState* state) { RETURN_IF_ERROR(conjunct->open(state)); } RETURN_IF_ERROR(vectorized::VExpr::open(_projections, state)); + for (auto& projections : _intermediate_projections) { + RETURN_IF_ERROR(vectorized::VExpr::open(projections, state)); + } if (_child_x && !is_source()) { RETURN_IF_ERROR(_child_x->open(state)); } @@ -175,7 +193,22 @@ Status OperatorXBase::do_projections(RuntimeState* state, vectorized::Block* ori auto* local_state = state->get_local_state(operator_id()); SCOPED_TIMER(local_state->exec_time_counter()); SCOPED_TIMER(local_state->_projection_timer); + const size_t rows = origin_block->rows(); + if (rows == 0) { + return Status::OK(); + } + vectorized::Block input_block = *origin_block; + std::vector result_column_ids; + for (const auto& projections : _intermediate_projections) { + result_column_ids.resize(projections.size()); + for (int i = 0; i < projections.size(); i++) { + RETURN_IF_ERROR(projections[i]->execute(&input_block, &result_column_ids[i])); + } + input_block.shuffle_columns(result_column_ids); + } + + DCHECK_EQ(rows, input_block.rows()); auto insert_column_datas = [&](auto& to, vectorized::ColumnPtr& from, size_t rows) { if (to->is_nullable() && !from->is_nullable()) { if (_keep_origin || !from->is_exclusive()) { @@ -198,15 +231,13 @@ Status OperatorXBase::do_projections(RuntimeState* state, vectorized::Block* ori vectorized::MutableBlock mutable_block = vectorized::VectorizedUtils::build_mutable_mem_reuse_block(output_block, *_output_row_descriptor); - auto rows = origin_block->rows(); - if (rows != 0) { auto& mutable_columns = mutable_block.mutable_columns(); DCHECK(mutable_columns.size() == local_state->_projections.size()); for (int i = 0; i < mutable_columns.size(); ++i) { auto result_column_id = -1; - RETURN_IF_ERROR(local_state->_projections[i]->execute(origin_block, &result_column_id)); - auto column_ptr = origin_block->get_by_position(result_column_id) + RETURN_IF_ERROR(local_state->_projections[i]->execute(&input_block, &result_column_id)); + auto column_ptr = input_block.get_by_position(result_column_id) .column->convert_to_full_column_if_const(); insert_column_datas(mutable_columns[i], column_ptr, rows); } @@ -365,6 +396,15 @@ Status PipelineXLocalState::init(RuntimeState* state, LocalState for (size_t i = 0; i < _projections.size(); i++) { RETURN_IF_ERROR(_parent->_projections[i]->clone(state, _projections[i])); } + _intermediate_projections.resize(_parent->_intermediate_projections.size()); + for (int i = 0; i < _parent->_intermediate_projections.size(); i++) { + _intermediate_projections[i].resize(_parent->_intermediate_projections[i].size()); + for (int j = 0; j < _parent->_intermediate_projections[i].size(); j++) { + RETURN_IF_ERROR(_parent->_intermediate_projections[i][j]->clone( + state, _intermediate_projections[i][j])); + } + } + _rows_returned_counter = ADD_COUNTER_WITH_LEVEL(_runtime_profile, "RowsProduced", TUnit::UNIT, 1); _blocks_returned_counter = diff --git a/be/src/pipeline/pipeline_x/operator.h b/be/src/pipeline/pipeline_x/operator.h index c375efb924dcbc..c3eb4d0cb51905 100644 --- a/be/src/pipeline/pipeline_x/operator.h +++ b/be/src/pipeline/pipeline_x/operator.h @@ -135,6 +135,9 @@ class PipelineXLocalStateBase { RuntimeState* _state = nullptr; vectorized::VExprContextSPtrs _conjuncts; vectorized::VExprContextSPtrs _projections; + // Used in common subexpression elimination to compute intermediate results. + std::vector _intermediate_projections; + bool _closed = false; vectorized::Block _origin_block; }; @@ -155,6 +158,22 @@ class OperatorXBase : public OperatorBase { if (tnode.__isset.output_tuple_id) { _output_row_descriptor.reset(new RowDescriptor(descs, {tnode.output_tuple_id}, {true})); } + if (tnode.__isset.output_tuple_id) { + _output_row_descriptor = std::make_unique( + descs, std::vector {tnode.output_tuple_id}, std::vector {true}); + } + if (!tnode.intermediate_output_tuple_id_list.empty()) { + DCHECK(tnode.__isset.output_tuple_id) << " no final output tuple id"; + // common subexpression elimination + DCHECK_EQ(tnode.intermediate_output_tuple_id_list.size(), + tnode.intermediate_projections_list.size()); + _intermediate_output_row_descriptor.reserve( + tnode.intermediate_output_tuple_id_list.size()); + for (auto output_tuple_id : tnode.intermediate_output_tuple_id_list) { + _intermediate_output_row_descriptor.push_back( + RowDescriptor(descs, std::vector {output_tuple_id}, std::vector {true})); + } + } } OperatorXBase(ObjectPool* pool, int node_id, int operator_id) @@ -247,6 +266,25 @@ class OperatorXBase : public OperatorBase { return _row_descriptor; } + // input expr -> intermediate_projections[0] -> intermediate_projections[1] -> intermediate_projections[2] ... -> final projections -> output expr + // prepare _row_descriptor intermediate_row_desc[0] intermediate_row_desc[1] intermediate_row_desc.end() _output_row_descriptor + + [[nodiscard]] const RowDescriptor& intermediate_row_desc(int idx) { + if (idx == 0) { + return intermediate_row_desc(); + } + DCHECK((idx - 1) < _intermediate_output_row_descriptor.size()); + return _intermediate_output_row_descriptor[idx - 1]; + } + + [[nodiscard]] const RowDescriptor& projections_row_desc() const { + if (_intermediate_output_row_descriptor.empty()) { + return intermediate_row_desc(); + } else { + return _intermediate_output_row_descriptor.back(); + } + } + [[nodiscard]] std::string debug_string() const override { return ""; } virtual std::string debug_string(int indentation_level = 0) const; @@ -318,6 +356,10 @@ class OperatorXBase : public OperatorBase { std::unique_ptr _output_row_descriptor = nullptr; vectorized::VExprContextSPtrs _projections; + std::vector _intermediate_output_row_descriptor; + // Used in common subexpression elimination to compute intermediate results. + std::vector _intermediate_projections; + /// Resource information sent from the frontend. const TBackendResourceProfile _resource_profile; diff --git a/be/src/vec/core/block.cpp b/be/src/vec/core/block.cpp index c93bfb11f09d6d..1d8d3e838015c9 100644 --- a/be/src/vec/core/block.cpp +++ b/be/src/vec/core/block.cpp @@ -719,6 +719,15 @@ void Block::swap(Block&& other) noexcept { row_same_bit = std::move(other.row_same_bit); } +void Block::shuffle_columns(const std::vector& result_column_ids) { + Container tmp_data; + tmp_data.reserve(result_column_ids.size()); + for (const int result_column_id : result_column_ids) { + tmp_data.push_back(data[result_column_id]); + } + swap(Block {tmp_data}); +} + void Block::update_hash(SipHash& hash) const { for (size_t row_no = 0, num_rows = rows(); row_no < num_rows; ++row_no) { for (const auto& col : data) { diff --git a/be/src/vec/core/block.h b/be/src/vec/core/block.h index a9769e7b679287..eb4fe43eca2faf 100644 --- a/be/src/vec/core/block.h +++ b/be/src/vec/core/block.h @@ -234,6 +234,9 @@ class Block { void swap(Block& other) noexcept; void swap(Block&& other) noexcept; + // Shuffle columns in place based on the result_column_ids + void shuffle_columns(const std::vector& result_column_ids); + // Default column size = -1 means clear all column in block // Else clear column [0, column_size) delete column [column_size, data.size) void clear_column_data(int column_size = -1) noexcept; diff --git a/be/src/vec/exec/scan/vscanner.cpp b/be/src/vec/exec/scan/vscanner.cpp index 39a9059d1d37c8..bedd6fb9e46352 100644 --- a/be/src/vec/exec/scan/vscanner.cpp +++ b/be/src/vec/exec/scan/vscanner.cpp @@ -20,6 +20,7 @@ #include #include "common/config.h" +#include "exec/exec_node.h" #include "pipeline/exec/scan_operator.h" #include "runtime/descriptors.h" #include "util/runtime_profile.h" @@ -68,6 +69,19 @@ Status VScanner::prepare(RuntimeState* state, const VExprContextSPtrs& conjuncts } } + const auto& intermediate_projections = + _parent ? _parent->_intermediate_projections : _local_state->_intermediate_projections; + if (!intermediate_projections.empty()) { + _intermediate_projections.resize(intermediate_projections.size()); + for (int i = 0; i < intermediate_projections.size(); i++) { + _intermediate_projections[i].resize(intermediate_projections[i].size()); + for (int j = 0; j < intermediate_projections[i].size(); j++) { + RETURN_IF_ERROR(intermediate_projections[i][j]->clone( + state, _intermediate_projections[i][j])); + } + } + } + return Status::OK(); } @@ -169,42 +183,55 @@ Status VScanner::_filter_output_block(Block* block) { } Status VScanner::_do_projections(vectorized::Block* origin_block, vectorized::Block* output_block) { - auto projection_timer = _parent ? _parent->_projection_timer : _local_state->_projection_timer; - auto exec_timer = _parent ? _parent->_exec_timer : _local_state->_exec_timer; + auto& projection_timer = _parent ? _parent->_projection_timer : _local_state->_projection_timer; + auto& exec_timer = _parent ? _parent->_exec_timer : _local_state->_exec_timer; SCOPED_TIMER(exec_timer); SCOPED_TIMER(projection_timer); + const size_t rows = origin_block->rows(); + if (rows == 0) { + return Status::OK(); + } + vectorized::Block input_block = *origin_block; + + std::vector result_column_ids; + for (auto& projections : _intermediate_projections) { + result_column_ids.resize(projections.size()); + for (int i = 0; i < projections.size(); i++) { + RETURN_IF_ERROR(projections[i]->execute(&input_block, &result_column_ids[i])); + } + input_block.shuffle_columns(result_column_ids); + } + + DCHECK_EQ(rows, input_block.rows()); MutableBlock mutable_block = VectorizedUtils::build_mutable_mem_reuse_block(output_block, *_output_row_descriptor); - auto rows = origin_block->rows(); - if (rows != 0) { - auto& mutable_columns = mutable_block.mutable_columns(); + auto& mutable_columns = mutable_block.mutable_columns(); - if (mutable_columns.size() != _projections.size()) { - return Status::InternalError( - "Logical error in scanner, output of projections {} mismatches with " - "scanner output {}", - _projections.size(), mutable_columns.size()); - } + if (mutable_columns.size() != _projections.size()) { + return Status::InternalError( + "Logical error in scanner, output of projections {} mismatches with " + "scanner output {}", + _projections.size(), mutable_columns.size()); + } - for (int i = 0; i < mutable_columns.size(); ++i) { - auto result_column_id = -1; - RETURN_IF_ERROR(_projections[i]->execute(origin_block, &result_column_id)); - auto column_ptr = origin_block->get_by_position(result_column_id) - .column->convert_to_full_column_if_const(); - //TODO: this is a quick fix, we need a new function like "change_to_nullable" to do it - if (mutable_columns[i]->is_nullable() xor column_ptr->is_nullable()) { - DCHECK(mutable_columns[i]->is_nullable() && !column_ptr->is_nullable()); - reinterpret_cast(mutable_columns[i].get()) - ->insert_range_from_not_nullable(*column_ptr, 0, rows); - } else { - mutable_columns[i]->insert_range_from(*column_ptr, 0, rows); - } + for (int i = 0; i < mutable_columns.size(); ++i) { + auto result_column_id = -1; + RETURN_IF_ERROR(_projections[i]->execute(&input_block, &result_column_id)); + auto column_ptr = input_block.get_by_position(result_column_id) + .column->convert_to_full_column_if_const(); + //TODO: this is a quick fix, we need a new function like "change_to_nullable" to do it + if (mutable_columns[i]->is_nullable() xor column_ptr->is_nullable()) { + DCHECK(mutable_columns[i]->is_nullable() && !column_ptr->is_nullable()); + reinterpret_cast(mutable_columns[i].get()) + ->insert_range_from_not_nullable(*column_ptr, 0, rows); + } else { + mutable_columns[i]->insert_range_from(*column_ptr, 0, rows); } - DCHECK(mutable_block.rows() == rows); - output_block->set_columns(std::move(mutable_columns)); } + DCHECK(mutable_block.rows() == rows); + output_block->set_columns(std::move(mutable_columns)); return Status::OK(); } diff --git a/be/src/vec/exec/scan/vscanner.h b/be/src/vec/exec/scan/vscanner.h index d264e99fc78306..8c205aaff5d4db 100644 --- a/be/src/vec/exec/scan/vscanner.h +++ b/be/src/vec/exec/scan/vscanner.h @@ -195,6 +195,8 @@ class VScanner { // It includes predicate in SQL and runtime filters. VExprContextSPtrs _conjuncts; VExprContextSPtrs _projections; + // Used in common subexpression elimination to compute intermediate results. + std::vector _intermediate_projections; vectorized::Block _origin_block; VExprContextSPtrs _common_expr_ctxs_push_down; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index f47b6826ebe2ce..205cfbd25309d8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -1837,15 +1837,38 @@ public PlanFragment visitPhysicalProject(PhysicalProject project registerRewrittenSlot(project, (OlapScanNode) inputFragment.getPlanRoot()); } - List projectionExprs = project.getProjects() - .stream() - .map(e -> ExpressionTranslator.translate(e, context)) - .collect(Collectors.toList()); - List slots = project.getProjects() - .stream() - .map(NamedExpression::toSlot) - .collect(Collectors.toList()); - + PlanNode inputPlanNode = inputFragment.getPlanRoot(); + List projectionExprs = null; + List allProjectionExprs = Lists.newArrayList(); + List slots = null; + if (project.hasMultiLayerProjection()) { + int layerCount = project.getMultiLayerProjects().size(); + for (int i = 0; i < layerCount; i++) { + List layer = project.getMultiLayerProjects().get(i); + projectionExprs = layer.stream() + .map(e -> ExpressionTranslator.translate(e, context)) + .collect(Collectors.toList()); + slots = layer.stream() + .map(NamedExpression::toSlot) + .collect(Collectors.toList()); + if (i < layerCount - 1) { + inputPlanNode.addIntermediateProjectList(projectionExprs); + TupleDescriptor projectionTuple = generateTupleDesc(slots, null, context); + inputPlanNode.addIntermediateOutputTupleDescList(projectionTuple); + } + allProjectionExprs.addAll(projectionExprs); + } + } else { + projectionExprs = project.getProjects() + .stream() + .map(e -> ExpressionTranslator.translate(e, context)) + .collect(Collectors.toList()); + slots = project.getProjects() + .stream() + .map(NamedExpression::toSlot) + .collect(Collectors.toList()); + allProjectionExprs.addAll(projectionExprs); + } // process multicast sink if (inputFragment instanceof MultiCastPlanFragment) { MultiCastDataSink multiCastDataSink = (MultiCastDataSink) inputFragment.getSink(); @@ -1857,10 +1880,9 @@ public PlanFragment visitPhysicalProject(PhysicalProject project return inputFragment; } - PlanNode inputPlanNode = inputFragment.getPlanRoot(); List conjuncts = inputPlanNode.getConjuncts(); Set requiredSlotIdSet = Sets.newHashSet(); - for (Expr expr : projectionExprs) { + for (Expr expr : allProjectionExprs) { Expr.extractSlots(expr, requiredSlotIdSet); } Set requiredByProjectSlotIdSet = Sets.newHashSet(requiredSlotIdSet); @@ -1895,8 +1917,10 @@ public PlanFragment visitPhysicalProject(PhysicalProject project requiredSlotIdSet.forEach(e -> requiredExprIds.add(context.findExprId(e))); for (ExprId exprId : requiredExprIds) { SlotId slotId = ((HashJoinNode) joinNode).getHashOutputExprSlotIdMap().get(exprId); - Preconditions.checkState(slotId != null); - ((HashJoinNode) joinNode).addSlotIdToHashOutputSlotIds(slotId); + // Preconditions.checkState(slotId != null); + if (slotId != null) { + ((HashJoinNode) joinNode).addSlotIdToHashOutputSlotIds(slotId); + } } } return inputFragment; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java new file mode 100644 index 00000000000000..5abc5f6f60ffa2 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.processor.post; + +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; + +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; + +/** + * collect common expr + */ +public class CommonSubExpressionCollector extends ExpressionVisitor { + public final Map> commonExprByDepth = new HashMap<>(); + private final Map> expressionsByDepth = new HashMap<>(); + + @Override + public Integer visit(Expression expr, Void context) { + if (expr.children().isEmpty()) { + return 0; + } + return collectCommonExpressionByDepth(expr.children().stream().map(child -> + child.accept(this, context)).reduce(Math::max).map(m -> m + 1).orElse(1), expr); + } + + private int collectCommonExpressionByDepth(int depth, Expression expr) { + Set expressions = getExpressionsFromDepthMap(depth, expressionsByDepth); + if (expressions.contains(expr)) { + Set commonExpression = getExpressionsFromDepthMap(depth, commonExprByDepth); + commonExpression.add(expr); + } + expressions.add(expr); + return depth; + } + + public static Set getExpressionsFromDepthMap( + int depth, Map> depthMap) { + depthMap.putIfAbsent(depth, new LinkedHashSet<>()); + return depthMap.get(depth); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java new file mode 100644 index 00000000000000..dfaf2de757e45e --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.processor.post; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; + +import com.google.common.collect.Lists; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Select A+B, (A+B+C)*2, (A+B+C)*3, D from T + * + * before optimize + * projection: + * Proj: A+B, (A+B+C)*2, (A+B+C)*3, D + * + * --- + * after optimize: + * Projection: List < List < Expression > > + * A+B, C, D + * A+B, A+B+C, D + * A+B, (A+B+C)*2, (A+B+C)*3, D + */ +public class CommonSubExpressionOpt extends PlanPostProcessor { + @Override + public PhysicalProject visitPhysicalProject(PhysicalProject project, CascadesContext ctx) { + + List> multiLayers = computeMultiLayerProjections( + project.getInputSlots(), project.getProjects()); + project.setMultiLayerProjects(multiLayers); + return project; + } + + private List> computeMultiLayerProjections( + Set inputSlots, List projects) { + + List> multiLayers = Lists.newArrayList(); + CommonSubExpressionCollector collector = new CommonSubExpressionCollector(); + for (Expression expr : projects) { + expr.accept(collector, null); + } + Map commonExprToAliasMap = new HashMap<>(); + collector.commonExprByDepth.values().stream().flatMap(expressions -> expressions.stream()) + .forEach(expression -> { + if (expression instanceof Alias) { + commonExprToAliasMap.put(expression, (Alias) expression); + } else { + commonExprToAliasMap.put(expression, new Alias(expression)); + } + }); + Map aliasMap = new HashMap<>(); + if (!collector.commonExprByDepth.isEmpty()) { + for (int i = 1; i <= collector.commonExprByDepth.size(); i++) { + List layer = Lists.newArrayList(); + layer.addAll(inputSlots); + Set exprsInDepth = CommonSubExpressionCollector + .getExpressionsFromDepthMap(i, collector.commonExprByDepth); + exprsInDepth.forEach(expr -> { + Expression rewritten = expr.accept(ExpressionReplacer.INSTANCE, aliasMap); + Alias alias = new Alias(rewritten); + aliasMap.put(expr, alias); + }); + layer.addAll(aliasMap.values()); + multiLayers.add(layer); + } + // final layer + List finalLayer = Lists.newArrayList(); + projects.forEach(expr -> { + Expression rewritten = expr.accept(ExpressionReplacer.INSTANCE, aliasMap); + if (rewritten instanceof Slot) { + finalLayer.add((NamedExpression) rewritten); + } else if (rewritten instanceof Alias) { + finalLayer.add(new Alias(expr.getExprId(), ((Alias) rewritten).child(), expr.getName())); + } + }); + multiLayers.add(finalLayer); + } + return multiLayers; + } + + /** + * replace sub expr by aliasMap + */ + public static class ExpressionReplacer + extends DefaultExpressionRewriter> { + public static final ExpressionReplacer INSTANCE = new ExpressionReplacer(); + + private ExpressionReplacer() { + } + + @Override + public Expression visit(Expression expr, Map replaceMap) { + if (replaceMap.containsKey(expr)) { + return replaceMap.get(expr).toSlot(); + } + return super.visit(expr, replaceMap); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java index 60c1a74445e1ff..86c8486ef45710 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java @@ -63,8 +63,9 @@ public List getProcessors() { builder.add(new MergeProjectPostProcessor()); builder.add(new RecomputeLogicalPropertiesProcessor()); builder.add(new AddOffsetIntoDistribute()); + builder.add(new CommonSubExpressionOpt()); + // DO NOT replace PLAN NODE from here builder.add(new TopNScanOpt()); - // after generate rf, DO NOT replace PLAN NODE builder.add(new FragmentProcessor()); if (!cascadesContext.getConnectContext().getSessionVariable().getRuntimeFilterMode() .toUpperCase().equals(TRuntimeFilterMode.OFF.name())) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java index af7bb950a97d96..e8472b6af23a6e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.processor.post.RuntimeFilterGenerator; import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.properties.PhysicalProperties; +import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -41,6 +42,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; import java.util.List; import java.util.Objects; @@ -52,6 +54,12 @@ public class PhysicalProject extends PhysicalUnary implements Project { private final List projects; + //multiLayerProjects is used to extract common expressions + // projects: (A+B) * 2, (A+B) * 3 + // multiLayerProjects: + // L1: A+B as x + // L2: x*2, x*3 + private List> multiLayerProjects = Lists.newArrayList(); public PhysicalProject(List projects, LogicalProperties logicalProperties, CHILD_TYPE child) { this(projects, Optional.empty(), logicalProperties, child); @@ -227,7 +235,12 @@ public boolean pushDownRuntimeFilter(CascadesContext context, IdGenerator computeOutput() { - return projects.stream() + List output = projects; + if (! multiLayerProjects.isEmpty()) { + int layers = multiLayerProjects.size(); + output = multiLayerProjects.get(layers - 1); + } + return output.stream() .map(NamedExpression::toSlot) .collect(ImmutableList.toImmutableList()); } @@ -237,4 +250,70 @@ public PhysicalProject resetLogicalProperties() { return new PhysicalProject<>(projects, groupExpression, null, physicalProperties, statistics, child()); } + + /** + * extract common expr, set multi layer projects + */ + public void computeMultiLayerProjectsForCommonExpress() { + // hard code: select (s_suppkey + s_nationkey), 1+(s_suppkey + s_nationkey), s_name from supplier; + if (projects.size() == 3) { + if (projects.get(2) instanceof SlotReference) { + SlotReference sName = (SlotReference) projects.get(2); + if (sName.getName().equals("s_name")) { + Alias a1 = (Alias) projects.get(0); // (s_suppkey + s_nationkey) + Alias a2 = (Alias) projects.get(1); // 1+(s_suppkey + s_nationkey) + // L1: (s_suppkey + s_nationkey) as x, s_name + multiLayerProjects.add(Lists.newArrayList(projects.get(0), projects.get(2))); + List l2 = Lists.newArrayList(); + l2.add(a1.toSlot()); + Alias a3 = new Alias(a2.getExprId(), new Add(a1.toSlot(), a2.child().child(1)), a2.getName()); + l2.add(a3); + l2.add(sName); + // L2: x, (1+x) as y, s_name + multiLayerProjects.add(l2); + } + } + } + // hard code: + // select (s_suppkey + n_regionkey) + 1 as x, (s_suppkey + n_regionkey) + 2 as y + // from supplier join nation on s_nationkey=n_nationkey + // projects: x, y + // multi L1: s_suppkey, n_regionkey, (s_suppkey + n_regionkey) as z + // L2: z +1 as x, z+2 as y + if (projects.size() == 2 && projects.get(0) instanceof Alias && projects.get(1) instanceof Alias + && ((Alias) projects.get(0)).getName().equals("x") + && ((Alias) projects.get(1)).getName().equals("y")) { + Alias a0 = (Alias) projects.get(0); + Alias a1 = (Alias) projects.get(1); + Add common = (Add) a0.child().child(0); // s_suppkey + n_regionkey + List l1 = Lists.newArrayList(); + common.children().stream().forEach(child -> l1.add((SlotReference) child)); + Alias aliasOfCommon = new Alias(common); + l1.add(aliasOfCommon); + multiLayerProjects.add(l1); + Add add1 = new Add(common, a0.child().child(0).child(1)); + Alias aliasOfAdd1 = new Alias(a0.getExprId(), add1, a0.getName()); + Add add2 = new Add(common, a1.child().child(0).child(1)); + Alias aliasOfAdd2 = new Alias(a1.getExprId(), add2, a1.getName()); + List l2 = Lists.newArrayList(aliasOfAdd1, aliasOfAdd2); + multiLayerProjects.add(l2); + } + } + + public boolean hasMultiLayerProjection() { + return !multiLayerProjects.isEmpty(); + } + + public List> getMultiLayerProjects() { + return multiLayerProjects; + } + + public void setMultiLayerProjects(List> multiLayers) { + this.multiLayerProjects = multiLayers; + } + + @Override + public List getOutput() { + return computeOutput(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java index b404bc4ad3545c..8cc18a527a86d6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java @@ -59,6 +59,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; /** * Each PlanNode represents a single relational operator @@ -155,6 +156,8 @@ public abstract class PlanNode extends TreeNode implements PlanStats { protected int nereidsId = -1; private List> childrenDistributeExprLists = new ArrayList<>(); + private List intermediateOutputTupleDescList = Lists.newArrayList(); + private List> intermediateProjectListList = Lists.newArrayList(); protected PlanNode(PlanNodeId id, ArrayList tupleIds, String planNodeName, StatisticalType statisticalType) { @@ -536,10 +539,20 @@ protected final String getExplainString(String rootPrefix, String prefix, TExpla expBuilder.append(detailPrefix + "limit: " + limit + "\n"); } if (!CollectionUtils.isEmpty(projectList)) { - expBuilder.append(detailPrefix).append("projections: ").append(getExplainString(projectList)).append("\n"); - expBuilder.append(detailPrefix).append("project output tuple id: ") + expBuilder.append(detailPrefix).append("final projections: ") + .append(getExplainString(projectList)).append("\n"); + expBuilder.append(detailPrefix).append("final project output tuple id: ") .append(outputTupleDesc.getId().asInt()).append("\n"); } + if (!intermediateProjectListList.isEmpty()) { + int layers = intermediateProjectListList.size(); + for (int i = layers - 1; i >= 0; i--) { + expBuilder.append(detailPrefix).append("intermediate projections: ") + .append(getExplainString(intermediateProjectListList.get(i))).append("\n"); + expBuilder.append(detailPrefix).append("intermediate tuple id: ") + .append(intermediateOutputTupleDescList.get(i).getId().asInt()).append("\n"); + } + } if (!CollectionUtils.isEmpty(childrenDistributeExprLists)) { for (List distributeExprList : childrenDistributeExprLists) { expBuilder.append(detailPrefix).append("distribute expr lists: ") @@ -660,6 +673,19 @@ private void treeToThriftHelper(TPlan container) { } } } + + if (!intermediateOutputTupleDescList.isEmpty()) { + intermediateOutputTupleDescList + .forEach( + tupleDescriptor -> msg.addToIntermediateOutputTupleIdList(tupleDescriptor.getId().asInt())); + } + + if (!intermediateProjectListList.isEmpty()) { + intermediateProjectListList.forEach( + projectList -> msg.addToIntermediateProjectionsList( + projectList.stream().map(expr -> expr.treeToThrift()).collect(Collectors.toList()))); + } + if (this instanceof ExchangeNode) { msg.num_children = 0; return; @@ -1221,4 +1247,12 @@ public boolean pushDownAggNoGroupingCheckCol(FunctionCallExpr aggExpr, Column co public void setNereidsId(int nereidsId) { this.nereidsId = nereidsId; } + + public void addIntermediateOutputTupleDescList(TupleDescriptor tupleDescriptor) { + intermediateOutputTupleDescList.add(tupleDescriptor); + } + + public void addIntermediateProjectList(List exprs) { + intermediateProjectListList.add(exprs); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java index 0f464ba2946b7d..c342d858fe1fb5 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java @@ -74,6 +74,7 @@ public static void teardown() { public void test() throws Exception { ConnectContext ctx = UtFrameUtils.createDefaultCtx(); ctx.getSessionVariable().setEnableNereidsPlanner(false); + ctx.getSessionVariable().enableFallbackToOriginalPlanner = true; ctx.getSessionVariable().setEnableFoldConstantByBe(false); // create database db1 createDatabase(ctx, "create database db1;"); @@ -113,8 +114,8 @@ public void test() throws Exception { Assert.assertTrue(constExprLists.get(0).get(0) instanceof FunctionCallExpr); queryStr = "select db1.id_masking(k1) from db1.tbl1"; - Assert.assertTrue( - dorisAssert.query(queryStr).explainQuery().contains("concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))")); // create alias function with cast // cast any type to decimal with specific precision and scale @@ -142,14 +143,16 @@ public void test() throws Exception { queryStr = "select db1.decimal(k3, 4, 1) from db1.tbl1;"; if (Config.enable_decimal_conversion) { - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMALV3(4, 1))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k3` AS DECIMALV3(4, 1))")); } else { - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMAL(4, 1))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k3` AS DECIMAL(4, 1))")); } // cast any type to varchar with fixed length - createFuncStr = "create alias function db1.varchar(all) with parameter(text) as " - + "cast(text as varchar(65533));"; + createFuncStr = "create alias function db1.varchar(all, int) with parameter(text, length) as " + + "cast(text as varchar(length));"; createFunctionStmt = (CreateFunctionStmt) UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx); Env.getCurrentEnv().createFunction(createFunctionStmt); @@ -172,7 +175,8 @@ public void test() throws Exception { Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral); queryStr = "select db1.varchar(k1, 4) from db1.tbl1;"; - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS VARCHAR(65533))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k1` AS VARCHAR(65533))")); // cast any type to char with fixed length createFuncStr = "create alias function db1.to_char(all, int) with parameter(text, length) as " @@ -199,7 +203,8 @@ public void test() throws Exception { Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral); queryStr = "select db1.to_char(k1, 4) from db1.tbl1;"; - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS CHARACTER")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k1` AS CHARACTER")); } @Test @@ -235,8 +240,8 @@ public void testCreateGlobalFunction() throws Exception { testFunctionQuery(ctx, queryStr, false); queryStr = "select id_masking(k1) from db2.tbl1"; - Assert.assertTrue( - dorisAssert.query(queryStr).explainQuery().contains("concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))")); // 4. create alias function with cast // cast any type to decimal with specific precision and scale @@ -253,9 +258,11 @@ public void testCreateGlobalFunction() throws Exception { queryStr = "select decimal(k3, 4, 1) from db2.tbl1;"; if (Config.enable_decimal_conversion) { - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMALV3(4, 1))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k3` AS DECIMALV3(4, 1))")); } else { - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMAL(4, 1))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k3` AS DECIMAL(4, 1))")); } // 5. cast any type to varchar with fixed length @@ -271,7 +278,8 @@ public void testCreateGlobalFunction() throws Exception { testFunctionQuery(ctx, queryStr, true); queryStr = "select varchar(k1, 4) from db2.tbl1;"; - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS VARCHAR(65533))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k1` AS VARCHAR(65533))")); // 6. cast any type to char with fixed length createFuncStr = "create global alias function db2.to_char(all, int) with parameter(text, length) as " @@ -286,7 +294,8 @@ public void testCreateGlobalFunction() throws Exception { testFunctionQuery(ctx, queryStr, true); queryStr = "select to_char(k1, 4) from db2.tbl1;"; - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS CHARACTER)")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k1` AS CHARACTER)")); } private void testFunctionQuery(ConnectContext ctx, String queryStr, Boolean isStringLiteral) throws Exception { @@ -320,4 +329,8 @@ private void createDatabase(ConnectContext ctx, String createDbStmtStr) throws E Env.getCurrentEnv().createDb(createDbStmt); System.out.println(Env.getCurrentInternalCatalog().getDbNames()); } + + private boolean containsIgnoreCase(String str, String sub) { + return str.toLowerCase().contains(sub.toLowerCase()); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java new file mode 100644 index 00000000000000..56b67e087d59ab --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.postprocess; + +import org.apache.doris.nereids.processor.post.CommonSubExpressionCollector; +import org.apache.doris.nereids.processor.post.CommonSubExpressionOpt; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.types.IntegerType; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +public class CommonSubExpressionTest extends ExpressionRewriteTestHelper { + @Test + public void testExtractCommonExpr() { + List exprs = parseProjections("a+b, a+b+1, abs(a+b+1), a"); + CommonSubExpressionCollector collector = + new CommonSubExpressionCollector(); + exprs.forEach(expr -> collector.visit(expr, null)); + System.out.println(collector.commonExprByDepth); + Assertions.assertEquals(2, collector.commonExprByDepth.size()); + List l1 = collector.commonExprByDepth.get(Integer.valueOf(1)) + .stream().collect(Collectors.toList()); + List l2 = collector.commonExprByDepth.get(Integer.valueOf(2)) + .stream().collect(Collectors.toList()); + Assertions.assertEquals(1, l1.size()); + assertExpression(l1.get(0), "a+b"); + Assertions.assertEquals(1, l2.size()); + assertExpression(l2.get(0), "a+b+1"); + } + + @Test + public void testMultiLayers() throws Exception { + List exprs = parseProjections("a, a+b, a+b+1, abs(a+b+1), a"); + Set inputSlots = exprs.get(0).getInputSlots(); + CommonSubExpressionOpt opt = new CommonSubExpressionOpt(); + Method computeMultLayerProjectionsMethod = CommonSubExpressionOpt.class + .getDeclaredMethod("computeMultiLayerProjections", Set.class, List.class); + computeMultLayerProjectionsMethod.setAccessible(true); + List> multiLayers = (List>) computeMultLayerProjectionsMethod + .invoke(opt, inputSlots, exprs); + System.out.println(multiLayers); + Assertions.assertEquals(3, multiLayers.size()); + List l0 = multiLayers.get(0); + Assertions.assertEquals(2, l0.size()); + Assertions.assertTrue(l0.contains(ExprParser.INSTANCE.parseExpression("a"))); + Assertions.assertTrue(l0.get(1) instanceof Alias); + assertExpression(l0.get(1).child(0), "a+b"); + Assertions.assertEquals(multiLayers.get(1).size(), 3); + Assertions.assertEquals(multiLayers.get(2).size(), 5); + List l2 = multiLayers.get(2); + for (int i = 0; i < 5; i++) { + Assertions.assertEquals(exprs.get(i).getExprId().asInt(), l2.get(i).getExprId().asInt()); + } + + } + + private void assertExpression(Expression expr, String str) { + Assertions.assertEquals(ExprParser.INSTANCE.parseExpression(str), expr); + } + + private List parseProjections(String exprList) { + List result = new ArrayList<>(); + String[] exprArray = exprList.split(","); + for (String item : exprArray) { + Expression expr = ExprParser.INSTANCE.parseExpression(item); + if (expr instanceof NamedExpression) { + result.add((NamedExpression) expr); + } else { + result.add(new Alias(expr)); + } + } + return result; + } + + public static class ExprParser { + public static ExprParser INSTANCE = new ExprParser(); + HashMap slotMap = new HashMap<>(); + + public Expression parseExpression(String str) { + Expression expr = PARSER.parseExpression(str); + return expr.accept(DataTypeAssignor.INSTANCE, slotMap); + } + } + + public static class DataTypeAssignor extends DefaultExpressionRewriter> { + public static DataTypeAssignor INSTANCE = new DataTypeAssignor(); + + @Override + public Expression visitSlot(Slot slot, Map slotMap) { + SlotReference exitsSlot = slotMap.get(slot.getName()); + if (exitsSlot != null) { + return exitsSlot; + } else { + SlotReference slotReference = new SlotReference(slot.getName(), IntegerType.INSTANCE); + slotMap.put(slot.getName(), slotReference); + return slotReference; + } + } + } + +} diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift index 2fadcdae538795..d88ab993363352 100644 --- a/gensrc/thrift/PlanNodes.thrift +++ b/gensrc/thrift/PlanNodes.thrift @@ -1294,10 +1294,13 @@ struct TPlanNode { 49: optional i64 push_down_count 50: optional list> distribute_expr_lists - + // projections is final projections, which means projecting into results and materializing them into the output block. 101: optional list projections 102: optional Types.TTupleId output_tuple_id 103: optional TPartitionSortNode partition_sort_node + // Intermediate projections will not materialize into the output block. + 104: optional list> intermediate_projections_list + 105: optional list intermediate_output_tuple_id_list } // A flattened representation of a tree of PlanNodes, obtained by depth-first diff --git a/regression-test/data/tpch_sf0.1_p1/sql/cse.out b/regression-test/data/tpch_sf0.1_p1/sql/cse.out new file mode 100644 index 00000000000000..454fe1083b511e --- /dev/null +++ b/regression-test/data/tpch_sf0.1_p1/sql/cse.out @@ -0,0 +1,31 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !cse -- +1 1 3 4 +2 0 3 4 +3 1 5 6 +4 0 5 6 +5 4 10 11 +6 0 7 8 +7 3 11 12 +8 1 10 11 +9 4 14 15 +10 1 12 13 + +-- !cse_2 -- +17 1 18 19 19 +5 2 7 8 8 +1 3 4 5 5 +15 4 19 20 20 +11 5 16 17 17 +14 6 20 21 21 +23 7 30 31 31 +17 8 25 26 26 +10 9 19 20 20 +24 10 34 35 35 + +-- !cse_3 -- +12093 13093 14093 15093 + +-- !cse_4 -- +12093 13093 14093 15093 + diff --git a/regression-test/suites/tpch_sf0.1_p1/sql/cse.sql b/regression-test/suites/tpch_sf0.1_p1/sql/cse.sql new file mode 100644 index 00000000000000..a7885eb9ce349a --- /dev/null +++ b/regression-test/suites/tpch_sf0.1_p1/sql/cse.sql @@ -0,0 +1,6 @@ +select s_suppkey,n_regionkey,(s_suppkey + n_regionkey) + 1 as x, (s_suppkey + n_regionkey) + 2 as y +from supplier join nation on s_nationkey=n_nationkey order by s_suppkey , n_regionkey limit 10 ; +select s_nationkey,s_suppkey ,(s_nationkey + s_suppkey), (s_nationkey + s_suppkey) + 1, abs((s_nationkey + s_suppkey) + 1) +from supplier order by s_suppkey , s_suppkey limit 10 ; +select sum(s_nationkey),sum(s_nationkey +1 ) ,sum(s_nationkey +2 ) , sum(s_nationkey + 3 ) from supplier ; +select sum(s_nationkey),sum(s_nationkey) + count(1) ,sum(s_nationkey) + 2 * count(1) , sum(s_nationkey) + 3 * count(1) from supplier ; \ No newline at end of file