Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CH-186] support RangePartitioning #189

Merged
merged 12 commits into from
Nov 22, 2022
8 changes: 6 additions & 2 deletions utils/local-engine/Parser/CHColumnToSparkRow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Common/Exception.h>

namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int UNKNOWN_TYPE;
}
}
Expand Down Expand Up @@ -341,8 +343,10 @@ int64_t SparkRowInfo::getTotalBytes() const

std::unique_ptr<SparkRowInfo> CHColumnToSparkRow::convertCHColumnToSparkRow(const Block & block)
{
if (!block.rows() || !block.columns())
return {};
if (!block.columns())
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "A block with empty columns");
}

std::unique_ptr<SparkRowInfo> spark_row_info = std::make_unique<SparkRowInfo>(block);
spark_row_info->setBufferAddress(reinterpret_cast<char *>(alloc(spark_row_info->getTotalBytes(), 64)));
Expand Down
203 changes: 126 additions & 77 deletions utils/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "SerializedPlanParser.h"
#include <memory>
#include <base/logger_useful.h>
#include <base/Decimal.h>
Expand Down Expand Up @@ -54,7 +55,11 @@

#include <Processors/QueryPlan/QueryPlan.h>
#include <Processors/QueryPlan/SortingStep.h>
#include <DataTypes/DataTypeFactory.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <base/types.h>
#include <Storages/IStorage.h>
#include <sys/select.h>
#include <Common/CHUtil.h>
#include "SerializedPlanParser.h"

Expand Down Expand Up @@ -92,6 +97,91 @@ bool isTypeMatched(const substrait::Type & substrait_type, const DataTypePtr & c
return parsed_ch_type->equals(*ch_type);
}

void SerializedPlanParser::parseExtensions(
const ::google::protobuf::RepeatedPtrField<substrait::extensions::SimpleExtensionDeclaration> & extensions)
{
for (const auto & extension : extensions)
{
if (extension.has_extension_function())
{
this->function_mapping.emplace(
std::to_string(extension.extension_function().function_anchor()), extension.extension_function().name());
}
}
}

std::shared_ptr<DB::ActionsDAG> SerializedPlanParser::expressionsToActionsDAG(
const ::google::protobuf::RepeatedPtrField<substrait::Expression> & expressions,
const DB::Block & header,
const DB::Block & read_schema)
{
auto actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(header));
NamesWithAliases required_columns;
std::set<String> distinct_columns;

for (const auto & expr : expressions)
{
if (expr.has_selection())
{
auto position = expr.selection().direct_reference().struct_field().field();
const ActionsDAG::Node * field = actions_dag->tryFindInIndex(read_schema.getByPosition(position).name);
if (distinct_columns.contains(field->result_name))
{
auto unique_name = getUniqueName(field->result_name);
required_columns.emplace_back(NameWithAlias(field->result_name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(field->result_name, field->result_name));
distinct_columns.emplace(field->result_name);
}
}
else if (expr.has_scalar_function())
{
std::string name;
std::vector<String> useless;
actions_dag = parseFunction(header, expr, name, useless, actions_dag, true);
if (!name.empty())
{
if (distinct_columns.contains(name))
{
auto unique_name = getUniqueName(name);
required_columns.emplace_back(NameWithAlias(name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(name, name));
distinct_columns.emplace(name);
}
}
}
else if (expr.has_cast() || expr.has_if_then() || expr.has_literal())
{
const auto * node = parseArgument(actions_dag, expr);
actions_dag->addOrReplaceInIndex(*node);
if (distinct_columns.contains(node->result_name))
{
auto unique_name = getUniqueName(node->result_name);
required_columns.emplace_back(NameWithAlias(node->result_name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(node->result_name, node->result_name));
distinct_columns.emplace(node->result_name);
}
}
else
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "unsupported projection type {}.", magic_enum::enum_name(expr.rex_type_case()));
}
}
actions_dag->project(required_columns);
return actions_dag;
}

/// TODO: This function needs to be improved for Decimal/Array/Map/Tuple types.
std::string getCastFunction(const substrait::Type & type)
{
Expand Down Expand Up @@ -458,6 +548,28 @@ DataTypePtr SerializedPlanParser::parseType(const substrait::Type & substrait_ty
return std::move(ch_type);
}

DB::DataTypePtr SerializedPlanParser::parseType(const std::string & type)
{
static std::map<std::string, std::string> type2type = {
{"BooleanType", "UInt8"},
{"ByteType", "Int8"},
{"ShortType", "Int16"},
{"IntegerType", "Int32"},
{"LongType", "Int64"},
{"FloatType", "Float32"},
{"DoubleType", "Float64"},
{"StringType", "String"},
{"DateType", "Date"}
};

auto it = type2type.find(type);
if (it == type2type.end())
{
throw DB::Exception(DB::ErrorCodes::UNKNOWN_TYPE, "Unknow spark type: {}", type);
}
return DB::DataTypeFactory::instance().get(it->second);
}

QueryPlanPtr SerializedPlanParser::parse(std::unique_ptr<substrait::Plan> plan)
{
auto * logger = &Poco::Logger::get("SerializedPlanParser");
Expand All @@ -469,17 +581,7 @@ QueryPlanPtr SerializedPlanParser::parse(std::unique_ptr<substrait::Plan> plan)
pb_util::MessageToJsonString(*plan, &json, options);
LOG_DEBUG(&Poco::Logger::get("SerializedPlanParser"), "substrait plan:{}", json);
}
if (plan->extensions_size() > 0)
{
for (const auto & extension : plan->extensions())
{
if (extension.has_extension_function())
{
this->function_mapping.emplace(
std::to_string(extension.extension_function().function_anchor()), extension.extension_function().name());
}
}
}
parseExtensions(plan->extensions());
if (plan->relations_size() == 1)
{
auto root_rel = plan->relations().at(0);
Expand Down Expand Up @@ -572,71 +674,7 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel)
read_schema = query_plan->getCurrentDataStream().header;
}
const auto & expressions = project.expressions();
auto actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(query_plan->getCurrentDataStream().header));
NamesWithAliases required_columns;
std::set<String> distinct_columns;

