Skip to content

Commit

Permalink
support expresions caculate in range paritioning
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Nov 18, 2022
1 parent a78c749 commit ba4f6fe
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 78 deletions.
165 changes: 88 additions & 77 deletions utils/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,91 @@ bool isTypeMatched(const substrait::Type & substrait_type, const DataTypePtr & c
return parsed_ch_type->equals(*ch_type);
}

void SerializedPlanParser::parseExtensions(
const ::google::protobuf::RepeatedPtrField<substrait::extensions::SimpleExtensionDeclaration> & extensions)
{
for (const auto & extension : extensions)
{
if (extension.has_extension_function())
{
this->function_mapping.emplace(
std::to_string(extension.extension_function().function_anchor()), extension.extension_function().name());
}
}
}

std::shared_ptr<DB::ActionsDAG> SerializedPlanParser::expressionsToActionsDAG(
const ::google::protobuf::RepeatedPtrField<substrait::Expression> & expressions,
const DB::Block & header,
const DB::Block & read_schema)
{
auto actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(header));
NamesWithAliases required_columns;
std::set<String> distinct_columns;

for (const auto & expr : expressions)
{
if (expr.has_selection())
{
auto position = expr.selection().direct_reference().struct_field().field();
const ActionsDAG::Node * field = actions_dag->tryFindInIndex(read_schema.getByPosition(position).name);
if (distinct_columns.contains(field->result_name))
{
auto unique_name = getUniqueName(field->result_name);
required_columns.emplace_back(NameWithAlias(field->result_name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(field->result_name, field->result_name));
distinct_columns.emplace(field->result_name);
}
}
else if (expr.has_scalar_function())
{
std::string name;
std::vector<String> useless;
actions_dag = parseFunction(header, expr, name, useless, actions_dag, true);
if (!name.empty())
{
if (distinct_columns.contains(name))
{
auto unique_name = getUniqueName(name);
required_columns.emplace_back(NameWithAlias(name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(name, name));
distinct_columns.emplace(name);
}
}
}
else if (expr.has_cast() || expr.has_if_then() || expr.has_literal())
{
const auto * node = parseArgument(actions_dag, expr);
actions_dag->addOrReplaceInIndex(*node);
if (distinct_columns.contains(node->result_name))
{
auto unique_name = getUniqueName(node->result_name);
required_columns.emplace_back(NameWithAlias(node->result_name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(node->result_name, node->result_name));
distinct_columns.emplace(node->result_name);
}
}
else
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "unsupported projection type {}.", magic_enum::enum_name(expr.rex_type_case()));
}
}
actions_dag->project(required_columns);
return actions_dag;
}

/// TODO: This function needs to be improved for Decimal/Array/Map/Tuple types.
std::string getCastFunction(const substrait::Type & type)
{
Expand Down Expand Up @@ -496,17 +581,7 @@ QueryPlanPtr SerializedPlanParser::parse(std::unique_ptr<substrait::Plan> plan)
pb_util::MessageToJsonString(*plan, &json, options);
LOG_DEBUG(&Poco::Logger::get("SerializedPlanParser"), "substrait plan:{}", json);
}
if (plan->extensions_size() > 0)
{
for (const auto & extension : plan->extensions())
{
if (extension.has_extension_function())
{
this->function_mapping.emplace(
std::to_string(extension.extension_function().function_anchor()), extension.extension_function().name());
}
}
}
parseExtensions(plan->extensions());
if (plan->relations_size() == 1)
{
auto root_rel = plan->relations().at(0);
Expand Down Expand Up @@ -599,71 +674,7 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel)
read_schema = query_plan->getCurrentDataStream().header;
}
const auto & expressions = project.expressions();
auto actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(query_plan->getCurrentDataStream().header));
NamesWithAliases required_columns;
std::set<String> distinct_columns;

