Skip to content

Commit

Permalink
[CH-186] support RangePartitioning (#189)
Browse files Browse the repository at this point in the history
* WIP

* two way to partition now

* WIP: use actions dag to compare is too slow

* WIP

* fixed code style

* remove debug codes

* fixed code style

* fixed header include

* fixed a bug in tcp q20

* support range partition in shuffle splitter

* remove unused headers

* support expresions caculate in range paritioning
  • Loading branch information
lgbo-ustc authored Nov 22, 2022
1 parent e225f47 commit e5b7449
Show file tree
Hide file tree
Showing 11 changed files with 683 additions and 156 deletions.
8 changes: 6 additions & 2 deletions utils/local-engine/Parser/CHColumnToSparkRow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Common/Exception.h>

namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int UNKNOWN_TYPE;
}
}
Expand Down Expand Up @@ -341,8 +343,10 @@ int64_t SparkRowInfo::getTotalBytes() const

std::unique_ptr<SparkRowInfo> CHColumnToSparkRow::convertCHColumnToSparkRow(const Block & block)
{
if (!block.rows() || !block.columns())
return {};
if (!block.columns())
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "A block with empty columns");
}

std::unique_ptr<SparkRowInfo> spark_row_info = std::make_unique<SparkRowInfo>(block);
spark_row_info->setBufferAddress(reinterpret_cast<char *>(alloc(spark_row_info->getTotalBytes(), 64)));
Expand Down
203 changes: 126 additions & 77 deletions utils/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "SerializedPlanParser.h"
#include <memory>
#include <base/logger_useful.h>
#include <base/Decimal.h>
Expand Down Expand Up @@ -54,7 +55,11 @@

#include <Processors/QueryPlan/QueryPlan.h>
#include <Processors/QueryPlan/SortingStep.h>
#include <DataTypes/DataTypeFactory.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <base/types.h>
#include <Storages/IStorage.h>
#include <sys/select.h>
#include <Common/CHUtil.h>
#include "SerializedPlanParser.h"

Expand Down Expand Up @@ -92,6 +97,91 @@ bool isTypeMatched(const substrait::Type & substrait_type, const DataTypePtr & c
return parsed_ch_type->equals(*ch_type);
}

void SerializedPlanParser::parseExtensions(
const ::google::protobuf::RepeatedPtrField<substrait::extensions::SimpleExtensionDeclaration> & extensions)
{
for (const auto & extension : extensions)
{
if (extension.has_extension_function())
{
this->function_mapping.emplace(
std::to_string(extension.extension_function().function_anchor()), extension.extension_function().name());
}
}
}

std::shared_ptr<DB::ActionsDAG> SerializedPlanParser::expressionsToActionsDAG(
const ::google::protobuf::RepeatedPtrField<substrait::Expression> & expressions,
const DB::Block & header,
const DB::Block & read_schema)
{
auto actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(header));
NamesWithAliases required_columns;
std::set<String> distinct_columns;

for (const auto & expr : expressions)
{
if (expr.has_selection())
{
auto position = expr.selection().direct_reference().struct_field().field();
const ActionsDAG::Node * field = actions_dag->tryFindInIndex(read_schema.getByPosition(position).name);
if (distinct_columns.contains(field->result_name))
{
auto unique_name = getUniqueName(field->result_name);
required_columns.emplace_back(NameWithAlias(field->result_name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(field->result_name, field->result_name));
distinct_columns.emplace(field->result_name);
}
}
else if (expr.has_scalar_function())
{
std::string name;
std::vector<String> useless;
actions_dag = parseFunction(header, expr, name, useless, actions_dag, true);
if (!name.empty())
{
if (distinct_columns.contains(name))
{
auto unique_name = getUniqueName(name);
required_columns.emplace_back(NameWithAlias(name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(name, name));
distinct_columns.emplace(name);
}
}
}
else if (expr.has_cast() || expr.has_if_then() || expr.has_literal())
{
const auto * node = parseArgument(actions_dag, expr);
actions_dag->addOrReplaceInIndex(*node);
if (distinct_columns.contains(node->result_name))
{
auto unique_name = getUniqueName(node->result_name);
required_columns.emplace_back(NameWithAlias(node->result_name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(node->result_name, node->result_name));
distinct_columns.emplace(node->result_name);
}
}
else
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "unsupported projection type {}.", magic_enum::enum_name(expr.rex_type_case()));
}
}
actions_dag->project(required_columns);
return actions_dag;
}

