Skip to content

Commit

Permalink
[POAE7-1448] pass the round-trip test of aggregatesNode (facebookincu…
Browse files Browse the repository at this point in the history
…bator#2)

* Pass project Tests for round-trip trans when batchSize=1

* clean some debug info

* change code style and using log instead of cout

* Pass project Tests for round-trip plan transform

* use full names for more readable in tests

* Pass FilterNode Tests for round-trip plan transform

* [POAE7-1448] Add AggregateNode, nullValue APIs and pass six tests about transform from velox to substrait

* address the comments

* [POAE7-1448] pass the round-trip test of aggregatesNode

* address comments and update the url of substrait submodule
  • Loading branch information
ZJie1 authored Feb 16, 2022
1 parent a920651 commit 7b89cf2
Show file tree
Hide file tree
Showing 5 changed files with 442 additions and 183 deletions.
3 changes: 2 additions & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[submodule "substrait"]
path = substrait
url = https://github.com/Intel-bigdata/substrait
url = https://github.com/intel-innersource/frameworks.ai.modular-sql.substrait.git
branch = modular_sql
2 changes: 1 addition & 1 deletion substrait
167 changes: 86 additions & 81 deletions velox/exec/SubstraitIRConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,18 @@ variant SubstraitVeloxConvertor::transformSLiteralType(
const io::substrait::Expression_Literal& sLiteralExpr) {
switch (sLiteralExpr.literal_type_case()) {
case io::substrait::Expression_Literal::LiteralTypeCase::kDecimal: {
return velox::variant(sLiteralExpr.decimal());
// Mapping the kDecimal in Substrait to DOUBLE in Velox
return velox::variant(sLiteralExpr.fp64());
}
case io::substrait::Expression_Literal::LiteralTypeCase::kString: {
return velox::variant(sLiteralExpr.var_char());
}
case io::substrait::Expression_Literal::LiteralTypeCase::kVarChar: {
return velox::variant(sLiteralExpr.var_char());
}
case io::substrait::Expression_Literal::LiteralTypeCase::kFixedChar: {
return velox::variant(sLiteralExpr.var_char());
}
case io::substrait::Expression_Literal::LiteralTypeCase::kBoolean: {
return velox::variant(sLiteralExpr.boolean());
}
Expand All @@ -210,47 +217,49 @@ variant SubstraitVeloxConvertor::transformSLiteralType(
return processSubstraitLiteralNullType(sLiteralExpr, nullValue);
}
default:
throw std::runtime_error("Unsupported liyeral_type in transformSLiteralType " +
throw std::runtime_error(
"Unsupported liyeral_type in transformSLiteralType " +
std::to_string(sLiteralExpr.literal_type_case()));
}
}

variant SubstraitVeloxConvertor::processSubstraitLiteralNullType(
const io::substrait::Expression_Literal& sLiteralExpr,
io::substrait::Type nullType) {
switch (nullType.kind_case()) {
case io::substrait::Type::kDecimal: {
return velox::variant(sLiteralExpr.decimal());
}
case io::substrait::Type::kString: {
return velox::variant(sLiteralExpr.var_char());
}
case io::substrait::Type::kBool: {
return velox::variant(sLiteralExpr.boolean());
}
case io::substrait::Type::kI64: {
return velox::variant(sLiteralExpr.i64());
}
case io::substrait::Type::kI32: {
return velox::variant(sLiteralExpr.i32());
}
case io::substrait::Type::kI16: {
return velox::variant(static_cast<int16_t>(sLiteralExpr.i16()));
}
case io::substrait::Type::kI8: {
return velox::variant(static_cast<int8_t>(sLiteralExpr.i8()));
}
case io::substrait::Type::kFp64: {
return velox::variant(sLiteralExpr.fp64());
}
case io::substrait::Type::kFp32: {
return velox::variant(sLiteralExpr.fp32());
const io::substrait::Expression_Literal& sLiteralExpr,
io::substrait::Type nullType) {
switch (nullType.kind_case()) {
case io::substrait::Type::kDecimal: {
// mapping to DOUBLE
return velox::variant(sLiteralExpr.fp64());
}
case io::substrait::Type::kString: {
return velox::variant(sLiteralExpr.var_char());
}
case io::substrait::Type::kBool: {
return velox::variant(sLiteralExpr.boolean());
}
case io::substrait::Type::kI64: {
return velox::variant(sLiteralExpr.i64());
}
case io::substrait::Type::kI32: {
return velox::variant(sLiteralExpr.i32());
}
case io::substrait::Type::kI16: {
return velox::variant(static_cast<int16_t>(sLiteralExpr.i16()));
}
case io::substrait::Type::kI8: {
return velox::variant(static_cast<int8_t>(sLiteralExpr.i8()));
}
case io::substrait::Type::kFp64: {
return velox::variant(sLiteralExpr.fp64());
}
case io::substrait::Type::kFp32: {
return velox::variant(sLiteralExpr.fp32());
}
default:
throw std::runtime_error(
"Unsupported type in processSubstraitLiteralNullType " +
std::to_string(nullType.kind_case()));
}
}
}

std::shared_ptr<const ITypedExpr> SubstraitVeloxConvertor::transformSExpr(
Expand Down Expand Up @@ -481,13 +490,13 @@ std::shared_ptr<PlanNode> SubstraitVeloxConvertor::transformSRead(
sRead.virtual_table().values(numRows - 1).fields_size();

std::vector<RowVectorPtr> vectors;
std::vector<VectorPtr> children;
bool nullFlag = false;
std::shared_ptr<RowVector> rowVector;

int64_t batchSize = valueFieldNums / numColumns;

for (int32_t row = 0; row < numRows; ++row) {
std::vector<VectorPtr> children;
std::shared_ptr<RowVector> rowVector;
io::substrait::Expression_Literal_Struct sRowValue =
sRead.virtual_table().values(row);
int64_t sFieldSize = sRowValue.fields_size();
Expand All @@ -503,12 +512,10 @@ std::shared_ptr<PlanNode> SubstraitVeloxConvertor::transformSRead(
// for the null value
if (sFieldType == 29) {
nullFlag = true;
childrenValue = BaseVector::createConstant(
transformSLiteralType(sField), batchSize, pool_);
childrenValue = BaseVector::createNullConstant(
vOutputChildType, batchSize, pool_);
} else {
// TODO need to confirm whether using
// VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH or VELOX_DYNAMIC_TYPE_DISPATCH
childrenValue = VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
childrenValue = VELOX_DYNAMIC_TYPE_DISPATCH(
test::BatchMaker::createVector,
vOutputChildType->kind(),
vOutputType->childAt(col),
Expand All @@ -520,13 +527,12 @@ std::shared_ptr<PlanNode> SubstraitVeloxConvertor::transformSRead(
if (nullFlag) {
rowVector = std::make_shared<RowVector>(
pool_, vOutputType, BufferPtr(nullptr), batchSize, children);
vectors.push_back(rowVector);

} else {
auto vector = std::dynamic_pointer_cast<RowVector>(
test::BatchMaker::createBatch(vOutputType, batchSize, *pool_));
vectors.push_back(vector);
rowVector = std::make_shared<RowVector>(
pool_, vOutputType, BufferPtr(), batchSize, children);
}
vectors.emplace_back(rowVector);
}

return std::make_shared<ValuesNode>(
Expand All @@ -541,7 +547,8 @@ std::shared_ptr<ProjectNode> SubstraitVeloxConvertor::transformSProject(
std::vector<std::shared_ptr<const ITypedExpr>> vExpressions;
std::vector<std::string> names;

std::shared_ptr<const PlanNode> vSource = fromSubstraitIR(sProj.input(), depth + 1);
std::shared_ptr<const PlanNode> vSource =
fromSubstraitIR(sProj.input(), depth + 1);

for (auto& sExpr : sProj.expressions()) {
std::shared_ptr<const ITypedExpr> vExpr =
Expand Down Expand Up @@ -579,8 +586,9 @@ std::shared_ptr<AggregationNode> SubstraitVeloxConvertor::transformSAggregate(
std::vector<std::shared_ptr<const FieldAccessTypedExpr>> groupingKeys;
std::shared_ptr<const FieldAccessTypedExpr> groupingKey;

const io::substrait::AggregateRel &sAgg = sRel.aggregate();
std::shared_ptr<const PlanNode> vSource = fromSubstraitIR(sAgg.input(), depth + 1);
const io::substrait::AggregateRel& sAgg = sRel.aggregate();
std::shared_ptr<const PlanNode> vSource =
fromSubstraitIR(sAgg.input(), depth + 1);

// TODO need to confirm whether this is only for one grouping set, GROUP BY
// a,b,c. Not fit for GROUPING SETS ???
Expand All @@ -594,36 +602,39 @@ std::shared_ptr<AggregationNode> SubstraitVeloxConvertor::transformSAggregate(
}
}
// for velox sum(c) is ok, but sum(c + d) is not.
for (auto &sMeas: sAgg.measures()) {
for (auto& sMeas : sAgg.measures()) {
io::substrait::Expression_AggregateFunction sMeasure = sMeas.measure();
if (sMeas.has_filter()) {
io::substrait::Expression sAggMask = sMeas.filter();
// handle the case sum(IF(linenumber = 7, partkey)) <=>sum(partkey) FILTER
// (where linenumber = 7) For each measure, an optional boolean input
// column that is used to mask out rows for this particular measure.

std::shared_ptr<const ITypedExpr> vAggMask =
transformSExpr(sAggMask, sGlobalMapping);
aggregateMask =
std::dynamic_pointer_cast<const FieldAccessTypedExpr>(vAggMask);
size_t sAggMaskLength = sAggMask.ByteSizeLong();
if (sAggMaskLength == 0) {
aggregateMask = {};
} else {
std::shared_ptr<const ITypedExpr> vAggMask =
transformSExpr(sAggMask, sGlobalMapping);
aggregateMask =
std::dynamic_pointer_cast<const FieldAccessTypedExpr>(vAggMask);
}
aggregateMasks.push_back(aggregateMask);
}

std::vector<std::shared_ptr<const ITypedExpr>> children;
std::string out_name;
std::string function_name = FindFunction(sMeasure.id().id());
out_name = function_name;
// AggregateFunction.args should be one for velox . if not, should do project firstly
int64_t sMeasureArgSize = sMeasure.args_size();
// the very simple case for sum(a) not very sure if this will contain the situation with maskExpression.
// AggregateFunction.args should be one for velox . if not, should do
// project firstly
int64_t sMeasureArgSize = sMeasure.args_size();
// the very simple case for sum(a) need to check if this will contain the
// situation with maskExpression.
if (sMeasureArgSize == 1) {
auto vMeasureArgExpr = transformSExpr(sMeasure.args()[0], sGlobalMapping);
if (auto vMeasureArg =
std::dynamic_pointer_cast<const CallTypedExpr>(vMeasureArgExpr)) {
aggregates.push_back(vMeasureArg);
// TODO : should be decided which aggregateNames should be
// the first way is re-construct the names according the res.
// debug to see if it's sum_a
out_name += vMeasureArg->toString();
aggregateNames.push_back(out_name);
}
Expand Down Expand Up @@ -653,21 +664,10 @@ std::shared_ptr<AggregationNode> SubstraitVeloxConvertor::transformSAggregate(
step = AggregationNode::Step::kSingle;
break;
}
default: VELOX_UNSUPPORTED("Unsupported aggregation step");
default:
VELOX_UNSUPPORTED("Unsupported aggregation step");
}
}
/*// TODO one of them is ok, need to be decided
// the second way to get the aggregateNames
io::substrait::Type_NamedStruct sAggOutMap = sAgg.common().emit().output_mapping(0);
// the proj common is always start from 0. because the way we trans from velox to substrait.
int64_t sAggOutMapSize = sAggOutMap.index_size();
for (int64_t i = 0; i < sAggOutMapSize; i++) {
aggregateNames.push_back(sAggOutMap.names(i));
}*/

//TODO Agg don't have emit outputMapping
//aggregateNames != vSource->outputType()->names();
//need to use global variable or the first way.

return std::make_shared<AggregationNode>(
std::to_string(depth),
Expand All @@ -678,7 +678,6 @@ std::shared_ptr<AggregationNode> SubstraitVeloxConvertor::transformSAggregate(
aggregateMasks,
ignoreNullKeys,
vSource);

}

std::shared_ptr<OrderByNode> SubstraitVeloxConvertor::transformSSort(
Expand Down Expand Up @@ -1423,15 +1422,15 @@ void SubstraitVeloxConvertor::transformVExpr(
// different by function names.
if (vCallTypeExprFunName == "if") {
io::substrait::Expression_IfThen* sFun = sExpr->mutable_if_then();
int64_t vCallTypeInputSize = vCallTypeInputs.size();
for (int64_t i = 0; i < vCallTypeInputSize; i++) {
std::shared_ptr<const ITypedExpr> vCallTypeInput =
vCallTypeInputs.at(i);
// TODO
// need to judge according the names in the expr, and then set them to
// the if or then or else expr can debug to find when process project
// node
}
int64_t vCallTypeInputSize = vCallTypeInputs.size();
for (int64_t i = 0; i < vCallTypeInputSize; i++) {
std::shared_ptr<const ITypedExpr> vCallTypeInput =
vCallTypeInputs.at(i);
// TODO
// need to judge according the names in the expr, and then set them to
// the if or then or else expr can debug to find when process project
// node
}
} else if (vCallTypeExprFunName == "switch") {
io::substrait::Expression_SwitchExpression* sFun =
sExpr->mutable_switch_expression();
Expand Down Expand Up @@ -1538,6 +1537,12 @@ void SubstraitVeloxConvertor::transformVConstantExpr(
sLiteralExpr->set_fp32(vConstExpr.value<TypeKind::REAL>());
break;
}
case velox::TypeKind::TIMESTAMP: {
// TODO
sLiteralExpr->set_timestamp(
vConstExpr.value<TypeKind::TIMESTAMP>().getNanos());
break;
}
default:
throw std::runtime_error(
"Unsupported constant Type" + mapTypeKindToName(vConstExpr.kind()));
Expand Down
Loading

0 comments on commit 7b89cf2

Please sign in to comment.