for (const auto & expr : expressions)
{
if (expr.has_selection())
{
auto position = expr.selection().direct_reference().struct_field().field();
const ActionsDAG::Node * field = actions_dag->tryFindInIndex(read_schema.getByPosition(position).name);
if (distinct_columns.contains(field->result_name))
{
auto unique_name = getUniqueName(field->result_name);
required_columns.emplace_back(NameWithAlias(field->result_name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(field->result_name, field->result_name));
distinct_columns.emplace(field->result_name);
}
}
else if (expr.has_scalar_function())
{
std::string name;
std::vector<String> useless;
actions_dag = parseFunction(query_plan->getCurrentDataStream().header, expr, name, useless, actions_dag, true);
if (!name.empty())
{
if (distinct_columns.contains(name))
{
auto unique_name = getUniqueName(name);
required_columns.emplace_back(NameWithAlias(name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(name, name));
distinct_columns.emplace(name);
}
}
}
else if (expr.has_cast() || expr.has_if_then() || expr.has_literal())
{
const auto * node = parseArgument(actions_dag, expr);
actions_dag->addOrReplaceInIndex(*node);
if (distinct_columns.contains(node->result_name))
{
auto unique_name = getUniqueName(node->result_name);
required_columns.emplace_back(NameWithAlias(node->result_name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(node->result_name, node->result_name));
distinct_columns.emplace(node->result_name);
}
}
else
{
throw Exception(
ErrorCodes::BAD_ARGUMENTS, "unsupported projection type {}.", magic_enum::enum_name(expr.rex_type_case()));
}
}
actions_dag->project(required_columns);
auto actions_dag = expressionsToActionsDAG(expressions, query_plan->getCurrentDataStream().header, read_schema);
auto expression_step = std::make_unique<ExpressionStep>(query_plan->getCurrentDataStream(), actions_dag);
expression_step->setStepDescription("Project");
query_plan->addStep(std::move(expression_step));
Expand Down Expand Up @@ -757,7 +768,7 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel)
break;
}
default:
throw Exception(ErrorCodes::UNKNOWN_TYPE, "doesn't support relation type");
throw Exception(ErrorCodes::UNKNOWN_TYPE, "doesn't support relation type: {}.\n{}", rel.rel_type_case(), rel.DebugString());
}
return query_plan;
}
Expand Down
6 changes: 6 additions & 0 deletions utils/local-engine/Parser/SerializedPlanParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ class SerializedPlanParser

void addInputIter(jobject iter) { input_iters.emplace_back(iter); }

void parseExtensions(const ::google::protobuf::RepeatedPtrField<substrait::extensions::SimpleExtensionDeclaration> & extensions);
std::shared_ptr<DB::ActionsDAG> expressionsToActionsDAG(
const ::google::protobuf::RepeatedPtrField<substrait::Expression> & expressions,
const DB::Block & header,
const DB::Block & read_schema);

static ContextMutablePtr global_context;
static Context::ConfigurationPtr config;
static SharedContextHolder shared_context;
Expand Down
53 changes: 52 additions & 1 deletion utils/local-engine/Shuffle/SelectorBuilder.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
#include "SelectorBuilder.h"
#include <memory>
#include <mutex>
#include <Poco/Base64Decoder.h>
#include <Poco/JSON/JSON.h>
#include <Poco/JSON/Parser.h>
#include <Poco/MemoryStream.h>
#include <Poco/StreamCopier.h>
#include <Common/Exception.h>
#include <Parser/SerializedPlanParser.h>
#include <Functions/FunctionFactory.h>
#include <Processors/QueryPlan/QueryPlan.h>
#include <Processors/QueryPlan/Optimizations/QueryPlanOptimizationSettings.h>