/// TODO: This function needs to be improved for Decimal/Array/Map/Tuple types.
std::string getCastFunction(const substrait::Type & type)
{
Expand Down Expand Up @@ -458,6 +548,28 @@ DataTypePtr SerializedPlanParser::parseType(const substrait::Type & substrait_ty
return std::move(ch_type);
}

DB::DataTypePtr SerializedPlanParser::parseType(const std::string & type)
{
static std::map<std::string, std::string> type2type = {
{"BooleanType", "UInt8"},
{"ByteType", "Int8"},
{"ShortType", "Int16"},
{"IntegerType", "Int32"},
{"LongType", "Int64"},
{"FloatType", "Float32"},
{"DoubleType", "Float64"},
{"StringType", "String"},
{"DateType", "Date"}
};

auto it = type2type.find(type);
if (it == type2type.end())
{
throw DB::Exception(DB::ErrorCodes::UNKNOWN_TYPE, "Unknow spark type: {}", type);
}
return DB::DataTypeFactory::instance().get(it->second);
}

QueryPlanPtr SerializedPlanParser::parse(std::unique_ptr<substrait::Plan> plan)
{
auto * logger = &Poco::Logger::get("SerializedPlanParser");
Expand All @@ -469,17 +581,7 @@ QueryPlanPtr SerializedPlanParser::parse(std::unique_ptr<substrait::Plan> plan)
pb_util::MessageToJsonString(*plan, &json, options);
LOG_DEBUG(&Poco::Logger::get("SerializedPlanParser"), "substrait plan:{}", json);
}
if (plan->extensions_size() > 0)
{
for (const auto & extension : plan->extensions())
{
if (extension.has_extension_function())
{
this->function_mapping.emplace(
std::to_string(extension.extension_function().function_anchor()), extension.extension_function().name());
}
}
}
parseExtensions(plan->extensions());
if (plan->relations_size() == 1)
{
auto root_rel = plan->relations().at(0);
Expand Down Expand Up @@ -572,71 +674,7 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel)
read_schema = query_plan->getCurrentDataStream().header;
}
const auto & expressions = project.expressions();
auto actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(query_plan->getCurrentDataStream().header));
NamesWithAliases required_columns;
std::set<String> distinct_columns;