for (const auto & expr : expressions)
{
if (expr.has_selection())
{
auto position = expr.selection().direct_reference().struct_field().field();
const ActionsDAG::Node * field = actions_dag->tryFindInIndex(read_schema.getByPosition(position).name);
if (distinct_columns.contains(field->result_name))
{
auto unique_name = getUniqueName(field->result_name);
required_columns.emplace_back(NameWithAlias(field->result_name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(field->result_name, field->result_name));
distinct_columns.emplace(field->result_name);
}
}
else if (expr.has_scalar_function())
{
std::string name;
std::vector<String> useless;
actions_dag = parseFunction(query_plan->getCurrentDataStream().header, expr, name, useless, actions_dag, true);
if (!name.empty())
{
if (distinct_columns.contains(name))
{
auto unique_name = getUniqueName(name);
required_columns.emplace_back(NameWithAlias(name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(name, name));
distinct_columns.emplace(name);
}
}
}
else if (expr.has_cast() || expr.has_if_then() || expr.has_literal())
{
const auto * node = parseArgument(actions_dag, expr);
actions_dag->addOrReplaceInIndex(*node);
if (distinct_columns.contains(node->result_name))
{
auto unique_name = getUniqueName(node->result_name);
required_columns.emplace_back(NameWithAlias(node->result_name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(node->result_name, node->result_name));
distinct_columns.emplace(node->result_name);
}
}
else
{
throw Exception(
ErrorCodes::BAD_ARGUMENTS, "unsupported projection type {}.", magic_enum::enum_name(expr.rex_type_case()));
}
}
actions_dag->project(required_columns);
auto actions_dag = expressionsToActionsDAG(expressions, query_plan->getCurrentDataStream().header, read_schema);
auto expression_step = std::make_unique<ExpressionStep>(query_plan->getCurrentDataStream(), actions_dag);
expression_step->setStepDescription("Project");
query_plan->addStep(std::move(expression_step));
Expand Down Expand Up @@ -730,7 +768,7 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel)
break;
}
default:
throw Exception(ErrorCodes::UNKNOWN_TYPE, "doesn't support relation type");
throw Exception(ErrorCodes::UNKNOWN_TYPE, "doesn't support relation type: {}.\n{}", rel.rel_type_case(), rel.DebugString());
}
return query_plan;
}
Expand Down Expand Up @@ -1697,6 +1735,16 @@ DB::SortDescription SerializedPlanParser::parseSortDescription(const substrait::
}
SharedContextHolder SerializedPlanParser::shared_context;

LocalExecutor::~LocalExecutor()
{
if (this->spark_buffer)
{
this->ch_column_to_spark_row->freeMem(spark_buffer->address, spark_buffer->size);
this->spark_buffer.reset();
}
}


void LocalExecutor::execute(QueryPlanPtr query_plan)
{
current_query_plan = std::move(query_plan);
Expand All @@ -1720,6 +1768,7 @@ void LocalExecutor::execute(QueryPlanPtr query_plan)
t_executor / 1000.0);
this->header = current_query_plan->getCurrentDataStream().header.cloneEmpty();
this->ch_column_to_spark_row = std::make_unique<CHColumnToSparkRow>();

}
std::unique_ptr<SparkRowInfo> LocalExecutor::writeBlockToSparkRow(Block & block)
{
Expand Down
19 changes: 11 additions & 8 deletions utils/local-engine/Parser/SerializedPlanParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <arrow/ipc/writer.h>
#include <substrait/plan.pb.h>
#include <Common/BlockIterator.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <base/types.h>
#include <Core/SortDescription.h>

namespace local_engine
Expand Down Expand Up @@ -143,9 +145,17 @@ class SerializedPlanParser
static bool isReadRelFromJava(const substrait::ReadRel & rel);
static DB::Block parseNameStruct(const substrait::NamedStruct & struct_);
static DB::DataTypePtr parseType(const substrait::Type & type);
// This is used for construct a data type from spark type name;
static DB::DataTypePtr parseType(const std::string & type);

void addInputIter(jobject iter) { input_iters.emplace_back(iter); }

void parseExtensions(const ::google::protobuf::RepeatedPtrField<substrait::extensions::SimpleExtensionDeclaration> & extensions);
std::shared_ptr<DB::ActionsDAG> expressionsToActionsDAG(
const ::google::protobuf::RepeatedPtrField<substrait::Expression> & expressions,
const DB::Block & header,
const DB::Block & read_schema);

static ContextMutablePtr global_context;
static Context::ConfigurationPtr config;
static SharedContextHolder shared_context;
Expand Down Expand Up @@ -246,14 +256,7 @@ class LocalExecutor : public BlockIterator
SparkRowInfoPtr next();
Block * nextColumnar();
bool hasNext();
~LocalExecutor()
{
if (this->spark_buffer)
{
this->ch_column_to_spark_row->freeMem(spark_buffer->address, spark_buffer->size);
this->spark_buffer.reset();
}
}
~LocalExecutor();

Block & getHeader();

Expand Down
Loading