namespace DB
{
Expand Down Expand Up @@ -68,6 +76,17 @@ RangeSelectorBuilder::RangeSelectorBuilder(const std::string & option)
{
Poco::JSON::Parser parser;
auto info = parser.parse(option).extract<Poco::JSON::Object::Ptr>();
if (info->has("projection_plan"))
{
// for convenient, we use a serialzied protobuf to store the projeciton plan
String encoded_str = info->get("projection_plan").convert<std::string>();
Poco::MemoryInputStream istr(encoded_str.data(), encoded_str.size());
Poco::Base64Decoder decoder(istr);
String decoded_str;
Poco::StreamCopier::copyToString(decoder, decoded_str);
projection_plan_pb = std::make_unique<substrait::Plan>();
projection_plan_pb->ParseFromString(decoded_str);
}
auto ordering_infos = info->get("ordering").extract<Poco::JSON::Array::Ptr>();
initSortInformation(ordering_infos);
initRangeBlock(info->get("range_bounds").extract<Poco::JSON::Array::Ptr>());
Expand All @@ -76,7 +95,26 @@ RangeSelectorBuilder::RangeSelectorBuilder(const std::string & option)
std::vector<DB::IColumn::ColumnIndex> RangeSelectorBuilder::build(DB::Block & block)
{
std::vector<DB::IColumn::ColumnIndex> result;
computePartitionIdByBinarySearch(block, result);
if (projection_plan_pb)
{
if (!has_init_actions_dag) [[unlikely]]
initActionsDAG(block);
DB::Block copied_block = block;
projection_expression_actions->execute(copied_block, block.rows());

// need to append the order keys columns to the original block
DB::ColumnsWithTypeAndName columns = block.getColumnsWithTypeAndName();
for (const auto & projected_col : copied_block.getColumnsWithTypeAndName())
{
columns.push_back(projected_col);
}
DB::Block projected_block(columns);
computePartitionIdByBinarySearch(projected_block, result);
}
else
{
computePartitionIdByBinarySearch(block, result);
}
return result;
}

Expand Down Expand Up @@ -181,6 +219,19 @@ void RangeSelectorBuilder::initRangeBlock(Poco::JSON::Array::Ptr range_bounds)
range_bounds_block = DB::Block(columns);
}

void RangeSelectorBuilder::initActionsDAG(const DB::Block & block)
{
std::lock_guard lock(actions_dag_mutex);
if (has_init_actions_dag)
return;
SerializedPlanParser plan_parser(local_engine::SerializedPlanParser::global_context);
plan_parser.parseExtensions(projection_plan_pb->extensions());
auto projection_actions_dag
= plan_parser.expressionsToActionsDAG(projection_plan_pb->relations().at(0).root().input().project().expressions(), block, block);
projection_expression_actions = std::make_unique<DB::ExpressionActions>(projection_actions_dag);
has_init_actions_dag = true;
}

void RangeSelectorBuilder::computePartitionIdByBinarySearch(DB::Block & block, std::vector<DB::IColumn::ColumnIndex> & selector)
{
Chunks chunks;
Expand Down
11 changes: 11 additions & 0 deletions utils/local-engine/Shuffle/SelectorBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
#include <Core/SortDescription.h>
#include <Core/Block.h>
#include <Common/BlockIterator.h>
#include <memory>
#include <vector>
#include <substrait/plan.pb.h>
#include <Interpreters/ActionsDAG.h>
#include <Interpreters/ExpressionActions.h>
namespace local_engine
{
class RoundRobinSelectorBuilder
Expand Down Expand Up @@ -47,8 +51,15 @@ class RangeSelectorBuilder
std::vector<SortFieldTypeInfo> sort_field_types;
DB::Block range_bounds_block;

// If the ordering keys have expressions, we caculate the expressions here.
std::mutex actions_dag_mutex;
std::unique_ptr<substrait::Plan> projection_plan_pb;
std::atomic<bool> has_init_actions_dag;
std::unique_ptr<DB::ExpressionActions> projection_expression_actions;

void initSortInformation(Poco::JSON::Array::Ptr orderings);
void initRangeBlock(Poco::JSON::Array::Ptr range_bounds);
void initActionsDAG(const DB::Block & block);

void computePartitionIdByBinarySearch(DB::Block & block, std::vector<DB::IColumn::ColumnIndex> & selector);
int compareRow(
Expand Down

0 comments on commit ba4f6fe

Please sign in to comment.