for (const auto & expr : expressions)
{
if (expr.has_selection())
{
auto position = expr.selection().direct_reference().struct_field().field();
const ActionsDAG::Node * field = actions_dag->tryFindInIndex(read_schema.getByPosition(position).name);
if (distinct_columns.contains(field->result_name))
{
auto unique_name = getUniqueName(field->result_name);
required_columns.emplace_back(NameWithAlias(field->result_name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(field->result_name, field->result_name));
distinct_columns.emplace(field->result_name);
}
}
else if (expr.has_scalar_function())
{
std::string name;
std::vector<String> useless;
actions_dag = parseFunction(query_plan->getCurrentDataStream().header, expr, name, useless, actions_dag, true);
if (!name.empty())
{
if (distinct_columns.contains(name))
{
auto unique_name = getUniqueName(name);
required_columns.emplace_back(NameWithAlias(name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(name, name));
distinct_columns.emplace(name);
}
}
}
else if (expr.has_cast() || expr.has_if_then() || expr.has_literal())
{
const auto * node = parseArgument(actions_dag, expr);
actions_dag->addOrReplaceInIndex(*node);
if (distinct_columns.contains(node->result_name))
{
auto unique_name = getUniqueName(node->result_name);
required_columns.emplace_back(NameWithAlias(node->result_name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(node->result_name, node->result_name));
distinct_columns.emplace(node->result_name);
}
}
else
{
throw Exception(
ErrorCodes::BAD_ARGUMENTS, "unsupported projection type {}.", magic_enum::enum_name(expr.rex_type_case()));
}
}
actions_dag->project(required_columns);
auto actions_dag = expressionsToActionsDAG(expressions, query_plan->getCurrentDataStream().header, read_schema);
auto expression_step = std::make_unique<ExpressionStep>(query_plan->getCurrentDataStream(), actions_dag);
expression_step->setStepDescription("Project");
query_plan->addStep(std::move(expression_step));
Expand Down Expand Up @@ -730,7 +768,7 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel)
break;
}
default:
throw Exception(ErrorCodes::UNKNOWN_TYPE, "doesn't support relation type");
throw Exception(ErrorCodes::UNKNOWN_TYPE, "doesn't support relation type: {}.\n{}", rel.rel_type_case(), rel.DebugString());
}
return query_plan;
}
Expand Down Expand Up @@ -1697,6 +1735,16 @@ DB::SortDescription SerializedPlanParser::parseSortDescription(const substrait::
}
SharedContextHolder SerializedPlanParser::shared_context;

LocalExecutor::~LocalExecutor()
{
if (this->spark_buffer)
{
this->ch_column_to_spark_row->freeMem(spark_buffer->address, spark_buffer->size);
this->spark_buffer.reset();
}
}


void LocalExecutor::execute(QueryPlanPtr query_plan)
{
current_query_plan = std::move(query_plan);
Expand All @@ -1720,6 +1768,7 @@ void LocalExecutor::execute(QueryPlanPtr query_plan)
t_executor / 1000.0);
this->header = current_query_plan->getCurrentDataStream().header.cloneEmpty();
this->ch_column_to_spark_row = std::make_unique<CHColumnToSparkRow>();

}
std::unique_ptr<SparkRowInfo> LocalExecutor::writeBlockToSparkRow(Block & block)
{
Expand Down
19 changes: 11 additions & 8 deletions utils/local-engine/Parser/SerializedPlanParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <arrow/ipc/writer.h>
#include <substrait/plan.pb.h>
#include <Common/BlockIterator.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <base/types.h>
#include <Core/SortDescription.h>

namespace local_engine
Expand Down Expand Up @@ -143,9 +145,17 @@ class SerializedPlanParser
static bool isReadRelFromJava(const substrait::ReadRel & rel);
static DB::Block parseNameStruct(const substrait::NamedStruct & struct_);
static DB::DataTypePtr parseType(const substrait::Type & type);
// This is used for construct a data type from spark type name;
static DB::DataTypePtr parseType(const std::string & type);

void addInputIter(jobject iter) { input_iters.emplace_back(iter); }

void parseExtensions(const ::google::protobuf::RepeatedPtrField<substrait::extensions::SimpleExtensionDeclaration> & extensions);
std::shared_ptr<DB::ActionsDAG> expressionsToActionsDAG(
const ::google::protobuf::RepeatedPtrField<substrait::Expression> & expressions,
const DB::Block & header,
const DB::Block & read_schema);

static ContextMutablePtr global_context;
static Context::ConfigurationPtr config;
static SharedContextHolder shared_context;
Expand Down Expand Up @@ -246,14 +256,7 @@ class LocalExecutor : public BlockIterator
SparkRowInfoPtr next();
Block * nextColumnar();
bool hasNext();
~LocalExecutor()
{
if (this->spark_buffer)
{
this->ch_column_to_spark_row->freeMem(spark_buffer->address, spark_buffer->size);
this->spark_buffer.reset();
}
}
~LocalExecutor();

Block & getHeader();

Expand Down
Loading

0 comments on commit e5b7449

Please sign in to comment.