diff --git a/velox/substrait/CMakeLists.txt b/velox/substrait/CMakeLists.txt index a3a9ade8f21f..52af01470e0a 100644 --- a/velox/substrait/CMakeLists.txt +++ b/velox/substrait/CMakeLists.txt @@ -33,10 +33,8 @@ get_filename_component(PROTO_DIR ${substrait_proto_directory}/, DIRECTORY) # Generate Substrait hearders add_custom_command( OUTPUT ${PROTO_OUTPUT_FILES} - COMMAND - ${Protobuf_PROTOC_EXECUTABLE} --proto_path ${PROJECT_SOURCE_DIR}/ - --proto_path ${Protobuf_INCLUDE_DIRS} --cpp_out ${PROJECT_SOURCE_DIR} - ${PROTO_FILES} + COMMAND ${Protobuf_PROTOC_EXECUTABLE} --proto_path ${proto_directory}/ --cpp_out ${PROTO_OUTPUT_DIR} + ${PROTO_FILES} DEPENDS ${PROTO_DIR} COMMENT "Running PROTO compiler" VERBATIM) @@ -54,13 +52,14 @@ set(SRCS VeloxToSubstraitPlan.cpp VeloxToSubstraitType.cpp VeloxSubstraitSignature.cpp - VariantToVectorConverter.cpp) + VariantToVectorConverter.cpp + SubstraitToVeloxPlanValidator.cpp) add_library(velox_substrait_plan_converter ${SRCS}) target_include_directories(velox_substrait_plan_converter PUBLIC ${PROTO_OUTPUT_DIR}) target_link_libraries(velox_substrait_plan_converter velox_connector - velox_dwio_dwrf_common) + velox_dwio_dwrf_common velox_functions_spark) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) diff --git a/velox/substrait/SubstraitParser.cpp b/velox/substrait/SubstraitParser.cpp index 0e51cd50a1e1..4de60c99e7a8 100644 --- a/velox/substrait/SubstraitParser.cpp +++ b/velox/substrait/SubstraitParser.cpp @@ -118,6 +118,25 @@ std::shared_ptr SubstraitParser::parseType( nullability = substraitType.date().nullability(); break; } + case ::substrait::Type::KindCase::kTimestamp: { + typeName = "TIMESTAMP"; + nullability = substraitType.timestamp().nullability(); + break; + } + case ::substrait::Type::KindCase::kDecimal: { + auto precision = substraitType.decimal().precision(); + auto scale = substraitType.decimal().scale(); + if (precision <= 18) { + typeName = "SHORT_DECIMAL<" + std::to_string(precision) + "," + + std::to_string(scale) + ">"; + } else { + typeName = "LONG_DECIMAL<" + std::to_string(precision) + "," + + std::to_string(scale) + ">"; + } + + nullability = substraitType.decimal().nullability(); + break; + } default: VELOX_NYI( "Parsing for Substrait type not supported: {}", @@ -144,6 +163,14 @@ std::shared_ptr SubstraitParser::parseType( return std::make_shared(type); } +std::string SubstraitParser::parseType(const std::string& substraitType) { + auto it = typeMap_.find(substraitType); + if (it == typeMap_.end()) { + VELOX_NYI("Substrait parsing for type {} not supported.", substraitType); + } + return it->second; +}; + std::vector> SubstraitParser::parseNamedStruct(const ::substrait::NamedStruct& namedStruct) { // Nte that "names" are not used. @@ -160,6 +187,36 @@ SubstraitParser::parseNamedStruct(const ::substrait::NamedStruct& namedStruct) { return substraitTypeList; } +std::vector SubstraitParser::parsePartitionColumns( + const ::substrait::NamedStruct& namedStruct) { + const auto& columnsTypes = namedStruct.partition_columns().column_type(); + std::vector isPartitionColumns; + if (columnsTypes.size() == 0) { + // Regard all columns as non-partitioned columns. + isPartitionColumns.resize(namedStruct.names().size(), false); + return isPartitionColumns; + } else { + VELOX_CHECK( + columnsTypes.size() == namedStruct.names().size(), + "Invalid partion columns."); + } + + isPartitionColumns.reserve(columnsTypes.size()); + for (const auto& columnType : columnsTypes) { + switch (columnType) { + case ::substrait::PartitionColumns::NORMAL_COL: + isPartitionColumns.emplace_back(false); + break; + case ::substrait::PartitionColumns::PARTITION_COL: + isPartitionColumns.emplace_back(true); + break; + default: + VELOX_FAIL("Patition column type is not supported."); + } + } + return isPartitionColumns; +} + int32_t SubstraitParser::parseReferenceSegment( const ::substrait::Expression::ReferenceSegment& refSegment) { auto typeCase = refSegment.reference_type_case(); @@ -219,17 +276,73 @@ const std::string& SubstraitParser::findFunctionSpec( return map[id]; } +std::string SubstraitParser::getSubFunctionName( + const std::string& subFuncSpec) const { + // Get the position of ":" in the function name. + std::size_t pos = subFuncSpec.find(":"); + if (pos == std::string::npos) { + return subFuncSpec; + } + return subFuncSpec.substr(0, pos); +} + +void SubstraitParser::getSubFunctionTypes( + const std::string& subFuncSpec, + std::vector& types) const { + // Get the position of ":" in the function name. + std::size_t pos = subFuncSpec.find(":"); + // Get the parameter types. + std::string funcTypes; + if (pos == std::string::npos) { + funcTypes = subFuncSpec; + } else { + if (pos == subFuncSpec.size() - 1) { + return; + } + funcTypes = subFuncSpec.substr(pos + 1); + } + // Split the types with delimiter. + std::string delimiter = "_"; + while ((pos = funcTypes.find(delimiter)) != std::string::npos) { + auto type = funcTypes.substr(0, pos); + if (type != "opt" && type != "req") { + types.emplace_back(type); + } + funcTypes.erase(0, pos + delimiter.length()); + } + types.emplace_back(funcTypes); +} + std::string SubstraitParser::findVeloxFunction( const std::unordered_map& functionMap, uint64_t id) const { std::string funcSpec = findFunctionSpec(functionMap, id); std::string_view funcName = getNameBeforeDelimiter(funcSpec, ":"); - return mapToVeloxFunction({funcName.begin(), funcName.end()}); + std::vector types; + getSubFunctionTypes(funcSpec, types); + bool isDecimal = false; + for (auto& type : types) { + if (type.find("dec") != std::string::npos) { + isDecimal = true; + break; + } + } + return mapToVeloxFunction({funcName.begin(), funcName.end()}, isDecimal); } std::string SubstraitParser::mapToVeloxFunction( - const std::string& substraitFunction) const { + const std::string& substraitFunction, + bool isDecimal) const { auto it = substraitVeloxFunctionMap_.find(substraitFunction); + if (isDecimal) { + if (substraitFunction == "add" || substraitFunction == "subtract" || + substraitFunction == "multiply" || substraitFunction == "divide" || + substraitFunction == "avg" || substraitFunction == "avg_merge" || + substraitFunction == "sum" || substraitFunction == "sum_merge" || + substraitFunction == "round") { + return "decimal_" + substraitFunction; + } + } if (it != substraitVeloxFunctionMap_.end()) { return it->second; } @@ -239,4 +352,19 @@ std::string SubstraitParser::mapToVeloxFunction( return substraitFunction; } +bool SubstraitParser::configSetInOptimization( + const ::substrait::extensions::AdvancedExtension& extension, + const std::string& config) const { + if (extension.has_optimization()) { + google::protobuf::StringValue msg; + extension.optimization().UnpackTo(&msg); + std::size_t pos = msg.value().find(config); + if ((pos != std::string::npos) && + (msg.value().substr(pos + config.size(), 1) == "1")) { + return true; + } + } + return false; +} + } // namespace facebook::velox::substrait diff --git a/velox/substrait/SubstraitParser.h b/velox/substrait/SubstraitParser.h index a76db95efc1d..5fd37b0c2196 100644 --- a/velox/substrait/SubstraitParser.h +++ b/velox/substrait/SubstraitParser.h @@ -25,6 +25,8 @@ #include "velox/substrait/proto/substrait/type.pb.h" #include "velox/substrait/proto/substrait/type_expressions.pb.h" +#include + namespace facebook::velox::substrait { /// This class contains some common functions used to parse Substrait @@ -37,14 +39,21 @@ class SubstraitParser { bool nullable; }; - /// Parse Substrait NamedStruct. - std::vector> parseNamedStruct( + /// Used to parse Substrait NamedStruct. + std::vector> parseNamedStruct( + const ::substrait::NamedStruct& namedStruct); + + /// Used to parse partition columns from Substrait NamedStruct. + std::vector parsePartitionColumns( const ::substrait::NamedStruct& namedStruct); /// Parse Substrait Type. std::shared_ptr parseType( const ::substrait::Type& substraitType); + // Parse substraitType type such as i32. + std::string parseType(const std::string& substraitType); + /// Parse Substrait ReferenceSegment. int32_t parseReferenceSegment( const ::substrait::Expression::ReferenceSegment& refSegment); @@ -70,14 +79,34 @@ class SubstraitParser { const std::unordered_map& functionMap, uint64_t id) const; - /// Find the Velox function name according to the function id + /// Extracts the function name for a function from specified compound name. + /// When the input is a simple name, it will be returned. + std::string getSubFunctionName(const std::string& functionSpec) const; + + /// This function is used get the types from the compound name. + void getSubFunctionTypes( + const std::string& subFuncSpec, + std::vector& types) const; + + /// Used to find the Velox function name according to the function id /// from a pre-constructed function map. std::string findVeloxFunction( const std::unordered_map& functionMap, uint64_t id) const; /// Map the Substrait function keyword into Velox function keyword. - std::string mapToVeloxFunction(const std::string& substraitFunction) const; + std::string mapToVeloxFunction( + const std::string& substraitFunction, + bool isDecimal) const; + + /// @brief Return whether a config is set as true in AdvancedExtension + /// optimization. + /// @param extension Substrait advanced extension. + /// @param config the key string of a config. + /// @return Whether the config is set as true. + bool configSetInOptimization( + const ::substrait::extensions::AdvancedExtension& extension, + const std::string& config) const; private: /// A map used for mapping Substrait function keywords into Velox functions' @@ -85,11 +114,43 @@ class SubstraitParser { /// keyword. For those functions with different names in Substrait and Velox, /// a mapping relation should be added here. std::unordered_map substraitVeloxFunctionMap_ = { - {"add", "plus"}, - {"subtract", "minus"}, - {"modulus", "mod"}, - {"not_equal", "neq"}, - {"equal", "eq"}}; + {"is_not_null", "isnotnull"}, /*Spark functions.*/ + {"is_null", "isnull"}, + {"equal", "equalto"}, + {"lt", "lessthan"}, + {"lte", "lessthanorequal"}, + {"gt", "greaterthan"}, + {"gte", "greaterthanorequal"}, + {"not_equal", "notequalto"}, + {"char_length", "length"}, + {"strpos", "instr"}, + {"ends_with", "endswith"}, + {"starts_with", "startswith"}, + {"datediff", "date_diff"}, + {"named_struct", "row_constructor"}, + {"bit_or", "bitwise_or_agg"}, + {"bit_or_merge", "bitwise_or_agg_merge"}, + {"bit_and", "bitwise_and_agg"}, + {"bit_and_merge", "bitwise_and_agg_merge"}, + {"modulus", "mod"} /*Presto functions.*/}; + + // The map is uesd for mapping substrait type. + // Key: type in function name. + // Value: substrait type name. + const std::unordered_map typeMap_ = { + {"bool", "BOOLEAN"}, + {"i8", "TINYINT"}, + {"i16", "SMALLINT"}, + {"i32", "INTEGER"}, + {"i64", "BIGINT"}, + {"fp32", "REAL"}, + {"fp64", "DOUBLE"}, + {"date", "DATE"}, + {"ts", "TIMESTAMP"}, + {"str", "VARCHAR"}, + {"vbin", "VARBINARY"}, + {"decShort", "SHORT_DECIMAL"}, + {"decLong", "LONG_DECIMAL"}}; }; } // namespace facebook::velox::substrait diff --git a/velox/substrait/SubstraitToVeloxExpr.cpp b/velox/substrait/SubstraitToVeloxExpr.cpp index ace7dafefb21..26550304ee37 100644 --- a/velox/substrait/SubstraitToVeloxExpr.cpp +++ b/velox/substrait/SubstraitToVeloxExpr.cpp @@ -17,6 +17,9 @@ #include "velox/substrait/SubstraitToVeloxExpr.h" #include "velox/substrait/TypeUtils.h" #include "velox/vector/FlatVector.h" +#include "velox/vector/VariantToVector.h" + +#include "velox/type/Timestamp.h" using namespace facebook::velox; namespace { @@ -91,6 +94,18 @@ ArrayVectorPtr makeArrayVector(const VectorPtr& elements) { elements); } +RowVectorPtr makeRowVector(const std::vector& children) { + std::vector> types; + types.resize(children.size()); + for (int i = 0; i < children.size(); i++) { + types[i] = children[i]->type(); + } + const size_t vectorSize = children.empty() ? 0 : children.front()->size(); + auto rowType = ROW(std::move(types)); + return std::make_shared( + children[0]->pool(), rowType, BufferPtr(nullptr), vectorSize, children); +} + ArrayVectorPtr makeEmptyArrayVector(memory::MemoryPool* pool) { BufferPtr offsets = allocateOffsets(1, pool); BufferPtr sizes = allocateOffsets(1, pool); @@ -98,6 +113,10 @@ ArrayVectorPtr makeEmptyArrayVector(memory::MemoryPool* pool) { pool, ARRAY(UNKNOWN()), nullptr, 1, offsets, sizes, nullptr); } +RowVectorPtr makeEmptyRowVector(memory::MemoryPool* pool) { + return makeRowVector({}); +} + template void setLiteralValue( const ::substrait::Expression::Literal& literal, @@ -110,8 +129,10 @@ void setLiteralValue( vector->set(index, StringView(literal.string())); } else if (literal.has_var_char()) { vector->set(index, StringView(literal.var_char().value())); + } else if (literal.has_binary()) { + vector->set(index, StringView(literal.binary())); } else { - VELOX_FAIL("Unexpected string literal"); + VELOX_FAIL("Unexpected string or binary literal"); } } else { vector->set(index, getLiteralValue(literal)); @@ -154,8 +175,23 @@ bool isNullOnFailure( } } +template +VectorPtr constructFlatVectorForStruct( + const ::substrait::Expression::Literal& child, + const vector_size_t size, + const TypePtr& type, + memory::MemoryPool* pool) { + VELOX_CHECK(type->isPrimitiveType()); + auto vector = BaseVector::create(type, size, pool); + using T = typename TypeTraits::NativeType; + auto flatVector = vector->as>(); + setLiteralValue(child, flatVector, 0); + return vector; +} + } // namespace +using facebook::velox::core::variantArrayToVector; namespace facebook::velox::substrait { std::shared_ptr @@ -166,20 +202,44 @@ SubstraitVeloxExprConverter::toVeloxExpr( switch (typeCase) { case ::substrait::Expression::FieldReference::ReferenceTypeCase:: kDirectReference: { - const auto& directRef = substraitField.direct_reference(); - int32_t colIdx = substraitParser_.parseReferenceSegment(directRef); + const auto& dRef = substraitField.direct_reference(); + VELOX_CHECK(dRef.has_struct_field(), "Struct field expected."); + int32_t colIdx = subParser_->parseReferenceSegment(dRef); + std::optional childIdx = std::nullopt; + if (dRef.struct_field().has_child()) { + childIdx = + subParser_->parseReferenceSegment(dRef.struct_field().child()); + } + + const auto& inputTypes = inputType->children(); const auto& inputNames = inputType->names(); const int64_t inputSize = inputNames.size(); - if (colIdx <= inputSize) { - const auto& inputTypes = inputType->children(); - // Convert type to row. + + if (colIdx >= inputSize) { + VELOX_FAIL("Missing the column with id '{}' .", colIdx); + } + + if (!childIdx.has_value()) { return std::make_shared( inputTypes[colIdx], std::make_shared(inputTypes[colIdx]), inputNames[colIdx]); } else { - VELOX_FAIL("Missing the column with id '{}' .", colIdx); + // Select a subfield in a struct by name. + if (auto inputColumnType = asRowType(inputTypes[colIdx])) { + if (childIdx.value() >= inputColumnType->size()) { + VELOX_FAIL("Missing the subfield with id '{}' .", childIdx.value()); + } + return std::make_shared( + inputColumnType->childAt(childIdx.value()), + std::make_shared( + inputTypes[colIdx], inputNames[colIdx]), + inputColumnType->nameOf(childIdx.value())); + } else { + VELOX_FAIL("RowType expected."); + } } + break; } default: VELOX_NYI( @@ -187,6 +247,36 @@ SubstraitVeloxExprConverter::toVeloxExpr( } } +core::TypedExprPtr SubstraitVeloxExprConverter::toExtractExpr( + const std::vector>& params, + const TypePtr& outputType) { + VELOX_CHECK_EQ(params.size(), 2); + auto functionArg = + std::dynamic_pointer_cast(params[0]); + if (functionArg) { + // Get the function argument. + auto variant = functionArg->value(); + if (!variant.hasValue()) { + VELOX_FAIL("Value expected in variant."); + } + // The first parameter specifies extracting from which field. + std::string from = variant.value(); + + // The second parameter is the function parameter. + std::vector> exprParams; + exprParams.reserve(1); + exprParams.emplace_back(params[1]); + auto iter = extractDatetimeFunctionMap_.find(from); + if (iter != extractDatetimeFunctionMap_.end()) { + return std::make_shared( + outputType, std::move(exprParams), iter->second); + } else { + VELOX_NYI("Extract from {} not supported.", from); + } + } + VELOX_FAIL("Constant is expected to be the first parameter in extract."); +} + core::TypedExprPtr SubstraitVeloxExprConverter::toVeloxExpr( const ::substrait::Expression::ScalarFunction& substraitFunc, const RowTypePtr& inputType) { @@ -195,14 +285,65 @@ core::TypedExprPtr SubstraitVeloxExprConverter::toVeloxExpr( for (const auto& sArg : substraitFunc.arguments()) { params.emplace_back(toVeloxExpr(sArg.value(), inputType)); } - const auto& veloxFunction = substraitParser_.findVeloxFunction( + const auto& veloxFunction = subParser_->findVeloxFunction( functionMap_, substraitFunc.function_reference()); std::string typeName = - substraitParser_.parseType(substraitFunc.output_type())->type; + subParser_->parseType(substraitFunc.output_type())->type; + + if (veloxFunction == "extract") { + return toExtractExpr(std::move(params), toVeloxType(typeName)); + } + return std::make_shared( toVeloxType(typeName), std::move(params), veloxFunction); } +std::shared_ptr +SubstraitVeloxExprConverter::literalsToConstantExpr( + const std::vector<::substrait::Expression::Literal>& literals) { + std::vector variants; + variants.reserve(literals.size()); + VELOX_CHECK_GE(literals.size(), 0, "List should have at least one item."); + std::optional literalType = std::nullopt; + for (const auto& literal : literals) { + auto veloxVariant = toVeloxExpr(literal)->value(); + if (!literalType.has_value()) { + literalType = veloxVariant.inferType(); + } + variants.emplace_back(veloxVariant); + } + VELOX_CHECK(literalType.has_value(), "Type expected."); + auto varArray = variant::array(variants); + ArrayVectorPtr arrayVector = + variantArrayToVector(varArray.inferType(), varArray.array(), pool_); + // Wrap the array vector into constant vector. + auto constantVector = + BaseVector::wrapInConstant(1 /*length*/, 0 /*index*/, arrayVector); + return std::make_shared(constantVector); +} + +core::TypedExprPtr SubstraitVeloxExprConverter::toVeloxExpr( + const ::substrait::Expression::SingularOrList& singularOrList, + const RowTypePtr& inputType) { + VELOX_CHECK( + singularOrList.options_size() > 0, "At least one option is expected."); + auto options = singularOrList.options(); + std::vector<::substrait::Expression::Literal> literals; + literals.reserve(options.size()); + for (const auto& option : options) { + VELOX_CHECK(option.has_literal(), "Literal is expected as option."); + literals.emplace_back(option.literal()); + } + + std::vector> params; + params.reserve(2); + // First param is the value, second param is the list. + params.emplace_back(toVeloxExpr(singularOrList.value(), inputType)); + params.emplace_back(literalsToConstantExpr(literals)); + return std::make_shared( + BOOLEAN(), std::move(params), "in"); +} + std::shared_ptr SubstraitVeloxExprConverter::toVeloxExpr( const ::substrait::Expression::Literal& substraitLit) { @@ -234,12 +375,13 @@ SubstraitVeloxExprConverter::toVeloxExpr( case ::substrait::Expression_Literal::LiteralTypeCase::kString: return std::make_shared( VARCHAR(), variant(substraitLit.string())); - case ::substrait::Expression_Literal::LiteralTypeCase::kNull: { - auto veloxType = - toVeloxType(substraitParser_.parseType(substraitLit.null())->type); + case ::substrait::Expression_Literal::LiteralTypeCase::kDate: return std::make_shared( - veloxType, variant::null(veloxType->kind())); - } + DATE(), variant(Date(substraitLit.date()))); + case ::substrait::Expression_Literal::LiteralTypeCase::kTimestamp: + return std::make_shared( + TIMESTAMP(), + variant(Timestamp::fromMicros(substraitLit.timestamp()))); case ::substrait::Expression_Literal::LiteralTypeCase::kVarChar: return std::make_shared( VARCHAR(), variant(substraitLit.var_char().value())); @@ -248,9 +390,47 @@ SubstraitVeloxExprConverter::toVeloxExpr( BaseVector::wrapInConstant(1, 0, literalsToArrayVector(substraitLit)); return std::make_shared(constantVector); } - case ::substrait::Expression_Literal::LiteralTypeCase::kDate: + case ::substrait::Expression_Literal::LiteralTypeCase::kBinary: return std::make_shared( - DATE(), variant(Date(substraitLit.date()))); + VARBINARY(), variant::binary(substraitLit.binary())); + case ::substrait::Expression_Literal::LiteralTypeCase::kStruct: { + auto constantVector = + BaseVector::wrapInConstant(1, 0, literalsToRowVector(substraitLit)); + return std::make_shared(constantVector); + } + case ::substrait::Expression_Literal::LiteralTypeCase::kDecimal: { + auto decimal = substraitLit.decimal().value(); + auto precision = substraitLit.decimal().precision(); + auto scale = substraitLit.decimal().scale(); + int128_t decimalValue; + memcpy(&decimalValue, decimal.c_str(), 16); + if (precision <= 18) { + auto type = DECIMAL(precision, scale); + return std::make_shared( + type, variant(static_cast(decimalValue))); + } else { + auto type = DECIMAL(precision, scale); + return std::make_shared( + type, + variant(HugeInt::build( + static_cast(decimalValue >> 64), + static_cast(decimalValue)))); + } + } + case ::substrait::Expression_Literal::LiteralTypeCase::kNull: { + auto veloxType = + toVeloxType(subParser_->parseType(substraitLit.null())->type); + if (veloxType->isShortDecimal()) { + return std::make_shared( + veloxType, variant::null(TypeKind::BIGINT)); + } else if (veloxType->isLongDecimal()) { + return std::make_shared( + veloxType, variant::null(TypeKind::HUGEINT)); + } else { + return std::make_shared( + veloxType, variant::null(veloxType->kind())); + } + } default: VELOX_NYI( "Substrait conversion not supported for type case '{}'", typeCase); @@ -292,7 +472,7 @@ ArrayVectorPtr SubstraitVeloxExprConverter::literalsToArrayVector( listLiteral, childSize, VARCHAR(), pool_)); case ::substrait::Expression_Literal::LiteralTypeCase::kNull: { auto veloxType = - toVeloxType(substraitParser_.parseType(listLiteral.null())->type); + toVeloxType(subParser_->parseType(listLiteral.null())->type); auto kind = veloxType->kind(); return makeArrayVector(VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( constructFlatVector, kind, listLiteral, childSize, veloxType, pool_)); @@ -324,19 +504,77 @@ ArrayVectorPtr SubstraitVeloxExprConverter::literalsToArrayVector( } } +RowVectorPtr SubstraitVeloxExprConverter::literalsToRowVector( + const ::substrait::Expression::Literal& structLiteral) { + auto childSize = structLiteral.struct_().fields().size(); + if (childSize == 0) { + return makeEmptyRowVector(pool_); + } + auto typeCase = structLiteral.struct_().fields(0).literal_type_case(); + switch (typeCase) { + case ::substrait::Expression_Literal::LiteralTypeCase::kBinary: { + std::vector vectors; + vectors.reserve(structLiteral.struct_().fields().size()); + for (auto& child : structLiteral.struct_().fields()) { + vectors.emplace_back(constructFlatVectorForStruct( + child, 1, VARBINARY(), pool_)); + } + return makeRowVector(vectors); + } + default: + VELOX_NYI( + "literalsToRowVector not supported for type case '{}'", typeCase); + } +} + core::TypedExprPtr SubstraitVeloxExprConverter::toVeloxExpr( const ::substrait::Expression::Cast& castExpr, const RowTypePtr& inputType) { - auto substraitType = substraitParser_.parseType(castExpr.type()); + auto substraitType = subParser_->parseType(castExpr.type()); auto type = toVeloxType(substraitType->type); bool nullOnFailure = isNullOnFailure(castExpr.failure_behavior()); std::vector inputs{ toVeloxExpr(castExpr.input(), inputType)}; - return std::make_shared(type, inputs, nullOnFailure); } +core::TypedExprPtr SubstraitVeloxExprConverter::toVeloxExpr( + const ::substrait::Expression::IfThen& ifThenExpr, + const RowTypePtr& inputType) { + VELOX_CHECK(ifThenExpr.ifs().size() > 0, "If clause expected."); + + // Params are concatenated conditions and results with an optional "else" at + // the end, e.g. {condition1, result1, condition2, result2,..else} + std::vector params; + // If and then expressions are in pairs. + params.reserve(ifThenExpr.ifs().size() * 2); + std::optional outputType = std::nullopt; + for (const auto& ifThen : ifThenExpr.ifs()) { + params.emplace_back(toVeloxExpr(ifThen.if_(), inputType)); + const auto& thenExpr = toVeloxExpr(ifThen.then(), inputType); + // Get output type from the first then expression. + if (!outputType.has_value()) { + outputType = thenExpr->type(); + } + params.emplace_back(thenExpr); + } + + if (ifThenExpr.has_else_()) { + params.reserve(1); + params.emplace_back(toVeloxExpr(ifThenExpr.else_(), inputType)); + } + + VELOX_CHECK(outputType.has_value(), "Output type should be set."); + if (ifThenExpr.ifs().size() == 1) { + // If there is only one if-then clause, use if expression. + return std::make_shared( + outputType.value(), std::move(params), "if"); + } + return std::make_shared( + outputType.value(), std::move(params), "switch"); +} + core::TypedExprPtr SubstraitVeloxExprConverter::toVeloxExpr( const ::substrait::Expression& substraitExpr, const RowTypePtr& inputType) { @@ -353,6 +591,8 @@ core::TypedExprPtr SubstraitVeloxExprConverter::toVeloxExpr( return toVeloxExpr(substraitExpr.cast(), inputType); case ::substrait::Expression::RexTypeCase::kIfThen: return toVeloxExpr(substraitExpr.if_then(), inputType); + case ::substrait::Expression::RexTypeCase::kSingularOrList: + return toVeloxExpr(substraitExpr.singular_or_list(), inputType); default: VELOX_NYI( "Substrait conversion not supported for Expression '{}'", typeCase); diff --git a/velox/substrait/SubstraitToVeloxExpr.h b/velox/substrait/SubstraitToVeloxExpr.h index dd5957e8e785..9e9c5b8420ec 100644 --- a/velox/substrait/SubstraitToVeloxExpr.h +++ b/velox/substrait/SubstraitToVeloxExpr.h @@ -18,7 +18,9 @@ #include "velox/core/Expressions.h" #include "velox/substrait/SubstraitParser.h" +#include "velox/type/StringView.h" #include "velox/vector/ComplexVector.h" +#include "velox/vector/FlatVector.h" namespace facebook::velox::substrait { @@ -34,6 +36,12 @@ class SubstraitVeloxExprConverter { const std::unordered_map& functionMap) : pool_(pool), functionMap_(functionMap) {} + /// Stores the variant and its type. + struct TypedVariant { + variant veloxVariant; + TypePtr variantType; + }; + /// Convert Substrait Field into Velox Field Expression. std::shared_ptr toVeloxExpr( const ::substrait::Expression::FieldReference& substraitField, @@ -44,12 +52,22 @@ class SubstraitVeloxExprConverter { const ::substrait::Expression::ScalarFunction& substraitFunc, const RowTypePtr& inputType); + /// Convert Substrait SingularOrList into Velox Expression. + core::TypedExprPtr toVeloxExpr( + const ::substrait::Expression::SingularOrList& singularOrList, + const RowTypePtr& inputType); + /// Convert Substrait CastExpression to Velox Expression. core::TypedExprPtr toVeloxExpr( const ::substrait::Expression::Cast& castExpr, const RowTypePtr& inputType); - /// Convert Substrait Literal into Velox Expression. + /// Create expression for extract. + std::shared_ptr toExtractExpr( + const std::vector>& params, + const TypePtr& outputType); + + /// Used to convert Substrait Literal into Velox Expression. std::shared_ptr toVeloxExpr( const ::substrait::Expression::Literal& substraitLit); @@ -63,21 +81,45 @@ class SubstraitVeloxExprConverter { const ::substrait::Expression::IfThen& substraitIfThen, const RowTypePtr& inputType); + /// Wrap a constant vector from literals with an array vector inside to create + /// the constant expression. + std::shared_ptr literalsToConstantExpr( + const std::vector<::substrait::Expression::Literal>& literals); + private: /// Convert list literal to ArrayVector. ArrayVectorPtr literalsToArrayVector( const ::substrait::Expression::Literal& listLiteral); + RowVectorPtr literalsToRowVector( + const ::substrait::Expression::Literal& structLiteral); + /// Memory pool. memory::MemoryPool* pool_; /// The Substrait parser used to convert Substrait representations into /// recognizable representations. - SubstraitParser substraitParser_; + std::shared_ptr subParser_ = + std::make_shared(); /// The map storing the relations between the function id and the function /// name. std::unordered_map functionMap_; + + // The map storing the Substrait extract function input field and velox + // function name. + std::unordered_map extractDatetimeFunctionMap_ = { + {"MILLISECOND", "millisecond"}, + {"SECOND", "second"}, + {"MINUTE", "minute"}, + {"HOUR", "hour"}, + {"DAY", "day"}, + {"DAY_OF_WEEK", "day_of_week"}, + {"DAY_OF_YEAR", "day_of_year"}, + {"MONTH", "month"}, + {"QUARTER", "quarter"}, + {"YEAR", "year"}, + {"YEAR_OF_WEEK", "year_of_week"}}; }; } // namespace facebook::velox::substrait diff --git a/velox/substrait/SubstraitToVeloxPlan.cpp b/velox/substrait/SubstraitToVeloxPlan.cpp index fc605420f737..71214a7d42f3 100644 --- a/velox/substrait/SubstraitToVeloxPlan.cpp +++ b/velox/substrait/SubstraitToVeloxPlan.cpp @@ -21,29 +21,6 @@ namespace facebook::velox::substrait { namespace { -core::AggregationNode::Step toAggregationStep( - const ::substrait::AggregateRel& sAgg) { - if (sAgg.measures().size() == 0) { - // When only groupings exist, set the phase to be Single. - return core::AggregationNode::Step::kSingle; - } - - // Use the first measure to set aggregation phase. - const auto& firstMeasure = sAgg.measures()[0]; - const auto& aggFunction = firstMeasure.measure(); - switch (aggFunction.phase()) { - case ::substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE: - return core::AggregationNode::Step::kPartial; - case ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_INTERMEDIATE: - return core::AggregationNode::Step::kIntermediate; - case ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT: - return core::AggregationNode::Step::kFinal; - case ::substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT: - return core::AggregationNode::Step::kSingle; - default: - VELOX_FAIL("Aggregate phase is not supported."); - } -} core::SortOrder toSortOrder(const ::substrait::SortField& sortField) { switch (sortField.direction()) { @@ -88,6 +65,123 @@ EmitInfo getEmitInfo( return emitInfo; } +template +// Get the lowest value for numeric type. +T getLowest() { + return std::numeric_limits::lowest(); +}; + +// Get the lowest value for string. +template <> +std::string getLowest() { + return ""; +}; + +// Get the max value for numeric type. +template +T getMax() { + return std::numeric_limits::max(); +}; + +// The max value will be used in BytesRange. Return empty string here instead. +template <> +std::string getMax() { + return ""; +}; + +// Substrait function names. +const std::string sIsNotNull = "is_not_null"; +const std::string sGte = "gte"; +const std::string sGt = "gt"; +const std::string sLte = "lte"; +const std::string sLt = "lt"; +const std::string sEqual = "equal"; +const std::string sOr = "or"; +const std::string sNot = "not"; + +// Substrait types. +const std::string sI32 = "i32"; +const std::string sI64 = "i64"; + +/// @brief Get the input type from both sides of join. +/// @param leftNode the plan node of left side. +/// @param rightNode the plan node of right side. +/// @return the input type. +RowTypePtr getJoinInputType( + const core::PlanNodePtr& leftNode, + const core::PlanNodePtr& rightNode) { + auto outputSize = + leftNode->outputType()->size() + rightNode->outputType()->size(); + std::vector outputNames; + std::vector> outputTypes; + outputNames.reserve(outputSize); + outputTypes.reserve(outputSize); + for (const auto& node : {leftNode, rightNode}) { + const auto& names = node->outputType()->names(); + outputNames.insert(outputNames.end(), names.begin(), names.end()); + const auto& types = node->outputType()->children(); + outputTypes.insert(outputTypes.end(), types.begin(), types.end()); + } + return std::make_shared( + std::move(outputNames), std::move(outputTypes)); +} + +/// @brief Get the direct output type of join. +/// @param leftNode the plan node of left side. +/// @param rightNode the plan node of right side. +/// @param joinType the join type. +/// @return the output type. +RowTypePtr getJoinOutputType( + const core::PlanNodePtr& leftNode, + const core::PlanNodePtr& rightNode, + const core::JoinType& joinType) { + // Decide output type. + // Output of right semi join cannot include columns from the left side. + bool outputMayIncludeLeftColumns = + !(core::isRightSemiFilterJoin(joinType) || + core::isRightSemiProjectJoin(joinType)); + + // Output of left semi and anti joins cannot include columns from the right + // side. + bool outputMayIncludeRightColumns = + !(core::isLeftSemiFilterJoin(joinType) || + core::isLeftSemiProjectJoin(joinType) || core::isAntiJoin(joinType)); + + if (outputMayIncludeLeftColumns && outputMayIncludeRightColumns) { + return getJoinInputType(leftNode, rightNode); + } + + if (outputMayIncludeLeftColumns) { + if (core::isLeftSemiProjectJoin(joinType)) { + auto outputSize = leftNode->outputType()->size() + 1; + std::vector outputNames = leftNode->outputType()->names(); + std::vector> outputTypes = + leftNode->outputType()->children(); + outputNames.emplace_back("exists"); + outputTypes.emplace_back(BOOLEAN()); + return std::make_shared( + std::move(outputNames), std::move(outputTypes)); + } else { + return leftNode->outputType(); + } + } + + if (outputMayIncludeRightColumns) { + if (core::isRightSemiProjectJoin(joinType)) { + auto outputSize = rightNode->outputType()->size() + 1; + std::vector outputNames = rightNode->outputType()->names(); + std::vector> outputTypes = + rightNode->outputType()->children(); + outputNames.emplace_back("exists"); + outputTypes.emplace_back(BOOLEAN()); + return std::make_shared( + std::move(outputNames), std::move(outputTypes)); + } else { + return rightNode->outputType(); + } + } + VELOX_FAIL("Output should include left or right columns."); +} } // namespace core::PlanNodePtr SubstraitVeloxPlanConverter::processEmit( @@ -109,6 +203,146 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::processEmit( } } +core::AggregationNode::Step SubstraitVeloxPlanConverter::toAggregationStep( + const ::substrait::AggregateRel& aggRel) { + if (aggRel.measures().size() == 0) { + // When only groupings exist, set the phase to be Single. + return core::AggregationNode::Step::kSingle; + } + + // Use the first measure to set aggregation phase. + const auto& firstMeasure = aggRel.measures()[0]; + const auto& aggFunction = firstMeasure.measure(); + switch (aggFunction.phase()) { + case ::substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE: + return core::AggregationNode::Step::kPartial; + case ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_INTERMEDIATE: + return core::AggregationNode::Step::kIntermediate; + case ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT: + return core::AggregationNode::Step::kFinal; + case ::substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT: + return core::AggregationNode::Step::kSingle; + default: + VELOX_FAIL("Aggregate phase is not supported."); + } +} + +core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( + const ::substrait::JoinRel& sJoin) { + if (!sJoin.has_left()) { + VELOX_FAIL("Left Rel is expected in JoinRel."); + } + if (!sJoin.has_right()) { + VELOX_FAIL("Right Rel is expected in JoinRel."); + } + + auto leftNode = toVeloxPlan(sJoin.left()); + auto rightNode = toVeloxPlan(sJoin.right()); + + // Map join type. + core::JoinType joinType; + bool isNullAwareAntiJoin = false; + switch (sJoin.type()) { + case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_INNER: + joinType = core::JoinType::kInner; + break; + case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_OUTER: + joinType = core::JoinType::kFull; + break; + case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT: + joinType = core::JoinType::kLeft; + break; + case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT: + joinType = core::JoinType::kRight; + break; + case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI: + // Determine the semi join type based on extracted information. + if (sJoin.has_advanced_extension() && + subParser_->configSetInOptimization( + sJoin.advanced_extension(), "isExistenceJoin=")) { + joinType = core::JoinType::kLeftSemiProject; + } else { + joinType = core::JoinType::kLeftSemiFilter; + } + break; + case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT_SEMI: + // Determine the semi join type based on extracted information. + if (sJoin.has_advanced_extension() && + subParser_->configSetInOptimization( + sJoin.advanced_extension(), "isExistenceJoin=")) { + joinType = core::JoinType::kRightSemiProject; + } else { + joinType = core::JoinType::kRightSemiFilter; + } + break; + case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_ANTI: { + // Determine the anti join type based on extracted information. + if (sJoin.has_advanced_extension() && + subParser_->configSetInOptimization( + sJoin.advanced_extension(), "isNullAwareAntiJoin=")) { + isNullAwareAntiJoin = true; + } + joinType = core::JoinType::kAnti; + break; + } + default: + VELOX_NYI("Unsupported Join type: {}", sJoin.type()); + } + + // extract join keys from join expression + std::vector leftExprs, + rightExprs; + extractJoinKeys(sJoin.expression(), leftExprs, rightExprs); + VELOX_CHECK_EQ(leftExprs.size(), rightExprs.size()); + size_t numKeys = leftExprs.size(); + + std::vector> leftKeys, + rightKeys; + leftKeys.reserve(numKeys); + rightKeys.reserve(numKeys); + auto inputRowType = getJoinInputType(leftNode, rightNode); + for (size_t i = 0; i < numKeys; ++i) { + leftKeys.emplace_back( + exprConverter_->toVeloxExpr(*leftExprs[i], inputRowType)); + rightKeys.emplace_back( + exprConverter_->toVeloxExpr(*rightExprs[i], inputRowType)); + } + + std::shared_ptr filter; + if (sJoin.has_post_join_filter()) { + filter = + exprConverter_->toVeloxExpr(sJoin.post_join_filter(), inputRowType); + } + + if (sJoin.has_advanced_extension() && + subParser_->configSetInOptimization( + sJoin.advanced_extension(), "isSMJ=")) { + // Create MergeJoinNode node + return std::make_shared( + nextPlanNodeId(), + joinType, + leftKeys, + rightKeys, + filter, + leftNode, + rightNode, + getJoinOutputType(leftNode, rightNode, joinType)); + + } else { + // Create HashJoinNode node + return std::make_shared( + nextPlanNodeId(), + joinType, + isNullAwareAntiJoin, + leftKeys, + rightKeys, + filter, + leftNode, + rightNode, + getJoinOutputType(leftNode, rightNode, joinType)); + } +} + core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( const ::substrait::AggregateRel& aggRel) { auto childNode = convertSingleInput<::substrait::AggregateRel>(aggRel); @@ -131,7 +365,7 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( aggregates.reserve(aggRel.measures().size()); for (const auto& measure : aggRel.measures()) { - core::FieldAccessTypedExprPtr mask; + core::FieldAccessTypedExprPtr mask = {}; ::substrait::Expression substraitAggMask = measure.filter(); // Get Aggregation Masks. if (measure.has_filter()) { @@ -140,9 +374,9 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( exprConverter_->toVeloxExpr(substraitAggMask, inputType)); } } - + aggregateMasks.push_back(mask); const auto& aggFunction = measure.measure(); - auto funcName = substraitParser_->findVeloxFunction( + auto funcName = subParser_->findVeloxFunction( functionMap_, aggFunction.function_reference()); std::vector aggParams; aggParams.reserve(aggFunction.arguments().size()); @@ -150,8 +384,8 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( aggParams.emplace_back( exprConverter_->toVeloxExpr(arg.value(), inputType)); } - auto aggVeloxType = toVeloxType( - substraitParser_->parseType(aggFunction.output_type())->type); + auto aggVeloxType = + toVeloxType(subParser_->parseType(aggFunction.output_type())->type); auto aggExpr = std::make_shared( aggVeloxType, std::move(aggParams), funcName); aggregates.emplace_back( @@ -167,7 +401,7 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( for (int idx = veloxGroupingExprs.size(); idx < veloxGroupingExprs.size() + aggRel.measures().size(); idx++) { - aggOutNames.emplace_back(substraitParser_->makeNodeName(planNodeId_, idx)); + aggOutNames.emplace_back(subParser_->makeNodeName(planNodeId_, idx)); } auto aggregationNode = std::make_shared( @@ -190,9 +424,8 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( const ::substrait::ProjectRel& projectRel) { auto childNode = convertSingleInput<::substrait::ProjectRel>(projectRel); - // Construct Velox Expressions. - auto projectExprs = projectRel.expressions(); + const auto& projectExprs = projectRel.expressions(); std::vector projectNames; std::vector expressions; projectNames.reserve(projectExprs.size()); @@ -217,8 +450,7 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( // Then, adding project expression related project names and expressions. for (const auto& expr : projectExprs) { expressions.emplace_back(exprConverter_->toVeloxExpr(expr, inputType)); - projectNames.emplace_back( - substraitParser_->makeNodeName(planNodeId_, colIdx)); + projectNames.emplace_back(subParser_->makeNodeName(planNodeId_, colIdx)); colIdx += 1; } @@ -247,6 +479,219 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( } } +core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( + const ::substrait::ExpandRel& expandRel) { + core::PlanNodePtr childNode; + if (expandRel.has_input()) { + childNode = toVeloxPlan(expandRel.input()); + } else { + VELOX_FAIL("Child Rel is expected in ExpandRel."); + } + + const auto& inputType = childNode->outputType(); + + std::vector> projectSetExprs; + projectSetExprs.reserve(expandRel.fields_size()); + + for (const auto& projections : expandRel.fields()) { + std::vector projectExprs; + projectExprs.reserve(projections.switching_field().duplicates_size()); + + for (const auto& projectExpr : projections.switching_field().duplicates()) { + if (projectExpr.has_selection()) { + auto expression = + exprConverter_->toVeloxExpr(projectExpr.selection(), inputType); + projectExprs.emplace_back(expression); + } else if (projectExpr.has_literal()) { + auto expression = exprConverter_->toVeloxExpr(projectExpr.literal()); + projectExprs.emplace_back(expression); + } else { + VELOX_FAIL( + "The project in Expand Operator only support field or literal."); + } + } + projectSetExprs.emplace_back(projectExprs); + } + + auto projectSize = expandRel.fields()[0].switching_field().duplicates_size(); + std::vector names; + names.reserve(projectSize); + for (int idx = 0; idx < projectSize; idx++) { + names.push_back(subParser_->makeNodeName(planNodeId_, idx)); + } + + return std::make_shared( + nextPlanNodeId(), projectSetExprs, std::move(names), childNode); +} + +const core::WindowNode::Frame createWindowFrame( + const ::substrait::Expression_WindowFunction_Bound& lower_bound, + const ::substrait::Expression_WindowFunction_Bound& upper_bound, + const ::substrait::WindowType& type) { + core::WindowNode::Frame frame; + switch (type) { + case ::substrait::WindowType::ROWS: + frame.type = core::WindowNode::WindowType::kRows; + break; + case ::substrait::WindowType::RANGE: + frame.type = core::WindowNode::WindowType::kRange; + break; + default: + VELOX_FAIL( + "the window type only support ROWS and RANGE, and the input type is ", + type); + } + + auto boundTypeConversion = + [](::substrait::Expression_WindowFunction_Bound boundType) + -> core::WindowNode::BoundType { + if (boundType.has_current_row()) { + return core::WindowNode::BoundType::kCurrentRow; + } else if (boundType.has_unbounded_following()) { + return core::WindowNode::BoundType::kUnboundedFollowing; + } else if (boundType.has_unbounded_preceding()) { + return core::WindowNode::BoundType::kUnboundedPreceding; + } else if (boundType.has_following()) { + return core::WindowNode::BoundType::kFollowing; + } else if (boundType.has_preceding()) { + return core::WindowNode::BoundType::kPreceding; + } else { + VELOX_FAIL("The BoundType is not supported."); + } + }; + frame.startType = boundTypeConversion(lower_bound); + switch (frame.startType) { + case core::WindowNode::BoundType::kPreceding: + // TODO: support non-literal expression. + frame.startValue = std::make_shared( + BIGINT(), variant(lower_bound.preceding().offset())); + break; + default: + frame.startValue = nullptr; + } + frame.endType = boundTypeConversion(upper_bound); + switch (frame.endType) { + // TODO: support non-literal expression. + case core::WindowNode::BoundType::kFollowing: + frame.endValue = std::make_shared( + BIGINT(), variant(upper_bound.following().offset())); + break; + default: + frame.endValue = nullptr; + } + return frame; +} + +core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( + const ::substrait::WindowRel& windowRel) { + core::PlanNodePtr childNode; + if (windowRel.has_input()) { + childNode = toVeloxPlan(windowRel.input()); + } else { + VELOX_FAIL("Child Rel is expected in WindowRel."); + } + + const auto& inputType = childNode->outputType(); + + // Parse measures and get the window expressions. + // Each measure represents one window expression. + bool ignoreNullKeys = false; + std::vector windowNodeFunctions; + std::vector windowColumnNames; + + windowNodeFunctions.reserve(windowRel.measures().size()); + for (const auto& smea : windowRel.measures()) { + const auto& windowFunction = smea.measure(); + std::string funcName = subParser_->findVeloxFunction( + functionMap_, windowFunction.function_reference()); + std::vector> windowParams; + windowParams.reserve(windowFunction.arguments().size()); + for (const auto& arg : windowFunction.arguments()) { + windowParams.emplace_back( + exprConverter_->toVeloxExpr(arg.value(), inputType)); + } + auto windowVeloxType = + toVeloxType(subParser_->parseType(windowFunction.output_type())->type); + auto windowCall = std::make_shared( + windowVeloxType, std::move(windowParams), funcName); + auto upperBound = windowFunction.upper_bound(); + auto lowerBound = windowFunction.lower_bound(); + auto type = windowFunction.window_type(); + + windowColumnNames.push_back(windowFunction.column_name()); + + windowNodeFunctions.push_back( + {std::move(windowCall), + createWindowFrame(lowerBound, upperBound, type), + ignoreNullKeys}); + } + + // Construct partitionKeys + std::vector partitionKeys; + const auto& partitions = windowRel.partition_expressions(); + partitionKeys.reserve(partitions.size()); + for (const auto& partition : partitions) { + auto expression = exprConverter_->toVeloxExpr(partition, inputType); + auto expr_field = + dynamic_cast(expression.get()); + VELOX_CHECK( + expr_field != nullptr, + " the partition key in Window Operator only support field") + + partitionKeys.emplace_back( + std::dynamic_pointer_cast( + expression)); + } + + std::vector sortingKeys; + std::vector sortingOrders; + + const auto& sorts = windowRel.sorts(); + sortingKeys.reserve(sorts.size()); + sortingOrders.reserve(sorts.size()); + + for (const auto& sort : sorts) { + switch (sort.direction()) { + case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST: + sortingOrders.emplace_back(core::kAscNullsFirst); + break; + case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST: + sortingOrders.emplace_back(core::kAscNullsLast); + break; + case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST: + sortingOrders.emplace_back(core::kDescNullsFirst); + break; + case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST: + sortingOrders.emplace_back(core::kDescNullsLast); + break; + default: + VELOX_FAIL("Sort direction is not support in WindowRel"); + } + + if (sort.has_expr()) { + auto expression = exprConverter_->toVeloxExpr(sort.expr(), inputType); + auto expr_field = + dynamic_cast(expression.get()); + VELOX_CHECK( + expr_field != nullptr, + " the sorting key in Window Operator only support field") + + sortingKeys.emplace_back( + std::dynamic_pointer_cast( + expression)); + } + } + + return std::make_shared( + nextPlanNodeId(), + partitionKeys, + sortingKeys, + sortingOrders, + windowColumnNames, + windowNodeFunctions, + childNode); +} + core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( const ::substrait::SortRel& sortRel) { auto childNode = convertSingleInput<::substrait::SortRel>(sortRel); @@ -293,7 +738,6 @@ SubstraitVeloxPlanConverter::processSortField( core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( const ::substrait::FilterRel& filterRel) { auto childNode = convertSingleInput<::substrait::FilterRel>(filterRel); - auto filterNode = std::make_shared( nextPlanNodeId(), exprConverter_->toVeloxExpr( @@ -307,6 +751,26 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( } } +bool isPushDownSupportedByFormat( + const dwio::common::FileFormat& format, + connector::hive::SubfieldFilters& subfieldFilters) { + switch (format) { + case dwio::common::FileFormat::PARQUET: + case dwio::common::FileFormat::ORC: + case dwio::common::FileFormat::DWRF: + case dwio::common::FileFormat::RC: + case dwio::common::FileFormat::RC_TEXT: + case dwio::common::FileFormat::RC_BINARY: + case dwio::common::FileFormat::TEXT: + case dwio::common::FileFormat::JSON: + case dwio::common::FileFormat::ALPHA: + case dwio::common::FileFormat::UNKNOWN: + default: + break; + } + return true; +} + core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( const ::substrait::FetchRel& fetchRel) { core::PlanNodePtr childNode; @@ -351,8 +815,7 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( } core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( - const ::substrait::ReadRel& readRel, - std::shared_ptr& splitInfo) { + const ::substrait::ReadRel& readRel) { // emit is not allowed in TableScanNode and ValuesNode related // outputs if (readRel.has_common()) { @@ -360,23 +823,42 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( !readRel.common().has_emit(), "Emit not supported for ValuesNode and TableScanNode related Substrait plans."); } + + // Check if the ReadRel specifies an input of stream. If yes, the pre-built + // input node will be used as the data source. + auto splitInfo = std::make_shared(); + auto streamIdx = streamIsInput(readRel); + if (streamIdx >= 0) { + if (inputNodesMap_.find(streamIdx) == inputNodesMap_.end()) { + VELOX_FAIL( + "Could not find source index {} in input nodes map.", streamIdx); + } + auto streamNode = inputNodesMap_[streamIdx]; + splitInfo->isStream = true; + splitInfoMap_[streamNode->id()] = splitInfo; + return streamNode; + } + + // Otherwise, will create TableScan node for ReadRel. // Get output names and types. std::vector colNameList; std::vector veloxTypeList; + std::vector isPartitionColumns; if (readRel.has_base_schema()) { const auto& baseSchema = readRel.base_schema(); colNameList.reserve(baseSchema.names().size()); for (const auto& name : baseSchema.names()) { colNameList.emplace_back(name); } - auto substraitTypeList = substraitParser_->parseNamedStruct(baseSchema); + auto substraitTypeList = subParser_->parseNamedStruct(baseSchema); + isPartitionColumns = subParser_->parsePartitionColumns(baseSchema); veloxTypeList.reserve(substraitTypeList.size()); for (const auto& substraitType : substraitTypeList) { veloxTypeList.emplace_back(toVeloxType(substraitType->type)); } } - // Parse local files + // Parse local files and construct split info. if (readRel.has_local_files()) { using SubstraitFileFormatCase = ::substrait::ReadRel_LocalFiles_FileOrFiles::FileFormatCase; @@ -385,13 +867,16 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( splitInfo->starts.reserve(fileList.size()); splitInfo->lengths.reserve(fileList.size()); for (const auto& file : fileList) { - // Expect all files to share the same index. + // Expect all Partitions share the same index. splitInfo->partitionIndex = file.partition_index(); splitInfo->paths.emplace_back(file.uri_file()); splitInfo->starts.emplace_back(file.start()); splitInfo->lengths.emplace_back(file.length()); switch (file.file_format_case()) { case SubstraitFileFormatCase::kOrc: + splitInfo->format = dwio::common::FileFormat::ORC; + break; + case SubstraitFileFormatCase::kDwrf: splitInfo->format = dwio::common::FileFormat::DWRF; break; case SubstraitFileFormatCase::kParquet: @@ -402,7 +887,6 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( } } } - // Do not hard-code connector ID and allow for connectors other than Hive. static const std::string kHiveConnectorId = "test-hive"; @@ -417,14 +901,66 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( connector::hive::SubfieldFilters{}, nullptr); } else { - connector::hive::SubfieldFilters filters = - toVeloxFilter(colNameList, veloxTypeList, readRel.filter()); + // Flatten the conditions connected with 'and'. + std::vector<::substrait::Expression_ScalarFunction> scalarFunctions; + std::vector<::substrait::Expression_SingularOrList> singularOrLists; + std::vector<::substrait::Expression_IfThen> ifThens; + flattenConditions( + readRel.filter(), scalarFunctions, singularOrLists, ifThens); + + std::unordered_map> rangeRecorders; + for (uint32_t idx = 0; idx < veloxTypeList.size(); idx++) { + rangeRecorders[idx] = std::make_shared(); + } + + // Separate the filters to be two parts. The subfield part can be + // pushed down. + std::vector<::substrait::Expression_ScalarFunction> subfieldFunctions; + std::vector<::substrait::Expression_SingularOrList> subfieldrOrLists; + + std::vector<::substrait::Expression_ScalarFunction> remainingFunctions; + std::vector<::substrait::Expression_SingularOrList> remainingrOrLists; + + separateFilters( + rangeRecorders, + scalarFunctions, + subfieldFunctions, + remainingFunctions, + singularOrLists, + subfieldrOrLists, + remainingrOrLists); + + // Create subfield filters based on the constructed filter info map. + connector::hive::SubfieldFilters subfieldFilters = toSubfieldFilters( + colNameList, veloxTypeList, subfieldFunctions, subfieldrOrLists); + // Connect the remaining filters with 'and'. + std::shared_ptr remainingFilter; + + if (!isPushDownSupportedByFormat(splitInfo->format, subfieldFilters)) { + // A subfieldFilter is not supported by the format, + // mark all filter as remaining filters. + subfieldFilters.clear(); + remainingFilter = connectWithAnd( + colNameList, + veloxTypeList, + scalarFunctions, + singularOrLists, + ifThens); + } else { + remainingFilter = connectWithAnd( + colNameList, + veloxTypeList, + remainingFunctions, + remainingrOrLists, + ifThens); + } + tableHandle = std::make_shared( kHiveConnectorId, "hive_table", filterPushdownEnabled, - std::move(filters), - nullptr); + std::move(subfieldFilters), + remainingFilter); } // Get assignments and out names. @@ -433,11 +969,12 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( std::unordered_map> assignments; for (int idx = 0; idx < colNameList.size(); idx++) { - auto outName = substraitParser_->makeNodeName(planNodeId_, idx); + auto outName = subParser_->makeNodeName(planNodeId_, idx); + auto columnType = isPartitionColumns[idx] + ? connector::hive::HiveColumnHandle::ColumnType::kPartitionKey + : connector::hive::HiveColumnHandle::ColumnType::kRegular; assignments[outName] = std::make_shared( - colNameList[idx], - connector::hive::HiveColumnHandle::ColumnType::kRegular, - veloxTypeList[idx]); + colNameList[idx], columnType, veloxTypeList[idx]); outNames.emplace_back(outName); } auto outputType = ROW(std::move(outNames), std::move(veloxTypeList)); @@ -445,11 +982,14 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( if (readRel.has_virtual_table()) { return toVeloxPlan(readRel, outputType); } else { - return std::make_shared( + auto tableScanNode = std::make_shared( nextPlanNodeId(), std::move(outputType), std::move(tableHandle), std::move(assignments)); + // Set split info map. + splitInfoMap_[tableScanNode->id()] = splitInfo; + return tableScanNode; } } @@ -525,12 +1065,14 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( if (rel.has_filter()) { return toVeloxPlan(rel.filter()); } + if (rel.has_join()) { + return toVeloxPlan(rel.join()); + } if (rel.has_read()) { - auto splitInfo = std::make_shared(); - - auto planNode = toVeloxPlan(rel.read(), splitInfo); - splitInfoMap_[planNode->id()] = splitInfo; - return planNode; + return toVeloxPlan(rel.read()); + } + if (rel.has_sort()) { + return toVeloxPlan(rel.sort()); } if (rel.has_fetch()) { return toVeloxPlan(rel.fetch()); @@ -538,6 +1080,15 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( if (rel.has_sort()) { return toVeloxPlan(rel.sort()); } + if (rel.has_expand()) { + return toVeloxPlan(rel.expand()); + } + if (rel.has_fetch()) { + return toVeloxPlan(rel.fetch()); + } + if (rel.has_window()) { + return toVeloxPlan(rel.window()); + } VELOX_NYI("Substrait conversion not supported for Rel."); } @@ -557,13 +1108,10 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( VELOX_CHECK( checkTypeExtension(substraitPlan), "The type extension only have unknown type.") - // Construct the function map based on the Substrait representation. + // Construct the function map based on the Substrait representation, + // and initialize the expression converter with it. constructFunctionMap(substraitPlan); - // Construct the expression converter. - exprConverter_ = - std::make_shared(pool_, functionMap_); - // In fact, only one RelRoot or Rel is expected here. VELOX_CHECK_EQ(substraitPlan.relations_size(), 1); const auto& rel = substraitPlan.relations(0); @@ -582,188 +1130,940 @@ std::string SubstraitVeloxPlanConverter::nextPlanNodeId() { planNodeId_++; return id; } - -// This class contains the needed infos for Filter Pushdown. -// TODO: Support different types here. -class FilterInfo { - public: - // Used to set the left bound. - void setLeft(double left, bool isExclusive) { - left_ = left; - leftExclusive_ = isExclusive; - if (!isInitialized_) { - isInitialized_ = true; +void SubstraitVeloxPlanConverter::constructFunctionMap( + const ::substrait::Plan& substraitPlan) { + // Construct the function map based on the Substrait representation. + for (const auto& sExtension : substraitPlan.extensions()) { + if (!sExtension.has_extension_function()) { + continue; } + const auto& sFmap = sExtension.extension_function(); + auto id = sFmap.function_anchor(); + auto name = sFmap.name(); + functionMap_[id] = name; } + exprConverter_ = + std::make_shared(pool_, functionMap_); +} - // Used to set the right bound. - void setRight(double right, bool isExclusive) { - right_ = right; - rightExclusive_ = isExclusive; - if (!isInitialized_) { - isInitialized_ = true; +void SubstraitVeloxPlanConverter::flattenConditions( + const ::substrait::Expression& substraitFilter, + std::vector<::substrait::Expression_ScalarFunction>& scalarFunctions, + std::vector<::substrait::Expression_SingularOrList>& singularOrLists, + std::vector<::substrait::Expression_IfThen>& ifThens) { + auto typeCase = substraitFilter.rex_type_case(); + switch (typeCase) { + case ::substrait::Expression::RexTypeCase::kScalarFunction: { + auto sFunc = substraitFilter.scalar_function(); + auto filterNameSpec = subParser_->findFunctionSpec( + functionMap_, sFunc.function_reference()); + // TODO: Only and relation is supported here. + if (subParser_->getSubFunctionName(filterNameSpec) == "and") { + for (const auto& sCondition : sFunc.arguments()) { + flattenConditions( + sCondition.value(), scalarFunctions, singularOrLists, ifThens); + } + } else { + scalarFunctions.emplace_back(sFunc); + } + break; + } + case ::substrait::Expression::RexTypeCase::kSingularOrList: { + singularOrLists.emplace_back(substraitFilter.singular_or_list()); + break; } + case ::substrait::Expression::RexTypeCase::kIfThen: { + ifThens.emplace_back(substraitFilter.if_then()); + break; + } + default: + VELOX_NYI("GetFlatConditions not supported for type '{}'", typeCase); } +} + +std::string SubstraitVeloxPlanConverter::findFuncSpec(uint64_t id) { + return subParser_->findFunctionSpec(functionMap_, id); +} - // Will fordis Null value if called once. - void forbidsNull() { - nullAllowed_ = false; - if (!isInitialized_) { - isInitialized_ = true; +int32_t SubstraitVeloxPlanConverter::streamIsInput( + const ::substrait::ReadRel& sRead) { + if (sRead.has_local_files()) { + const auto& fileList = sRead.local_files().items(); + if (fileList.size() == 0) { + VELOX_FAIL("At least one file path is expected."); } - } - // Return the initialization status. - bool isInitialized() { - return isInitialized_ ? true : false; - } + // The stream input will be specified with the format of + // "iterator:${index}". + std::string filePath = fileList[0].uri_file(); + std::string prefix = "iterator:"; + std::size_t pos = filePath.find(prefix); + if (pos == std::string::npos) { + return -1; + } - // The left bound. - std::optional left_ = std::nullopt; - // The right bound. - std::optional right_ = std::nullopt; - // The Null allowing. - bool nullAllowed_ = true; - // If true, left bound will be exclusive. - bool leftExclusive_ = false; - // If true, right bound will be exclusive. - bool rightExclusive_ = false; + // Get the index. + std::string idxStr = filePath.substr(pos + prefix.size(), filePath.size()); + try { + return stoi(idxStr); + } catch (const std::exception& err) { + VELOX_FAIL(err.what()); + } + } + if (validationMode_) { + return -1; + } + VELOX_FAIL("Local file is expected."); +} - private: - bool isInitialized_ = false; -}; +void SubstraitVeloxPlanConverter::extractJoinKeys( + const ::substrait::Expression& joinExpression, + std::vector& leftExprs, + std::vector& rightExprs) { + std::vector expressions; + expressions.push_back(&joinExpression); + while (!expressions.empty()) { + auto visited = expressions.back(); + expressions.pop_back(); + if (visited->rex_type_case() == + ::substrait::Expression::RexTypeCase::kScalarFunction) { + const auto& funcName = + subParser_->getSubFunctionName(subParser_->findVeloxFunction( + functionMap_, visited->scalar_function().function_reference())); + const auto& args = visited->scalar_function().arguments(); + if (funcName == "and") { + expressions.push_back(&args[0].value()); + expressions.push_back(&args[1].value()); + } else if (funcName == "eq" || funcName == "equalto") { + VELOX_CHECK(std::all_of( + args.cbegin(), + args.cend(), + [](const ::substrait::FunctionArgument& arg) { + return arg.value().has_selection(); + })); + leftExprs.push_back(&args[0].value().selection()); + rightExprs.push_back(&args[1].value().selection()); + } else { + VELOX_NYI("Join condition {} not supported.", funcName); + } + } else { + VELOX_FAIL( + "Unable to parse from join expression: {}", + joinExpression.DebugString()); + } + } +} -connector::hive::SubfieldFilters SubstraitVeloxPlanConverter::toVeloxFilter( +connector::hive::SubfieldFilters SubstraitVeloxPlanConverter::toSubfieldFilters( const std::vector& inputNameList, const std::vector& inputTypeList, - const ::substrait::Expression& substraitFilter) { - connector::hive::SubfieldFilters filters; - // A map between the column index and the FilterInfo for that column. - std::unordered_map> colInfoMap; - for (int idx = 0; idx < inputNameList.size(); idx++) { + const std::vector<::substrait::Expression_ScalarFunction>& scalarFunctions, + const std::vector<::substrait::Expression_SingularOrList>& + singularOrLists) { + std::unordered_map> colInfoMap; + // A map between the column index and the FilterInfo. + for (uint32_t idx = 0; idx < inputTypeList.size(); idx++) { colInfoMap[idx] = std::make_shared(); } - std::vector<::substrait::Expression_ScalarFunction> scalarFunctions; - flattenConditions(substraitFilter, scalarFunctions); // Construct the FilterInfo for the related column. for (const auto& scalarFunction : scalarFunctions) { - auto filterNameSpec = substraitParser_->findFunctionSpec( + auto filterNameSpec = subParser_->findFunctionSpec( functionMap_, scalarFunction.function_reference()); - auto filterName = getNameBeforeDelimiter(filterNameSpec, ":"); - int32_t colIdx; - // TODO: Add different types' support here. - double val; - for (auto& arg : scalarFunction.arguments()) { - auto argExpr = arg.value(); - auto typeCase = argExpr.rex_type_case(); - switch (typeCase) { - case ::substrait::Expression::RexTypeCase::kSelection: { - auto sel = argExpr.selection(); - // TODO: Only direct reference is considered here. - auto dRef = sel.direct_reference(); - colIdx = substraitParser_->parseReferenceSegment(dRef); - break; - } - case ::substrait::Expression::RexTypeCase::kLiteral: { - auto sLit = argExpr.literal(); - // TODO: Only double is considered here. - val = sLit.fp64(); - break; + auto filterName = subParser_->getSubFunctionName(filterNameSpec); + if (filterName == sNot) { + VELOX_CHECK(scalarFunction.arguments().size() == 1); + auto expr = scalarFunction.arguments()[0].value(); + if (expr.has_scalar_function()) { + // Set its chid to filter info with reverse enabled. + setFilterMap( + scalarFunction.arguments()[0].value().scalar_function(), + inputTypeList, + colInfoMap, + true); + } else { + // TODO: support push down of Not In. + VELOX_NYI("Scalar function expected."); + } + continue; + } + + if (filterName == sOr) { + VELOX_CHECK(scalarFunction.arguments().size() == 2); + VELOX_CHECK(std::all_of( + scalarFunction.arguments().cbegin(), + scalarFunction.arguments().cend(), + [](const ::substrait::FunctionArgument& arg) { + return arg.value().has_scalar_function() || + arg.value().has_singular_or_list(); + })); + // Set the chidren functions to filter info. They should be + // effective to the same field. + for (const auto& arg : scalarFunction.arguments()) { + auto expr = arg.value(); + if (expr.has_scalar_function()) { + setFilterMap( + arg.value().scalar_function(), inputTypeList, colInfoMap); + } else if (expr.has_singular_or_list()) { + setSingularListValues(expr.singular_or_list(), colInfoMap); + } else { + VELOX_NYI("Scalar function or SingularOrList expected."); } - default: - VELOX_NYI( - "Substrait conversion not supported for arg type '{}'", typeCase); } + continue; } - if (filterName == "is_not_null") { - colInfoMap[colIdx]->forbidsNull(); - } else if (filterName == "gte") { - colInfoMap[colIdx]->setLeft(val, false); - } else if (filterName == "gt") { - colInfoMap[colIdx]->setLeft(val, true); - } else if (filterName == "lte") { - colInfoMap[colIdx]->setRight(val, false); - } else if (filterName == "lt") { - colInfoMap[colIdx]->setRight(val, true); + + setFilterMap(scalarFunction, inputTypeList, colInfoMap); + } + + for (const auto& list : singularOrLists) { + setSingularListValues(list, colInfoMap); + } + return mapToFilters(inputNameList, inputTypeList, colInfoMap); +} + +bool SubstraitVeloxPlanConverter::fieldOrWithLiteral( + const ::google::protobuf::RepeatedPtrField<::substrait::FunctionArgument>& + arguments, + uint32_t& fieldIndex) { + if (arguments.size() == 1) { + if (arguments[0].value().has_selection()) { + // Only field exists. + fieldIndex = subParser_->parseReferenceSegment( + arguments[0].value().selection().direct_reference()); + return true; } else { - VELOX_NYI( - "Substrait conversion not supported for filter name '{}'", - filterName); + return false; } } - // Construct the Filters. - for (int idx = 0; idx < inputNameList.size(); idx++) { - auto filterInfo = colInfoMap[idx]; - double leftBound; - double rightBound; - bool leftUnbounded = true; - bool rightUnbounded = true; - bool leftExclusive = false; - bool rightExclusive = false; - if (filterInfo->isInitialized()) { - if (filterInfo->left_) { - leftUnbounded = false; - leftBound = filterInfo->left_.value(); - leftExclusive = filterInfo->leftExclusive_; + if (arguments.size() != 2) { + // Not the field and literal combination. + return false; + } + bool fieldExists = false; + bool literalExists = false; + for (const auto& param : arguments) { + auto typeCase = param.value().rex_type_case(); + switch (typeCase) { + case ::substrait::Expression::RexTypeCase::kSelection: + fieldIndex = subParser_->parseReferenceSegment( + param.value().selection().direct_reference()); + fieldExists = true; + break; + case ::substrait::Expression::RexTypeCase::kLiteral: + literalExists = true; + break; + default: + break; + } + } + // Whether the field and literal both exist. + return fieldExists && literalExists; +} + +bool SubstraitVeloxPlanConverter::chidrenFunctionsOnSameField( + const ::substrait::Expression_ScalarFunction& function) { + // Get the column indices of the chidren functions. + std::vector colIndices; + for (const auto& arg : function.arguments()) { + if (arg.value().has_scalar_function()) { + auto scalarFunction = arg.value().scalar_function(); + for (const auto& param : scalarFunction.arguments()) { + if (param.value().has_selection()) { + auto field = param.value().selection(); + VELOX_CHECK(field.has_direct_reference()); + int32_t colIdx = + subParser_->parseReferenceSegment(field.direct_reference()); + colIndices.emplace_back(colIdx); + } } - if (filterInfo->right_) { - rightUnbounded = false; - rightBound = filterInfo->right_.value(); - rightExclusive = filterInfo->rightExclusive_; + } else if (arg.value().has_singular_or_list()) { + auto singularOrList = arg.value().singular_or_list(); + int32_t colIdx = getColumnIndexFromSingularOrList(singularOrList); + colIndices.emplace_back(colIdx); + } else { + return false; + } + } + + if (std::all_of(colIndices.begin(), colIndices.end(), [&](uint32_t idx) { + return idx == colIndices[0]; + })) { + // All indices are the same. + return true; + } + return false; +} + +bool SubstraitVeloxPlanConverter::canPushdownCommonFunction( + const ::substrait::Expression_ScalarFunction& scalarFunction, + const std::string& filterName, + uint32_t& fieldIdx) { + // Condtions can be pushed down. + std::unordered_set supportedCommonFunctions = { + sIsNotNull, sGte, sGt, sLte, sLt, sEqual}; + + bool canPushdown = false; + if (supportedCommonFunctions.find(filterName) != + supportedCommonFunctions.end() && + fieldOrWithLiteral(scalarFunction.arguments(), fieldIdx)) { + // The arg should be field or field with literal. + canPushdown = true; + } + return canPushdown; +} + +bool SubstraitVeloxPlanConverter::canPushdownNot( + const ::substrait::Expression_ScalarFunction& scalarFunction, + const std::unordered_map>& + rangeRecorders) { + VELOX_CHECK( + scalarFunction.arguments().size() == 1, + "Only one arg is expected for Not."); + auto notArg = scalarFunction.arguments()[0]; + if (!notArg.value().has_scalar_function()) { + // Not for a Boolean Literal or Or List is not supported curretly. + // It can be pushed down with an AlwaysTrue or AlwaysFalse Range. + return false; + } + + auto argFunction = subParser_->findFunctionSpec( + functionMap_, notArg.value().scalar_function().function_reference()); + auto functionName = subParser_->getSubFunctionName(argFunction); + + std::unordered_set supportedNotFunctions = { + sGte, sGt, sLte, sLt, sEqual}; + + uint32_t fieldIdx; + bool isFieldOrWithLiteral = fieldOrWithLiteral( + notArg.value().scalar_function().arguments(), fieldIdx); + + if (supportedNotFunctions.find(functionName) != supportedNotFunctions.end() && + isFieldOrWithLiteral && + rangeRecorders.at(fieldIdx)->setCertainRangeForFunction( + functionName, true /*reverse*/)) { + return true; + } + return false; +} + +bool SubstraitVeloxPlanConverter::canPushdownOr( + const ::substrait::Expression_ScalarFunction& scalarFunction, + const std::unordered_map>& + rangeRecorders) { + // OR Conditon whose chidren functions are on different columns is not + // supported to be pushed down. + if (!chidrenFunctionsOnSameField(scalarFunction)) { + return false; + } + + std::unordered_set supportedOrFunctions = { + sIsNotNull, sGte, sGt, sLte, sLt, sEqual}; + + for (const auto& arg : scalarFunction.arguments()) { + if (arg.value().has_scalar_function()) { + auto nameSpec = subParser_->findFunctionSpec( + functionMap_, arg.value().scalar_function().function_reference()); + auto functionName = subParser_->getSubFunctionName(nameSpec); + + uint32_t fieldIdx; + bool isFieldOrWithLiteral = fieldOrWithLiteral( + arg.value().scalar_function().arguments(), fieldIdx); + if (supportedOrFunctions.find(functionName) == + supportedOrFunctions.end() || + !isFieldOrWithLiteral || + !rangeRecorders.at(fieldIdx)->setCertainRangeForFunction( + functionName, false /*reverse*/, true /*forOrRelation*/)) { + // The arg should be field or field with literal. + return false; + } + } else if (arg.value().has_singular_or_list()) { + auto singularOrList = arg.value().singular_or_list(); + if (!canPushdownSingularOrList(singularOrList, true)) { + return false; + } + uint32_t fieldIdx = getColumnIndexFromSingularOrList(singularOrList); + // Disable IN pushdown for int-like types. + if (!rangeRecorders.at(fieldIdx)->setInRange(true /*forOrRelation*/)) { + return false; } - bool nullAllowed = filterInfo->nullAllowed_; - filters[common::Subfield(inputNameList[idx])] = - std::make_unique( - leftBound, - leftUnbounded, - leftExclusive, - rightBound, - rightUnbounded, - rightExclusive, - nullAllowed); + } else { + // Or relation betweeen other expressions is not supported to be pushded + // down currently. + return false; } } - return filters; + return true; } -void SubstraitVeloxPlanConverter::flattenConditions( - const ::substrait::Expression& substraitFilter, - std::vector<::substrait::Expression_ScalarFunction>& scalarFunctions) { - auto typeCase = substraitFilter.rex_type_case(); - switch (typeCase) { - case ::substrait::Expression::RexTypeCase::kScalarFunction: { - auto sFunc = substraitFilter.scalar_function(); - auto filterNameSpec = substraitParser_->findFunctionSpec( - functionMap_, sFunc.function_reference()); - // TODO: Only and relation is supported here. - if (getNameBeforeDelimiter(filterNameSpec, ":") == "and") { - for (const auto& sCondition : sFunc.arguments()) { - flattenConditions(sCondition.value(), scalarFunctions); - } +void SubstraitVeloxPlanConverter::separateFilters( + const std::unordered_map>& + rangeRecorders, + const std::vector<::substrait::Expression_ScalarFunction>& scalarFunctions, + std::vector<::substrait::Expression_ScalarFunction>& subfieldFunctions, + std::vector<::substrait::Expression_ScalarFunction>& remainingFunctions, + const std::vector<::substrait::Expression_SingularOrList>& singularOrLists, + std::vector<::substrait::Expression_SingularOrList>& subfieldOrLists, + std::vector<::substrait::Expression_SingularOrList>& remainingOrLists) { + for (const auto& singularOrList : singularOrLists) { + if (!canPushdownSingularOrList(singularOrList)) { + remainingOrLists.emplace_back(singularOrList); + continue; + } + uint32_t colIdx = getColumnIndexFromSingularOrList(singularOrList); + if (rangeRecorders.at(colIdx)->setInRange()) { + subfieldOrLists.emplace_back(singularOrList); + } else { + remainingOrLists.emplace_back(singularOrList); + } + } + + for (const auto& scalarFunction : scalarFunctions) { + auto filterNameSpec = subParser_->findFunctionSpec( + functionMap_, scalarFunction.function_reference()); + auto filterName = subParser_->getSubFunctionName(filterNameSpec); + if (filterName != sNot && filterName != sOr) { + // Check if the condition is supported to be pushed down. + uint32_t fieldIdx; + if (canPushdownCommonFunction(scalarFunction, filterName, fieldIdx) && + rangeRecorders.at(fieldIdx)->setCertainRangeForFunction(filterName)) { + subfieldFunctions.emplace_back(scalarFunction); } else { - scalarFunctions.emplace_back(sFunc); + remainingFunctions.emplace_back(scalarFunction); } - break; + continue; + } + + // Check whether NOT and OR functions can be pushed down. + // If yes, the scalar function will be added into the subfield functions. + bool supported = false; + if (filterName == sNot) { + supported = canPushdownNot(scalarFunction, rangeRecorders); + } else if (filterName == sOr) { + supported = canPushdownOr(scalarFunction, rangeRecorders); } + + if (supported) { + subfieldFunctions.emplace_back(scalarFunction); + } else { + remainingFunctions.emplace_back(scalarFunction); + } + } +} + +bool SubstraitVeloxPlanConverter::RangeRecorder::setCertainRangeForFunction( + const std::string& functionName, + bool reverse, + bool forOrRelation) { + if (functionName == sLt || functionName == sLte) { + if (reverse) { + return setLeftBound(forOrRelation); + } else { + return setRightBound(forOrRelation); + } + } + if (functionName == sGt || functionName == sGte) { + if (reverse) { + return setRightBound(forOrRelation); + } else { + return setLeftBound(forOrRelation); + } + } + if (functionName == sEqual) { + if (reverse) { + // Not equal means lt or gt. + return setMultiRange(); + } else { + return setLeftBound(forOrRelation) && setRightBound(forOrRelation); + } + } + if (functionName == sOr) { + if (reverse) { + // Not supported. + return false; + } else { + return setMultiRange(); + } + } + if (functionName == sIsNotNull) { + if (reverse) { + // Not supported. + return false; + } else { + // Is not null can always coexist with the other range. + return true; + } + } + return false; +} + +template +void SubstraitVeloxPlanConverter::setColInfoMap( + const std::string& filterName, + uint32_t colIdx, + std::optional literalVariant, + bool reverse, + std::unordered_map>& colInfoMap) { + if (filterName == sIsNotNull) { + if (reverse) { + VELOX_NYI("Reverse not supported for filter name '{}'", filterName); + } + colInfoMap[colIdx]->forbidsNull(); + return; + } + + if (filterName == sGte) { + if (reverse) { + colInfoMap[colIdx]->setUpper(literalVariant, true); + } else { + colInfoMap[colIdx]->setLower(literalVariant, false); + } + return; + } + + if (filterName == sGt) { + if (reverse) { + colInfoMap[colIdx]->setUpper(literalVariant, false); + } else { + colInfoMap[colIdx]->setLower(literalVariant, true); + } + return; + } + + if (filterName == sLte) { + if (reverse) { + colInfoMap[colIdx]->setLower(literalVariant, true); + } else { + colInfoMap[colIdx]->setUpper(literalVariant, false); + } + return; + } + + if (filterName == sLt) { + if (reverse) { + colInfoMap[colIdx]->setLower(literalVariant, false); + } else { + colInfoMap[colIdx]->setUpper(literalVariant, true); + } + return; + } + + if (filterName == sEqual) { + if (reverse) { + colInfoMap[colIdx]->setNotValue(literalVariant); + } else { + colInfoMap[colIdx]->setLower(literalVariant, false); + colInfoMap[colIdx]->setUpper(literalVariant, false); + } + return; + } + VELOX_NYI("SetColInfoMap not supported for filter name '{}'", filterName); +} + +void SubstraitVeloxPlanConverter::setFilterMap( + const ::substrait::Expression_ScalarFunction& scalarFunction, + const std::vector& inputTypeList, + std::unordered_map>& colInfoMap, + bool reverse) { + auto nameSpec = subParser_->findFunctionSpec( + functionMap_, scalarFunction.function_reference()); + auto functionName = subParser_->getSubFunctionName(nameSpec); + + // Extract the column index and column bound from the scalar function. + std::optional colIdx; + std::optional<::substrait::Expression_Literal> substraitLit; + std::vector typeCases; + for (const auto& param : scalarFunction.arguments()) { + auto typeCase = param.value().rex_type_case(); + switch (typeCase) { + case ::substrait::Expression::RexTypeCase::kSelection: + typeCases.emplace_back("kSelection"); + colIdx = subParser_->parseReferenceSegment( + param.value().selection().direct_reference()); + break; + case ::substrait::Expression::RexTypeCase::kLiteral: + typeCases.emplace_back("kLiteral"); + substraitLit = param.value().literal(); + break; + default: + VELOX_NYI( + "Substrait conversion not supported for arg type '{}'", typeCase); + } + } + + std::unordered_map functionRevertMap = { + {sLt, sGt}, {sGt, sLt}, {sGte, sLte}, {sLte, sGte}}; + + // Handle "123 < q1" type expression case + if (typeCases.size() > 1 && + (typeCases[0] == "kLiteral" && typeCases[1] == "kSelection") && + functionRevertMap.find(functionName) != functionRevertMap.end()) { + // change the function name: lt => gt, gt => lt, gte => lte, lte => gte + functionName = functionRevertMap[functionName]; + } + + if (!colIdx.has_value()) { + VELOX_NYI("Column index is expected in subfield filters creation."); + } + + // Set the extracted bound to the specific column. + uint32_t colIdxVal = colIdx.value(); + auto inputType = inputTypeList[colIdxVal]; + std::optional val; + if (inputType->isShortDecimal()) { + if (substraitLit) { + auto decimal = substraitLit.value().decimal().value(); + auto precision = substraitLit.value().decimal().precision(); + auto scale = substraitLit.value().decimal().scale(); + int128_t decimalValue; + memcpy(&decimalValue, decimal.c_str(), 16); + auto type = DECIMAL(precision, scale); + val = variant(static_cast(decimalValue)); + } + setColInfoMap(functionName, colIdxVal, val, reverse, colInfoMap); + } + switch (inputType->kind()) { + case TypeKind::INTEGER: + if (substraitLit) { + val = variant(substraitLit.value().i32()); + } + setColInfoMap(functionName, colIdxVal, val, reverse, colInfoMap); + break; + case TypeKind::BIGINT: + if (substraitLit) { + val = variant(substraitLit.value().i64()); + } + setColInfoMap(functionName, colIdxVal, val, reverse, colInfoMap); + break; + case TypeKind::DOUBLE: + if (substraitLit) { + val = variant(substraitLit.value().fp64()); + } + setColInfoMap(functionName, colIdxVal, val, reverse, colInfoMap); + break; + case TypeKind::BOOLEAN: + if (substraitLit) { + val = variant(substraitLit.value().boolean()); + } + setColInfoMap(functionName, colIdxVal, val, reverse, colInfoMap); + break; + case TypeKind::VARCHAR: + if (substraitLit) { + val = variant(substraitLit.value().string()); + } + setColInfoMap( + functionName, colIdxVal, val, reverse, colInfoMap); + break; + case TypeKind::DATE: + if (substraitLit) { + val = variant(Date(substraitLit.value().date())); + } + setColInfoMap(functionName, colIdxVal, val, reverse, colInfoMap); + break; default: - VELOX_NYI("GetFlatConditions not supported for type '{}'", typeCase); + VELOX_NYI( + "Subfield filters creation not supported for input type '{}'", + inputType); } } -void SubstraitVeloxPlanConverter::constructFunctionMap( - const ::substrait::Plan& substraitPlan) { - // Construct the function map based on the Substrait representation. - for (const auto& sExtension : substraitPlan.extensions()) { - if (!sExtension.has_extension_function()) { - continue; +template +void SubstraitVeloxPlanConverter::createNotEqualFilter( + variant notVariant, + bool nullAllowed, + std::vector>& colFilters) { + using NativeType = typename RangeTraits::NativeType; + using RangeType = typename RangeTraits::RangeType; + // Value > lower + std::unique_ptr lowerFilter = std::make_unique( + notVariant.value(), /*lower*/ + false, /*lowerUnbounded*/ + true, /*lowerExclusive*/ + getMax(), /*upper*/ + true, /*upperUnbounded*/ + false, /*upperExclusive*/ + nullAllowed); /*nullAllowed*/ + colFilters.emplace_back(std::move(lowerFilter)); + + // Value < upper + std::unique_ptr upperFilter = std::make_unique( + getLowest(), /*lower*/ + true, /*lowerUnbounded*/ + false, /*lowerExclusive*/ + notVariant.value(), /*upper*/ + false, /*upperUnbounded*/ + true, /*upperExclusive*/ + nullAllowed); /*nullAllowed*/ + colFilters.emplace_back(std::move(upperFilter)); +} + +template +void SubstraitVeloxPlanConverter::setInFilter( + const std::vector& variants, + bool nullAllowed, + const std::string& inputName, + connector::hive::SubfieldFilters& filters) {} + +template <> +void SubstraitVeloxPlanConverter::setInFilter( + const std::vector& variants, + bool nullAllowed, + const std::string& inputName, + connector::hive::SubfieldFilters& filters) { + std::vector values; + values.reserve(variants.size()); + for (const auto& variant : variants) { + double value = variant.value(); + values.emplace_back(value); + } + filters[common::Subfield(inputName, true)] = + common::createDoubleValues(values, nullAllowed); +} + +template <> +void SubstraitVeloxPlanConverter::setInFilter( + const std::vector& variants, + bool nullAllowed, + const std::string& inputName, + connector::hive::SubfieldFilters& filters) { + std::vector values; + values.reserve(variants.size()); + for (const auto& variant : variants) { + int64_t value = variant.value(); + values.emplace_back(value); + } + filters[common::Subfield(inputName, true)] = + common::createBigintValues(values, nullAllowed); +} + +template <> +void SubstraitVeloxPlanConverter::setInFilter( + const std::vector& variants, + bool nullAllowed, + const std::string& inputName, + connector::hive::SubfieldFilters& filters) { + // Use bigint values for int type. + std::vector values; + values.reserve(variants.size()); + for (const auto& variant : variants) { + // Use the matched type to get value from variant. + int64_t value = variant.value(); + values.emplace_back(value); + } + filters[common::Subfield(inputName, true)] = + common::createBigintValues(values, nullAllowed); +} + +template <> +void SubstraitVeloxPlanConverter::setInFilter( + const std::vector& variants, + bool nullAllowed, + const std::string& inputName, + connector::hive::SubfieldFilters& filters) { + // Use bigint values for small int type. + std::vector values; + values.reserve(variants.size()); + for (const auto& variant : variants) { + // Use the matched type to get value from variant. + int64_t value = variant.value(); + values.emplace_back(value); + } + filters[common::Subfield(inputName, true)] = + common::createBigintValues(values, nullAllowed); +} + +template <> +void SubstraitVeloxPlanConverter::setInFilter( + const std::vector& variants, + bool nullAllowed, + const std::string& inputName, + connector::hive::SubfieldFilters& filters) { + // Use bigint values for tiny int type. + std::vector values; + values.reserve(variants.size()); + for (const auto& variant : variants) { + // Use the matched type to get value from variant. + int64_t value = variant.value(); + values.emplace_back(value); + } + filters[common::Subfield(inputName, true)] = + common::createBigintValues(values, nullAllowed); +} + +template <> +void SubstraitVeloxPlanConverter::setInFilter( + const std::vector& variants, + bool nullAllowed, + const std::string& inputName, + connector::hive::SubfieldFilters& filters) { + // Use bigint values for int type. + std::vector values; + values.reserve(variants.size()); + for (const auto& variant : variants) { + // Use int32 to get value from date variant. + int64_t value = variant.value(); + values.emplace_back(value); + } + filters[common::Subfield(inputName, true)] = + common::createBigintValues(values, nullAllowed); +} + +template <> +void SubstraitVeloxPlanConverter::setInFilter( + const std::vector& variants, + bool nullAllowed, + const std::string& inputName, + connector::hive::SubfieldFilters& filters) { + std::vector values; + values.reserve(variants.size()); + for (const auto& variant : variants) { + std::string value = variant.value(); + values.emplace_back(value); + } + filters[common::Subfield(inputName, true)] = + std::make_unique(values, nullAllowed); +} + +template +void SubstraitVeloxPlanConverter::setSubfieldFilter( + std::vector> colFilters, + const std::string& inputName, + bool nullAllowed, + connector::hive::SubfieldFilters& filters) { + using MultiRangeType = typename RangeTraits::MultiRangeType; + + if (colFilters.size() == 1) { + filters[common::Subfield(inputName, true)] = std::move(colFilters[0]); + } else if (colFilters.size() > 1) { + // BigintMultiRange should have been sorted + if (colFilters[0]->kind() == common::FilterKind::kBigintRange) { + std::sort( + colFilters.begin(), + colFilters.end(), + [](const auto& a, const auto& b) { + return dynamic_cast(a.get())->lower() < + dynamic_cast(b.get())->lower(); + }); } - const auto& sFmap = sExtension.extension_function(); - auto id = sFmap.function_anchor(); - auto name = sFmap.name(); - functionMap_[id] = name; + filters[common::Subfield(inputName, true)] = + std::make_unique(std::move(colFilters), nullAllowed); + } +} + +template +void SubstraitVeloxPlanConverter::constructSubfieldFilters( + uint32_t colIdx, + const std::string& inputName, + const std::shared_ptr& filterInfo, + connector::hive::SubfieldFilters& filters) { + using NativeType = typename RangeTraits::NativeType; + using RangeType = typename RangeTraits::RangeType; + using MultiRangeType = typename RangeTraits::MultiRangeType; + + if (!filterInfo->isInitialized()) { + return; + } + + uint32_t rangeSize = std::max( + filterInfo->lowerBounds_.size(), filterInfo->upperBounds_.size()); + bool nullAllowed = filterInfo->nullAllowed_; + + // Handle 'in' filter. + if (filterInfo->valuesVector_.size() > 0) { + // To filter out null is a default behaviour of Spark IN expression. + nullAllowed = false; + setInFilter( + filterInfo->valuesVector_, nullAllowed, inputName, filters); + // Currently, In cannot coexist with other filter conditions + // due to multirange is in 'OR' relation but 'AND' is needed. + VELOX_CHECK( + rangeSize == 0, + "LowerBounds or upperBounds conditons cannot be supported after IN filter."); + VELOX_CHECK( + !filterInfo->notValue_.has_value(), + "Not equal cannot be supported after IN filter."); + return; + } + + // Construct the Filters. + std::vector> colFilters; + + // Handle not(equal) filter. + if (filterInfo->notValue_) { + variant notVariant = filterInfo->notValue_.value(); + createNotEqualFilter( + notVariant, filterInfo->nullAllowed_, colFilters); + // Currently, Not-equal cannot coexist with other filter conditions + // due to multirange is in 'OR' relation but 'AND' is needed. + VELOX_CHECK( + rangeSize == 0, + "LowerBounds or upperBounds conditons cannot be supported after not-equal filter."); + filters[common::Subfield(inputName, true)] = + std::make_unique(std::move(colFilters), nullAllowed); + return; + } + + // Handle null filtering. + if (rangeSize == 0 && !nullAllowed) { + std::unique_ptr filter = + std::make_unique(); + filters[common::Subfield(inputName, true)] = std::move(filter); + return; + } + + // Handle other filter ranges. + NativeType lowerBound; + if constexpr (KIND == facebook::velox::TypeKind::BIGINT) { + lowerBound = DecimalUtil::kShortDecimalMin; + } else { + lowerBound = getLowest(); + } + + NativeType upperBound; + if constexpr (KIND == facebook::velox::TypeKind::BIGINT) { + upperBound = DecimalUtil::kShortDecimalMax; + } else { + upperBound = getMax(); } + + bool lowerUnbounded = true; + bool upperUnbounded = true; + bool lowerExclusive = false; + bool upperExclusive = false; + + for (uint32_t idx = 0; idx < rangeSize; idx++) { + if (idx < filterInfo->lowerBounds_.size() && + filterInfo->lowerBounds_[idx]) { + lowerUnbounded = false; + variant lowerVariant = filterInfo->lowerBounds_[idx].value(); + + lowerBound = lowerVariant.value(); + + lowerExclusive = filterInfo->lowerExclusives_[idx]; + } + if (idx < filterInfo->upperBounds_.size() && + filterInfo->upperBounds_[idx]) { + upperUnbounded = false; + variant upperVariant = filterInfo->upperBounds_[idx].value(); + upperBound = upperVariant.value(); + + upperExclusive = filterInfo->upperExclusives_[idx]; + } + std::unique_ptr filter = std::make_unique( + lowerBound, + lowerUnbounded, + lowerExclusive, + upperBound, + upperUnbounded, + upperExclusive, + nullAllowed); + colFilters.emplace_back(std::move(filter)); + } + + // Set the SubfieldFilter. + setSubfieldFilter( + std::move(colFilters), inputName, filterInfo->nullAllowed_, filters); } bool SubstraitVeloxPlanConverter::checkTypeExtension( @@ -781,9 +2081,172 @@ bool SubstraitVeloxPlanConverter::checkTypeExtension( return true; } -const std::string& SubstraitVeloxPlanConverter::findFunction( - uint64_t id) const { - return substraitParser_->findFunctionSpec(functionMap_, id); +connector::hive::SubfieldFilters SubstraitVeloxPlanConverter::mapToFilters( + const std::vector& inputNameList, + const std::vector& inputTypeList, + std::unordered_map> colInfoMap) { + // Construct the subfield filters based on the filter info map. + connector::hive::SubfieldFilters filters; + for (uint32_t colIdx = 0; colIdx < inputNameList.size(); colIdx++) { + auto inputType = inputTypeList[colIdx]; + switch (inputType->kind()) { + case TypeKind::TINYINT: + constructSubfieldFilters( + colIdx, inputNameList[colIdx], colInfoMap[colIdx], filters); + break; + case TypeKind::SMALLINT: + constructSubfieldFilters( + colIdx, inputNameList[colIdx], colInfoMap[colIdx], filters); + break; + case TypeKind::INTEGER: + constructSubfieldFilters( + colIdx, inputNameList[colIdx], colInfoMap[colIdx], filters); + break; + case TypeKind::BIGINT: + constructSubfieldFilters( + colIdx, inputNameList[colIdx], colInfoMap[colIdx], filters); + break; + case TypeKind::DOUBLE: + constructSubfieldFilters( + colIdx, inputNameList[colIdx], colInfoMap[colIdx], filters); + break; + case TypeKind::BOOLEAN: + constructSubfieldFilters( + colIdx, inputNameList[colIdx], colInfoMap[colIdx], filters); + break; + case TypeKind::VARCHAR: + constructSubfieldFilters( + colIdx, inputNameList[colIdx], colInfoMap[colIdx], filters); + break; + case TypeKind::DATE: + constructSubfieldFilters( + colIdx, inputNameList[colIdx], colInfoMap[colIdx], filters); + break; + default: + VELOX_NYI( + "Subfield filters creation not supported for input type '{}'", + inputType); + } + } + return filters; +} + +core::TypedExprPtr SubstraitVeloxPlanConverter::connectWithAnd( + std::vector inputNameList, + std::vector inputTypeList, + const std::vector<::substrait::Expression_ScalarFunction>& scalarFunctions, + const std::vector<::substrait::Expression_SingularOrList>& singularOrLists, + const std::vector<::substrait::Expression_IfThen>& ifThens) { + if (scalarFunctions.size() == 0 && singularOrLists.size() == 0 && + ifThens.size() == 0) { + return nullptr; + } + auto inputType = ROW(std::move(inputNameList), std::move(inputTypeList)); + + // Filter for scalar functions. + std::vector> allFilters; + for (auto scalar : scalarFunctions) { + auto filter = exprConverter_->toVeloxExpr(scalar, inputType); + if (filter != nullptr) { + allFilters.emplace_back(filter); + } + } + for (auto orList : singularOrLists) { + auto filter = exprConverter_->toVeloxExpr(orList, inputType); + if (filter != nullptr) { + allFilters.emplace_back(filter); + } + } + for (auto ifThen : ifThens) { + auto filter = exprConverter_->toVeloxExpr(ifThen, inputType); + if (filter != nullptr) { + allFilters.emplace_back(filter); + } + } + VELOX_CHECK_GT(allFilters.size(), 0, "One filter should be valid.") + std::shared_ptr andFilter = allFilters[0]; + for (auto i = 1; i < allFilters.size(); i++) { + andFilter = connectWithAnd(andFilter, allFilters[i]); + } + return andFilter; +} + +core::TypedExprPtr SubstraitVeloxPlanConverter::connectWithAnd( + core::TypedExprPtr leftExpr, + core::TypedExprPtr rightExpr) { + std::vector params; + params.reserve(2); + params.emplace_back(leftExpr); + params.emplace_back(rightExpr); + return std::make_shared( + BOOLEAN(), std::move(params), "and"); +} + +bool SubstraitVeloxPlanConverter::canPushdownSingularOrList( + const ::substrait::Expression_SingularOrList& singularOrList, + bool disableIntLike) { + VELOX_CHECK( + singularOrList.options_size() > 0, "At least one option is expected."); + // Check whether the value is field. + bool hasField = singularOrList.value().has_selection(); + auto options = singularOrList.options(); + for (const auto& option : options) { + VELOX_CHECK(option.has_literal(), "Literal is expected as option."); + auto type = option.literal().literal_type_case(); + // Only BigintValues and BytesValues are supported. + if (type != ::substrait::Expression_Literal::LiteralTypeCase::kI32 && + type != ::substrait::Expression_Literal::LiteralTypeCase::kI64 && + type != ::substrait::Expression_Literal::LiteralTypeCase::kString) { + return false; + } + // BigintMultiRange can only accept BigintRange, so disableIntLike is set to + // true for OR pushdown of int-like types. + if (disableIntLike && + (type == ::substrait::Expression_Literal::LiteralTypeCase::kI32 || + type == ::substrait::Expression_Literal::LiteralTypeCase::kI64)) { + return false; + } + } + return hasField; +} + +uint32_t SubstraitVeloxPlanConverter::getColumnIndexFromSingularOrList( + const ::substrait::Expression_SingularOrList& singularOrList) { + // Get the column index. + ::substrait::Expression_FieldReference selection; + if (singularOrList.value().has_scalar_function()) { + selection = singularOrList.value() + .scalar_function() + .arguments()[0] + .value() + .selection(); + } else if (singularOrList.value().has_selection()) { + selection = singularOrList.value().selection(); + } else { + VELOX_FAIL("Unsupported type in IN pushdown."); + } + return subParser_->parseReferenceSegment(selection.direct_reference()); +} + +void SubstraitVeloxPlanConverter::setSingularListValues( + const ::substrait::Expression_SingularOrList& singularOrList, + std::unordered_map>& colInfoMap) { + VELOX_CHECK( + singularOrList.options_size() > 0, "At least one option is expected."); + // Get the column index. + uint32_t colIdx = getColumnIndexFromSingularOrList(singularOrList); + + // Get the value list. + auto options = singularOrList.options(); + std::vector variants; + variants.reserve(options.size()); + for (const auto& option : options) { + VELOX_CHECK(option.has_literal(), "Literal is expected as option."); + variants.emplace_back( + exprConverter_->toVeloxExpr(option.literal())->value()); + } + // Set the value list to filter info. + colInfoMap[colIdx]->setValues(variants); } } // namespace facebook::velox::substrait diff --git a/velox/substrait/SubstraitToVeloxPlan.h b/velox/substrait/SubstraitToVeloxPlan.h index d334fb343cc6..1ae9479e6b19 100644 --- a/velox/substrait/SubstraitToVeloxPlan.h +++ b/velox/substrait/SubstraitToVeloxPlan.h @@ -19,32 +19,46 @@ #include "velox/connectors/hive/HiveConnector.h" #include "velox/core/PlanNode.h" #include "velox/substrait/SubstraitToVeloxExpr.h" +#include "velox/substrait/TypeUtils.h" namespace facebook::velox::substrait { +struct SplitInfo { + /// Whether the split comes from arrow array stream node. + bool isStream = false; + + /// The Partition index. + u_int32_t partitionIndex; + + /// The file paths to be scanned. + std::vector paths; + + /// The file starts in the scan. + std::vector starts; + + /// The lengths to be scanned. + std::vector lengths; + + /// The file format of the files to be scanned. + dwio::common::FileFormat format; +}; /// This class is used to convert the Substrait plan into Velox plan. class SubstraitVeloxPlanConverter { public: - explicit SubstraitVeloxPlanConverter(memory::MemoryPool* pool) - : pool_(pool) {} - struct SplitInfo { - /// The Partition index. - u_int32_t partitionIndex; + SubstraitVeloxPlanConverter( + memory::MemoryPool* pool, + bool validationMode = false) + : pool_(pool), validationMode_(validationMode) {} + /// Used to convert Substrait ExpandRel into Velox PlanNode. + core::PlanNodePtr toVeloxPlan(const ::substrait::ExpandRel& expandRel); - /// The file paths to be scanned. - std::vector paths; + /// Used to convert Substrait SortRel into Velox PlanNode. + core::PlanNodePtr toVeloxPlan(const ::substrait::WindowRel& windowRel); - /// The file starts in the scan. - std::vector starts; + /// Used to convert Substrait JoinRel into Velox PlanNode. + core::PlanNodePtr toVeloxPlan(const ::substrait::JoinRel& joinRel); - /// The lengths to be scanned. - std::vector lengths; - - /// The file format of the files to be scanned. - dwio::common::FileFormat format; - }; - - /// Convert Substrait AggregateRel into Velox PlanNode. + /// Used to convert Substrait AggregateRel into Velox PlanNode. core::PlanNodePtr toVeloxPlan(const ::substrait::AggregateRel& aggRel); /// Convert Substrait ProjectRel into Velox PlanNode. @@ -53,14 +67,6 @@ class SubstraitVeloxPlanConverter { /// Convert Substrait FilterRel into Velox PlanNode. core::PlanNodePtr toVeloxPlan(const ::substrait::FilterRel& filterRel); - /// Convert Substrait ReadRel into Velox PlanNode. - /// Index: the index of the partition this item belongs to. - /// Starts: the start positions in byte to read from the items. - /// Lengths: the lengths in byte to read from the items. - core::PlanNodePtr toVeloxPlan( - const ::substrait::ReadRel& readRel, - std::shared_ptr& splitInfo); - /// Convert Substrait FetchRel into Velox LimitNode or TopNNode according the /// different input of fetchRel. core::PlanNodePtr toVeloxPlan(const ::substrait::FetchRel& fetchRel); @@ -70,26 +76,33 @@ class SubstraitVeloxPlanConverter { const ::substrait::ReadRel& readRel, const RowTypePtr& type); - /// Convert Substrait Rel into Velox PlanNode. - core::PlanNodePtr toVeloxPlan(const ::substrait::Rel& rel); - - /// Convert Substrait RelRoot into Velox PlanNode. - core::PlanNodePtr toVeloxPlan(const ::substrait::RelRoot& root); - /// Convert Substrait SortRel into Velox OrderByNode. core::PlanNodePtr toVeloxPlan(const ::substrait::SortRel& sortRel); - /// Convert Substrait Plan into Velox PlanNode. - core::PlanNodePtr toVeloxPlan(const ::substrait::Plan& substraitPlan); - /// Check the Substrait type extension only has one unknown extension. bool checkTypeExtension(const ::substrait::Plan& substraitPlan); - /// Construct the function map between the index and the Substrait function - /// name. + /// Convert Substrait ReadRel into Velox PlanNode. + /// Index: the index of the partition this item belongs to. + /// Starts: the start positions in byte to read from the items. + /// Lengths: the lengths in byte to read from the items. + core::PlanNodePtr toVeloxPlan(const ::substrait::ReadRel& sRead); + + /// Used to convert Substrait Rel into Velox PlanNode. + core::PlanNodePtr toVeloxPlan(const ::substrait::Rel& sRel); + + /// Used to convert Substrait RelRoot into Velox PlanNode. + core::PlanNodePtr toVeloxPlan(const ::substrait::RelRoot& sRoot); + + /// Used to convert Substrait Plan into Velox PlanNode. + core::PlanNodePtr toVeloxPlan(const ::substrait::Plan& substraitPlan); + + /// Used to construct the function map between the index + /// and the Substrait function name. Initialize the expression + /// converter based on the constructed function map. void constructFunctionMap(const ::substrait::Plan& substraitPlan); - /// Return the function map used by this plan converter. + /// Will return the function map used by this plan converter. const std::unordered_map& getFunctionMap() const { return functionMap_; } @@ -100,13 +113,6 @@ class SubstraitVeloxPlanConverter { return splitInfoMap_; } - /// Looks up a function by ID and returns function name if found. Throws if - /// function with specified ID doesn't exist. Returns a compound - /// function specification consisting of the function name and the input - /// types. The format is as follows: :__..._ - const std::string& findFunction(uint64_t id) const; - /// Integrate Substrait emit feature. Here a given 'substrait::RelCommon' /// is passed and check if emit is defined for this relation. Basically a /// ProjectNode is added on top of 'noEmitNode' to represent output order @@ -116,30 +122,206 @@ class SubstraitVeloxPlanConverter { const ::substrait::RelCommon& relCommon, const core::PlanNodePtr& noEmitNode); - private: - /// Returns unique ID to use for plan node. Produces sequential numbers - /// starting from zero. - std::string nextPlanNodeId(); + /// Used to insert certain plan node as input. The plan node + /// id will start from the setted one. + void insertInputNode( + uint64_t inputIdx, + const std::shared_ptr& inputNode, + int planNodeId) { + inputNodesMap_[inputIdx] = inputNode; + planNodeId_ = planNodeId; + } - /// Used to convert Substrait Filter into Velox SubfieldFilters which will - /// be used in TableScan. - connector::hive::SubfieldFilters toVeloxFilter( - const std::vector& inputNameList, - const std::vector& inputTypeList, - const ::substrait::Expression& substraitFilter); + /// Used to check if ReadRel specifies an input of stream. + /// If yes, the index of input stream will be returned. + /// If not, -1 will be returned. + int32_t streamIsInput(const ::substrait::ReadRel& sRel); /// Multiple conditions are connected to a binary tree structure with /// the relation key words, including AND, OR, and etc. Currently, only /// AND is supported. This function is used to extract all the Substrait /// conditions in the binary tree structure into a vector. void flattenConditions( - const ::substrait::Expression& substraitFilter, - std::vector<::substrait::Expression_ScalarFunction>& scalarFunctions); + const ::substrait::Expression& sFilter, + std::vector<::substrait::Expression_ScalarFunction>& scalarFunctions, + std::vector<::substrait::Expression_SingularOrList>& singularOrLists, + std::vector<::substrait::Expression_IfThen>& ifThens); + + /// Used to find the function specification in the constructed function map. + std::string findFuncSpec(uint64_t id); + + /// Extract join keys from joinExpression. + /// joinExpression is a boolean condition that describes whether each record + /// from the left set “match” the record from the right set. The condition + /// must only include the following operations: AND, ==, field references. + /// Field references correspond to the direct output order of the data. + void extractJoinKeys( + const ::substrait::Expression& joinExpression, + std::vector& leftExprs, + std::vector& rightExprs); + + /// Get aggregation step from AggregateRel. + core::AggregationNode::Step toAggregationStep( + const ::substrait::AggregateRel& sAgg); - /// The Substrait parser used to convert Substrait representations into - /// recognizable representations. - std::shared_ptr substraitParser_{ - std::make_shared()}; + private: + /// Range filter recorder for a field is used to make sure only the conditions + /// that can coexist for this field being pushed down with a range filter. + class RangeRecorder { + public: + /// Set the existence of values range and returns whether this condition can + /// coexist with existing conditions for one field. Conditions in OR + /// relation can coexist with each other. + bool setInRange(bool forOrRelation = false) { + if (forOrRelation) { + return true; + } + if (inRange_ || multiRange_ || leftBound_ || rightBound_) { + return false; + } + inRange_ = true; + return true; + } + + /// Set the existence of left bound and returns whether it can coexist with + /// existing conditions for this field. + bool setLeftBound(bool forOrRelation = false) { + if (forOrRelation) { + if (!rightBound_) + leftBound_ = true; + return !rightBound_; + } + if (leftBound_ || inRange_ || multiRange_) { + return false; + } + leftBound_ = true; + return true; + } + + /// Set the existence of right bound and returns whether it can coexist with + /// existing conditions for this field. + bool setRightBound(bool forOrRelation = false) { + if (forOrRelation) { + if (!leftBound_) + rightBound_ = true; + return !leftBound_; + } + if (rightBound_ || inRange_ || multiRange_) { + return false; + } + rightBound_ = true; + return true; + } + + /// Set the multi-range and returns whether it can coexist with + /// existing conditions for this field. + bool setMultiRange() { + if (inRange_ || multiRange_ || leftBound_ || rightBound_) { + return false; + } + multiRange_ = true; + return true; + } + + /// Set certain existence according to function name and returns whether it + /// can coexist with existing conditions for this field. + bool setCertainRangeForFunction( + const std::string& functionName, + bool reverse = false, + bool forOrRelation = false); + + private: + /// The existence of values range. + bool inRange_ = false; + + /// The existence of left bound. + bool leftBound_ = false; + + /// The existence of right bound. + bool rightBound_ = false; + + /// The existence of multi-range. + bool multiRange_ = false; + }; + + /// Filter info for a column used in filter push down. + class FilterInfo { + public: + // Disable null allow. + void forbidsNull() { + nullAllowed_ = false; + if (!isInitialized_) { + isInitialized_ = true; + } + } + + // Return the initialization status. + bool isInitialized() { + return isInitialized_ ? true : false; + } + + // Add a lower bound to the range. Multiple lower bounds are + // regarded to be in 'or' relation. + void setLower(const std::optional& left, bool isExclusive) { + lowerBounds_.emplace_back(left); + lowerExclusives_.emplace_back(isExclusive); + if (!isInitialized_) { + isInitialized_ = true; + } + } + + // Add a upper bound to the range. Multiple upper bounds are + // regarded to be in 'or' relation. + void setUpper(const std::optional& right, bool isExclusive) { + upperBounds_.emplace_back(right); + upperExclusives_.emplace_back(isExclusive); + if (!isInitialized_) { + isInitialized_ = true; + } + } + + // Set a list of values to be used in the push down of 'in' expression. + void setValues(const std::vector& values) { + for (const auto& value : values) { + valuesVector_.emplace_back(value); + } + if (!isInitialized_) { + isInitialized_ = true; + } + } + + // Set a value for the not(equal) condition. + void setNotValue(const std::optional& notValue) { + notValue_ = notValue; + if (!isInitialized_) { + isInitialized_ = true; + } + } + + // Whether this filter map is initialized. + bool isInitialized_ = false; + + // The null allow. + bool nullAllowed_ = false; + + // If true, left bound will be exclusive. + std::vector lowerExclusives_; + + // If true, right bound will be exclusive. + std::vector upperExclusives_; + + // A value should not be equal to. + std::optional notValue_ = std::nullopt; + + // The lower bounds in 'or' relation. + std::vector> lowerBounds_; + + // The upper bounds in 'or' relation. + std::vector> upperBounds_; + + // The list of values used in 'in' expression. + std::vector valuesVector_; + }; /// Helper Function to convert Substrait sortField to Velox sortingKeys and /// sortingOrders. @@ -151,9 +333,182 @@ class SubstraitVeloxPlanConverter { sortField, const RowTypePtr& inputType); - /// The Expression converter used to convert Substrait representations into - /// Velox expressions. - std::shared_ptr exprConverter_; + /// Returns unique ID to use for plan node. Produces sequential numbers + /// starting from zero. + std::string nextPlanNodeId(); + + /// Returns whether the args of a scalar function being field or + /// field with literal. If yes, extract and set the field index. + bool fieldOrWithLiteral( + const ::google::protobuf::RepeatedPtrField<::substrait::FunctionArgument>& + arguments, + uint32_t& fieldIndex); + + /// Separate the functions to be two parts: + /// subfield functions to be handled by the subfieldFilters in HiveConnector, + /// and remaining functions to be handled by the remainingFilter in + /// HiveConnector. + void separateFilters( + const std::unordered_map>& + rangeRecorders, + const std::vector<::substrait::Expression_ScalarFunction>& + scalarFunctions, + std::vector<::substrait::Expression_ScalarFunction>& subfieldFunctions, + std::vector<::substrait::Expression_ScalarFunction>& remainingFunctions, + const std::vector<::substrait::Expression_SingularOrList>& + singularOrLists, + std::vector<::substrait::Expression_SingularOrList>& subfieldrOrLists, + std::vector<::substrait::Expression_SingularOrList>& remainingrOrLists); + + /// Returns whether a function can be pushed down. + bool canPushdownCommonFunction( + const ::substrait::Expression_ScalarFunction& scalarFunction, + const std::string& filterName, + uint32_t& fieldIdx); + + /// Returns whether a NOT function can be pushed down. + bool canPushdownNot( + const ::substrait::Expression_ScalarFunction& scalarFunction, + const std::unordered_map>& + rangeRecorders); + + /// Returns whether a OR function can be pushed down. + bool canPushdownOr( + const ::substrait::Expression_ScalarFunction& scalarFunction, + const std::unordered_map>& + rangeRecorders); + + /// Returns whether a SingularOrList can be pushed down. + bool canPushdownSingularOrList( + const ::substrait::Expression_SingularOrList& singularOrList, + bool disableIntLike = false); + + /// Returns a set of unique column indices for IN function to be pushed down. + std::unordered_set getInColIndices( + const std::vector<::substrait::Expression_SingularOrList>& + singularOrLists); + + /// Check whether the chidren functions of this scalar function have the same + /// column index. Curretly used to check whether the two chilren functions of + /// 'or' expression are effective on the same column. + bool chidrenFunctionsOnSameField( + const ::substrait::Expression_ScalarFunction& function); + + /// Extract the scalar function, and set the filter info for different types + /// of columns. If reverse is true, the opposite filter info will be set. + void setFilterMap( + const ::substrait::Expression_ScalarFunction& scalarFunction, + const std::vector& inputTypeList, + std::unordered_map>& colInfoMap, + bool reverse = false); + + /// Extract SingularOrList and returns the field index. + uint32_t getColumnIndexFromSingularOrList( + const ::substrait::Expression_SingularOrList& singularOrList); + + /// Extract SingularOrList and set it to the filter info map. + void setSingularListValues( + const ::substrait::Expression_SingularOrList& singularOrList, + std::unordered_map>& colInfoMap); + + /// Set the filter info for a column base on the information + /// extracted from filter condition. + template + void setColInfoMap( + const std::string& filterName, + uint32_t colIdx, + std::optional literalVariant, + bool reverse, + std::unordered_map>& colInfoMap); + + /// Create a multirange to specify the filter 'x != notValue' with: + /// x > notValue or x < notValue. + template + void createNotEqualFilter( + variant notVariant, + bool nullAllowed, + std::vector>& colFilters); + + /// Create a values range to handle in filter. + /// variants: the list of values extracted from the in expression. + /// inputName: the column input name. + template + void setInFilter( + const std::vector& variants, + bool nullAllowed, + const std::string& inputName, + connector::hive::SubfieldFilters& filters); + + /// Set the constructed filters into SubfieldFilters. + /// The FilterType is used to distinguish BigintRange and + /// Filter (the base class). This is needed because BigintMultiRange + /// can only accept the unique ptr of BigintRange as parameter. + template + void setSubfieldFilter( + std::vector> colFilters, + const std::string& inputName, + bool nullAllowed, + connector::hive::SubfieldFilters& filters); + + /// Create the subfield filter based on the constructed filter info. + /// inputName: the input name of a column. + template + void constructSubfieldFilters( + uint32_t colIdx, + const std::string& inputName, + const std::shared_ptr& filterInfo, + connector::hive::SubfieldFilters& filters); + + /// Construct subfield filters according to the pre-set map of filter info. + connector::hive::SubfieldFilters mapToFilters( + const std::vector& inputNameList, + const std::vector& inputTypeList, + std::unordered_map> colInfoMap); + + /// Convert subfield functions into subfieldFilters to + /// be used in Hive Connector. + connector::hive::SubfieldFilters toSubfieldFilters( + const std::vector& inputNameList, + const std::vector& inputTypeList, + const std::vector<::substrait::Expression_ScalarFunction>& + subfieldFunctions, + const std::vector<::substrait::Expression_SingularOrList>& + singularOrLists); + + /// Connect all remaining functions with 'and' relation + /// for the use of remaingFilter in Hive Connector. + core::TypedExprPtr connectWithAnd( + std::vector inputNameList, + std::vector inputTypeList, + const std::vector<::substrait::Expression_ScalarFunction>& + remainingFunctions, + const std::vector<::substrait::Expression_SingularOrList>& + singularOrLists, + const std::vector<::substrait::Expression_IfThen>& ifThens); + + /// Connect the left and right expressions with 'and' relation. + core::TypedExprPtr connectWithAnd( + core::TypedExprPtr leftExpr, + core::TypedExprPtr rightExpr); + + /// Set the phase of Aggregation. + void setPhase( + const ::substrait::AggregateRel& sAgg, + core::AggregationNode::Step& aggStep); + + /// Used to convert AggregateRel into Velox plan node. + /// The output of child node will be used as the input of Aggregation. + std::shared_ptr toVeloxAgg( + const ::substrait::AggregateRel& sAgg, + const std::shared_ptr& childNode, + const core::AggregationNode::Step& aggStep); + + /// Helper function to convert the input of Substrait Rel to Velox Node. + template + core::PlanNodePtr convertSingleInput(T rel) { + VELOX_CHECK(rel.has_input(), "Child Rel is expected here."); + return toVeloxPlan(rel.input()); + } /// The unique identification for each PlanNode. int planNodeId_ = 0; @@ -162,19 +517,30 @@ class SubstraitVeloxPlanConverter { /// name. Will be constructed based on the Substrait representation. std::unordered_map functionMap_; - /// Mapping from leaf plan node ID to splits. + /// The map storing the split stats for each PlanNode. std::unordered_map> splitInfoMap_; + /// The map storing the pre-built plan nodes which can be accessed through + /// index. This map is only used when the computation of a Substrait plan + /// depends on other input nodes. + std::unordered_map> + inputNodesMap_; + + /// The Substrait parser used to convert Substrait representations into + /// recognizable representations. + std::shared_ptr subParser_{ + std::make_shared()}; + + /// The Expression converter used to convert Substrait representations into + /// Velox expressions. + std::shared_ptr exprConverter_; + /// Memory pool. memory::MemoryPool* pool_; - /// Helper function to convert the input of Substrait Rel to Velox Node. - template - core::PlanNodePtr convertSingleInput(T rel) { - VELOX_CHECK(rel.has_input(), "Child Rel is expected here."); - return toVeloxPlan(rel.input()); - } + /// A flag used to specify validation. + bool validationMode_ = false; }; } // namespace facebook::velox::substrait diff --git a/velox/substrait/SubstraitToVeloxPlanValidator.cpp b/velox/substrait/SubstraitToVeloxPlanValidator.cpp new file mode 100644 index 000000000000..f3376e43fe99 --- /dev/null +++ b/velox/substrait/SubstraitToVeloxPlanValidator.cpp @@ -0,0 +1,1132 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/substrait/SubstraitToVeloxPlanValidator.h" +#include +#include +#include "TypeUtils.h" +#include "velox/expression/SignatureBinder.h" +#include "velox/type/Tokenizer.h" + +namespace facebook::velox::substrait { +namespace { +bool validateColNames(const ::substrait::NamedStruct& schema) { + for (auto& name : schema.names()) { + common::Tokenizer token(name); + for (auto i = 0; i < name.size(); i++) { + auto c = name[i]; + if (!token.isUnquotedPathCharacter(c)) { + std::cout << "Illegal column charactor " << c << "in column " << name + << std::endl; + return false; + } + } + } + return true; +} +} // namespace +bool SubstraitToVeloxPlanValidator::validateInputTypes( + const ::substrait::extensions::AdvancedExtension& extension, + std::vector& types) { + // The input type is wrapped in enhancement. + if (!extension.has_enhancement()) { + return false; + } + const auto& enhancement = extension.enhancement(); + ::substrait::Type inputType; + if (!enhancement.UnpackTo(&inputType)) { + return false; + } + if (!inputType.has_struct_()) { + return false; + } + + // Get the input types. + const auto& sTypes = inputType.struct_().types(); + for (const auto& sType : sTypes) { + try { + types.emplace_back(toVeloxType(subParser_->parseType(sType)->type)); + } catch (const VeloxException& err) { + std::cout << "Type is not supported due to:" << err.message() + << std::endl; + return false; + } + } + return true; +} + +bool SubstraitToVeloxPlanValidator::validateRound( + const ::substrait::Expression::ScalarFunction& scalarFunction, + const RowTypePtr& inputType) { + const auto& arguments = scalarFunction.arguments(); + if (arguments.size() < 2) { + return false; + } + if (!arguments[1].value().has_literal()) { + VELOX_FAIL("Round scale is expected."); + } + // Velox has different result with Spark on negative scale. + auto typeCase = arguments[1].value().literal().literal_type_case(); + switch (typeCase) { + case ::substrait::Expression_Literal::LiteralTypeCase::kI32: + return (arguments[1].value().literal().i32() >= 0); + case ::substrait::Expression_Literal::LiteralTypeCase::kI64: + return (arguments[1].value().literal().i64() >= 0); + default: + VELOX_NYI( + "Round scale validation is not supported for type case '{}'", + typeCase); + } +} + +bool SubstraitToVeloxPlanValidator::validateExtractExpr( + const std::vector>& params) { + VELOX_CHECK_EQ(params.size(), 2); + auto functionArg = + std::dynamic_pointer_cast(params[0]); + if (functionArg) { + // Get the function argument. + auto variant = functionArg->value(); + if (!variant.hasValue()) { + VELOX_FAIL("Value expected in variant."); + } + // The first parameter specifies extracting from which field. + std::string from = variant.value(); + // Hour causes incorrect result. + if (from == "HOUR") { + return false; + } + return true; + } + VELOX_FAIL("Constant is expected to be the first parameter in extract."); +} + +bool SubstraitToVeloxPlanValidator::validateScalarFunction( + const ::substrait::Expression::ScalarFunction& scalarFunction, + const RowTypePtr& inputType) { + std::vector params; + params.reserve(scalarFunction.arguments().size()); + for (const auto& argument : scalarFunction.arguments()) { + if (argument.has_value() && + !validateExpression(argument.value(), inputType)) { + return false; + } + params.emplace_back( + exprConverter_->toVeloxExpr(argument.value(), inputType)); + } + + const auto& function = subParser_->findFunctionSpec( + planConverter_->getFunctionMap(), scalarFunction.function_reference()); + const auto& name = subParser_->getSubFunctionName(function); + std::vector types; + subParser_->getSubFunctionTypes(function, types); + if (name == "round") { + return validateRound(scalarFunction, inputType); + } + if (name == "extract") { + return validateExtractExpr(params); + } + if (name == "char_length") { + VELOX_CHECK(types.size() == 1); + if (types[0] == "vbin") { + VLOG(1) << "Binary type is not supported in " << name << "."; + return false; + } + } + + std::unordered_set functions = { + "regexp_replace", "split", "split_part", + "factorial", "concat_ws", "rand", + "json_array_length", "from_unixtime", "to_unix_timestamp", + "unix_timestamp", "repeat", "translate", + "add_months", "date_format", "trunc", + "sequence", "posexplode", "arrays_overlap", + "array_min", "array_max"}; + if (functions.find(name) != functions.end()) { + VLOG(1) << "Function is not supported: " << name << "."; + return false; + } + + return true; +} + +bool SubstraitToVeloxPlanValidator::validateLiteral( + const ::substrait::Expression_Literal& literal, + const RowTypePtr& inputType) { + if (literal.has_list() && literal.list().values_size() == 0) { + return false; + } + return true; +} + +bool SubstraitToVeloxPlanValidator::validateCast( + const ::substrait::Expression::Cast& castExpr, + const RowTypePtr& inputType) { + if (!validateExpression(castExpr.input(), inputType)) { + return false; + } + + const auto& toType = + toVeloxType(subParser_->parseType(castExpr.type())->type); + if (toType->kind() == TypeKind::TIMESTAMP) { + VLOG(1) << "Casting to TIMESTAMP is not supported"; + return false; + } + + core::TypedExprPtr input = + exprConverter_->toVeloxExpr(castExpr.input(), inputType); + + // Casting from some types is not supported. See CastExpr::applyCast. + switch (input->type()->kind()) { + case TypeKind::ARRAY: + case TypeKind::MAP: + case TypeKind::ROW: + case TypeKind::VARBINARY: + VLOG(1) << "Invalid input type in casting: " << input->type() << "."; + return false; + case TypeKind::DATE: { + if (toType->kind() == TypeKind::TIMESTAMP) { + VLOG(1) << "Casting from DATE to TIMESTAMP is not supported."; + return false; + } + } + case TypeKind::TIMESTAMP: { + VLOG(1) + << "Casting from TIMESTAMP is not supported or has incorrect result."; + return false; + } + default: { + } + } + return true; +} + +bool SubstraitToVeloxPlanValidator::validateExpression( + const ::substrait::Expression& expression, + const RowTypePtr& inputType) { + std::shared_ptr veloxExpr; + auto typeCase = expression.rex_type_case(); + switch (typeCase) { + case ::substrait::Expression::RexTypeCase::kScalarFunction: + return validateScalarFunction(expression.scalar_function(), inputType); + case ::substrait::Expression::RexTypeCase::kLiteral: + return validateLiteral(expression.literal(), inputType); + case ::substrait::Expression::RexTypeCase::kCast: + return validateCast(expression.cast(), inputType); + default: + return true; + } +} + +bool SubstraitToVeloxPlanValidator::validate( + const ::substrait::FetchRel& fetchRel) { + const auto& extension = fetchRel.advanced_extension(); + std::vector types; + if (!validateInputTypes(extension, types)) { + std::cout << "Validation failed for input types in FetchRel." << std::endl; + return false; + } + + if (fetchRel.offset() < 0 || fetchRel.count() < 0) { + std::cout << "Offset and count should be valid." << std::endl; + return false; + } + return true; +} + +bool SubstraitToVeloxPlanValidator::validate( + const ::substrait::ExpandRel& expandRel) { + if (expandRel.has_input() && !validate(expandRel.input())) { + return false; + } + RowTypePtr rowType = nullptr; + // Get and validate the input types from extension. + if (expandRel.has_advanced_extension()) { + const auto& extension = expandRel.advanced_extension(); + std::vector types; + if (!validateInputTypes(extension, types)) { + std::cout << "Validation failed for input types in ExpandRel." + << std::endl; + return false; + } + int32_t inputPlanNodeId = 0; + std::vector names; + names.reserve(types.size()); + for (auto colIdx = 0; colIdx < types.size(); colIdx++) { + names.emplace_back(subParser_->makeNodeName(inputPlanNodeId, colIdx)); + } + rowType = std::make_shared(std::move(names), std::move(types)); + } + + int32_t projectSize = 0; + // Validate fields. + for (const auto& fields : expandRel.fields()) { + std::vector expressions; + if (fields.has_switching_field()) { + auto projectExprs = fields.switching_field().duplicates(); + expressions.reserve(projectExprs.size()); + if (projectSize == 0) { + projectSize = projectExprs.size(); + } else if (projectSize != projectExprs.size()) { + std::cout << "SwitchingField expressions size should be constant." + << std::endl; + return false; + } + + try { + for (const auto& projectExpr : projectExprs) { + const auto& typeCase = projectExpr.rex_type_case(); + switch (typeCase) { + case ::substrait::Expression::RexTypeCase::kSelection: + case ::substrait::Expression::RexTypeCase::kLiteral: + break; + default: + std::cout << "Only field or literal is supported." << std::endl; + return false; + } + if (rowType) { + expressions.emplace_back( + exprConverter_->toVeloxExpr(projectExpr, rowType)); + } + } + + if (rowType) { + // Try to compile the expressions. If there is any unregistered + // function or mismatched type, exception will be thrown. + exec::ExprSet exprSet(std::move(expressions), execCtx_); + } + + } catch (const VeloxException& err) { + std::cout << "Validation failed for expressions in ExpandRel due to:" + << err.message() << std::endl; + return false; + } + } else { + std::cout << "Only SwitchingField is supported in ExpandRel." + << std::endl; + return false; + } + } + + return true; +} + +bool validateBoundType(::substrait::Expression_WindowFunction_Bound boundType) { + switch (boundType.kind_case()) { + case ::substrait::Expression_WindowFunction_Bound::kUnboundedFollowing: + case ::substrait::Expression_WindowFunction_Bound::kUnboundedPreceding: + case ::substrait::Expression_WindowFunction_Bound::kCurrentRow: + case ::substrait::Expression_WindowFunction_Bound::kFollowing: + case ::substrait::Expression_WindowFunction_Bound::kPreceding: + break; + default: + std::cout << "The Bound Type is not supported. " + << "\n"; + return false; + } + return true; +} + +bool SubstraitToVeloxPlanValidator::validate( + const ::substrait::WindowRel& windowRel) { + if (windowRel.has_input() && !validate(windowRel.input())) { + return false; + } + + // Get and validate the input types from extension. + if (!windowRel.has_advanced_extension()) { + std::cout << "Input types are expected in WindowRel." << std::endl; + return false; + } + const auto& extension = windowRel.advanced_extension(); + std::vector types; + if (!validateInputTypes(extension, types)) { + std::cout << "Validation failed for input types in WindowRel." << std::endl; + return false; + } + + int32_t inputPlanNodeId = 0; + std::vector names; + names.reserve(types.size()); + for (auto colIdx = 0; colIdx < types.size(); colIdx++) { + names.emplace_back(subParser_->makeNodeName(inputPlanNodeId, colIdx)); + } + auto rowType = std::make_shared(std::move(names), std::move(types)); + + // Validate WindowFunction + std::vector funcSpecs; + funcSpecs.reserve(windowRel.measures().size()); + for (const auto& smea : windowRel.measures()) { + try { + const auto& windowFunction = smea.measure(); + funcSpecs.emplace_back( + planConverter_->findFuncSpec(windowFunction.function_reference())); + toVeloxType(subParser_->parseType(windowFunction.output_type())->type); + for (const auto& arg : windowFunction.arguments()) { + auto typeCase = arg.value().rex_type_case(); + switch (typeCase) { + case ::substrait::Expression::RexTypeCase::kSelection: + case ::substrait::Expression::RexTypeCase::kLiteral: + break; + default: + std::cout << "Only field is supported in window functions." + << std::endl; + return false; + } + } + // Validate BoundType and Frame Type + switch (windowFunction.window_type()) { + case ::substrait::WindowType::ROWS: + case ::substrait::WindowType::RANGE: + break; + default: + VELOX_FAIL( + "the window type only support ROWS and RANGE, and the input type is ", + windowFunction.window_type()); + } + + bool boundTypeSupported = + validateBoundType(windowFunction.upper_bound()) && + validateBoundType(windowFunction.lower_bound()); + if (!boundTypeSupported) { + return false; + } + } catch (const VeloxException& err) { + std::cout << "Validation failed for window function due to: " + << err.message() << std::endl; + return false; + } + } + + // Validate supported aggregate functions. + std::unordered_set unsupportedFuncs = {"collect_list"}; + for (const auto& funcSpec : funcSpecs) { + auto funcName = subParser_->getSubFunctionName(funcSpec); + if (unsupportedFuncs.find(funcName) != unsupportedFuncs.end()) { + std::cout << "Validation failed due to " << funcName + << " was not supported in WindowRel." << std::endl; + return false; + } + } + + // Validate groupby expression + const auto& groupByExprs = windowRel.partition_expressions(); + std::vector> expressions; + expressions.reserve(groupByExprs.size()); + try { + for (const auto& expr : groupByExprs) { + auto expression = exprConverter_->toVeloxExpr(expr, rowType); + auto expr_field = + dynamic_cast(expression.get()); + if (expr_field == nullptr) { + std::cout + << "Only field is supported for partition key in Window Operator!" + << std::endl; + return false; + } else { + expressions.emplace_back(expression); + } + } + // Try to compile the expressions. If there is any unregistred funciton or + // mismatched type, exception will be thrown. + exec::ExprSet exprSet(std::move(expressions), execCtx_); + } catch (const VeloxException& err) { + std::cout << "Validation failed for expression in ProjectRel due to:" + << err.message() << std::endl; + return false; + } + + // Validate Sort expression + const auto& sorts = windowRel.sorts(); + for (const auto& sort : sorts) { + switch (sort.direction()) { + case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST: + case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST: + case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST: + case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST: + break; + default: + return false; + } + + if (sort.has_expr()) { + try { + auto expression = exprConverter_->toVeloxExpr(sort.expr(), rowType); + auto expr_field = + dynamic_cast(expression.get()); + VELOX_CHECK( + expr_field != nullptr, + " the sorting key in Sort Operator only support field") + + exec::ExprSet exprSet({std::move(expression)}, execCtx_); + } catch (const VeloxException& err) { + std::cout << "Validation failed for expression in SortRel due to:" + << err.message() << std::endl; + return false; + } + } + } + + return true; +} + +bool SubstraitToVeloxPlanValidator::validate( + const ::substrait::SortRel& sortRel) { + if (sortRel.has_input() && !validate(sortRel.input())) { + return false; + } + // Get and validate the input types from extension. + if (!sortRel.has_advanced_extension()) { + std::cout << "Input types are expected in SortRel." << std::endl; + return false; + } + const auto& extension = sortRel.advanced_extension(); + std::vector types; + if (!validateInputTypes(extension, types)) { + std::cout << "Validation failed for input types in SortRel." << std::endl; + return false; + } + + int32_t inputPlanNodeId = 0; + std::vector names; + names.reserve(types.size()); + for (auto colIdx = 0; colIdx < types.size(); colIdx++) { + names.emplace_back(subParser_->makeNodeName(inputPlanNodeId, colIdx)); + } + auto rowType = std::make_shared(std::move(names), std::move(types)); + + const auto& sorts = sortRel.sorts(); + for (const auto& sort : sorts) { + switch (sort.direction()) { + case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST: + case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST: + case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST: + case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST: + break; + default: + return false; + } + + if (sort.has_expr()) { + try { + auto expression = exprConverter_->toVeloxExpr(sort.expr(), rowType); + auto expr_field = + dynamic_cast(expression.get()); + VELOX_CHECK( + expr_field != nullptr, + " the sorting key in Sort Operator only support field") + + exec::ExprSet exprSet({std::move(expression)}, execCtx_); + } catch (const VeloxException& err) { + std::cout << "Validation failed for expression in SortRel due to:" + << err.message() << std::endl; + return false; + } + } + } + + return true; +} + +bool SubstraitToVeloxPlanValidator::validate( + const ::substrait::ProjectRel& projectRel) { + if (projectRel.has_input() && !validate(projectRel.input())) { + return false; + } + + // Get and validate the input types from extension. + if (!projectRel.has_advanced_extension()) { + std::cout << "Input types are expected in ProjectRel." << std::endl; + return false; + } + const auto& extension = projectRel.advanced_extension(); + std::vector types; + if (!validateInputTypes(extension, types)) { + std::cout << "Validation failed for input types in ProjectRel." + << std::endl; + return false; + } + + for (auto i = 0; i < types.size(); i++) { + switch (types[i]->kind()) { + case TypeKind::ARRAY: + return false; + default:; + } + } + + int32_t inputPlanNodeId = 0; + // Create the fake input names to be used in row type. + std::vector names; + names.reserve(types.size()); + for (uint32_t colIdx = 0; colIdx < types.size(); colIdx++) { + names.emplace_back(subParser_->makeNodeName(inputPlanNodeId, colIdx)); + } + auto rowType = std::make_shared(std::move(names), std::move(types)); + + // Validate the project expressions. + const auto& projectExprs = projectRel.expressions(); + std::vector> expressions; + expressions.reserve(projectExprs.size()); + try { + for (const auto& expr : projectExprs) { + if (!validateExpression(expr, rowType)) { + return false; + } + expressions.emplace_back(exprConverter_->toVeloxExpr(expr, rowType)); + } + // Try to compile the expressions. If there is any unregistered function or + // mismatched type, exception will be thrown. + exec::ExprSet exprSet(std::move(expressions), execCtx_); + } catch (const VeloxException& err) { + std::cout << "Validation failed for expression in ProjectRel due to:" + << err.message() << std::endl; + return false; + } + return true; +} + +bool SubstraitToVeloxPlanValidator::validate( + const ::substrait::FilterRel& filterRel) { + if (filterRel.has_input() && !validate(filterRel.input())) { + return false; + } + + // Get and validate the input types from extension. + if (!filterRel.has_advanced_extension()) { + std::cout << "Input types are expected in FilterRel." << std::endl; + return false; + } + const auto& extension = filterRel.advanced_extension(); + std::vector types; + if (!validateInputTypes(extension, types)) { + std::cout << "Validation failed for input types in FilterRel." << std::endl; + return false; + } + for (const auto& type : types) { + if (type->kind() == TypeKind::TIMESTAMP) { + VLOG(1) << "Timestamp is not fully supported in Filter"; + return false; + } + } + + int32_t inputPlanNodeId = 0; + // Create the fake input names to be used in row type. + std::vector names; + names.reserve(types.size()); + for (uint32_t colIdx = 0; colIdx < types.size(); colIdx++) { + names.emplace_back(subParser_->makeNodeName(inputPlanNodeId, colIdx)); + } + auto rowType = std::make_shared(std::move(names), std::move(types)); + + std::vector> expressions; + expressions.reserve(1); + try { + if (!validateExpression(filterRel.condition(), rowType)) { + return false; + } + expressions.emplace_back( + exprConverter_->toVeloxExpr(filterRel.condition(), rowType)); + // Try to compile the expressions. If there is any unregistered function + // or mismatched type, exception will be thrown. + exec::ExprSet exprSet(std::move(expressions), execCtx_); + } catch (const VeloxException& err) { + std::cout << "Validation failed for expression in FilterRel due to:" + << err.message() << std::endl; + return false; + } + return true; +} + +bool SubstraitToVeloxPlanValidator::validate( + const ::substrait::JoinRel& joinRel) { + if (joinRel.has_left() && !validate(joinRel.left())) { + return false; + } + if (joinRel.has_right() && !validate(joinRel.right())) { + return false; + } + + if (joinRel.has_advanced_extension() && + subParser_->configSetInOptimization( + joinRel.advanced_extension(), "isSMJ=")) { + switch (joinRel.type()) { + case ::substrait::JoinRel_JoinType_JOIN_TYPE_INNER: + case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT: + break; + default: + std::cout << "Sort merge join only support inner and left join" + << std::endl; + return false; + } + } + switch (joinRel.type()) { + case ::substrait::JoinRel_JoinType_JOIN_TYPE_INNER: + case ::substrait::JoinRel_JoinType_JOIN_TYPE_OUTER: + case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT: + case ::substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT: + case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI: + case ::substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT_SEMI: + case ::substrait::JoinRel_JoinType_JOIN_TYPE_ANTI: + break; + default: + return false; + } + + // Validate input types. + if (!joinRel.has_advanced_extension()) { + std::cout << "Input types are expected in JoinRel." << std::endl; + return false; + } + + const auto& extension = joinRel.advanced_extension(); + std::vector types; + if (!validateInputTypes(extension, types)) { + std::cout << "Validation failed for input types in JoinRel" << std::endl; + return false; + } + + int32_t inputPlanNodeId = 0; + std::vector names; + names.reserve(types.size()); + for (auto colIdx = 0; colIdx < types.size(); colIdx++) { + names.emplace_back(subParser_->makeNodeName(inputPlanNodeId, colIdx)); + } + auto rowType = std::make_shared(std::move(names), std::move(types)); + + if (joinRel.has_expression()) { + std::vector leftExprs, + rightExprs; + try { + planConverter_->extractJoinKeys( + joinRel.expression(), leftExprs, rightExprs); + } catch (const VeloxException& err) { + std::cout << "Validation failed for expression in JoinRel due to:" + << err.message() << std::endl; + return false; + } + } + + if (joinRel.has_post_join_filter()) { + try { + auto expression = + exprConverter_->toVeloxExpr(joinRel.post_join_filter(), rowType); + exec::ExprSet exprSet({std::move(expression)}, execCtx_); + } catch (const VeloxException& err) { + std::cout << "Validation failed for expression in ProjectRel due to:" + << err.message() << std::endl; + return false; + } + } + return true; +} + +TypePtr SubstraitToVeloxPlanValidator::getDecimalType( + const std::string& decimalType) { + // Decimal info is in the format of dec. + auto precisionStart = decimalType.find_first_of('<'); + auto tokenIndex = decimalType.find_first_of(','); + auto scaleStart = decimalType.find_first_of('>'); + auto precision = stoi(decimalType.substr( + precisionStart + 1, (tokenIndex - precisionStart - 1))); + auto scale = + stoi(decimalType.substr(tokenIndex + 1, (scaleStart - tokenIndex - 1))); + return DECIMAL(precision, scale); +} + +TypePtr SubstraitToVeloxPlanValidator::getRowType( + const std::string& structType) { + // Struct info is in the format of struct. + // TODO: nested struct is not supported. + auto structStart = structType.find_first_of('<'); + auto structEnd = structType.find_last_of('>'); + VELOX_CHECK( + structEnd - structStart > 1, + "More information is needed to create RowType"); + std::string childrenTypes = + structType.substr(structStart + 1, structEnd - structStart - 1); + + // Split the types with delimiter. + std::string delimiter = ","; + std::size_t pos; + std::vector types; + std::vector names; + while ((pos = childrenTypes.find(delimiter)) != std::string::npos) { + const auto& typeStr = childrenTypes.substr(0, pos); + std::string decDelimiter = ">"; + if (typeStr.find("dec") != std::string::npos) { + std::size_t endPos = childrenTypes.find(decDelimiter); + VELOX_CHECK(endPos >= pos + 1, "Decimal scale is expected."); + const auto& decimalStr = + typeStr + childrenTypes.substr(pos, endPos - pos) + decDelimiter; + types.emplace_back(getDecimalType(decimalStr)); + names.emplace_back(""); + childrenTypes.erase( + 0, endPos + delimiter.length() + decDelimiter.length()); + continue; + } + + types.emplace_back(toVeloxType(subParser_->parseType(typeStr))); + names.emplace_back(""); + childrenTypes.erase(0, pos + delimiter.length()); + } + types.emplace_back(toVeloxType(subParser_->parseType(childrenTypes))); + names.emplace_back(""); + return std::make_shared(std::move(names), std::move(types)); +} + +bool SubstraitToVeloxPlanValidator::validateAggRelFunctionType( + const ::substrait::AggregateRel& aggRel) { + if (aggRel.measures_size() == 0) { + return true; + } + + for (const auto& smea : aggRel.measures()) { + const auto& aggFunction = smea.measure(); + auto funcSpec = + planConverter_->findFuncSpec(aggFunction.function_reference()); + std::vector types; + bool isDecimal = false; + try { + std::vector funcTypes; + subParser_->getSubFunctionTypes(funcSpec, funcTypes); + types.reserve(funcTypes.size()); + for (auto& type : funcTypes) { + if (!isDecimal && type.find("dec") != std::string::npos) { + isDecimal = true; + } + if (type.find("struct") != std::string::npos) { + types.emplace_back(getRowType(type)); + } else if (type.find("dec") != std::string::npos) { + types.emplace_back(getDecimalType(type)); + } else { + types.emplace_back(toVeloxType(subParser_->parseType(type))); + } + } + } catch (const VeloxException& err) { + std::cout + << "Validation failed for input type in AggregateRel function due to:" + << err.message() << std::endl; + return false; + } + auto funcName = subParser_->mapToVeloxFunction( + subParser_->getSubFunctionName(funcSpec), isDecimal); + if (auto signatures = exec::getAggregateFunctionSignatures(funcName)) { + for (const auto& signature : signatures.value()) { + exec::SignatureBinder binder(*signature, types); + if (binder.tryBind()) { + auto resolveType = binder.tryResolveType( + exec::isPartialOutput(planConverter_->toAggregationStep(aggRel)) + ? signature->intermediateType() + : signature->returnType()); + if (resolveType == nullptr) { + std::cout + << fmt::format( + "Validation failed for function {} resolve type in AggregateRel.", + funcName) + << std::endl; + return false; + } + return true; + } + } + std::cout + << fmt::format( + "Validation failed for function {} bind in AggregateRel.", + funcName) + << std::endl; + return false; + } + } + std::cout << "Validation failed for function resolve in AggregateRel." + << std::endl; + return false; +} + +bool SubstraitToVeloxPlanValidator::validate( + const ::substrait::AggregateRel& aggRel) { + if (aggRel.has_input() && !validate(aggRel.input())) { + return false; + } + + // Validate input types. + if (aggRel.has_advanced_extension()) { + std::vector types; + const auto& extension = aggRel.advanced_extension(); + if (!validateInputTypes(extension, types)) { + std::cout << "Validation failed for input types in AggregateRel." + << std::endl; + return false; + } + } + + // Validate groupings. + for (const auto& grouping : aggRel.groupings()) { + for (const auto& groupingExpr : grouping.grouping_expressions()) { + const auto& typeCase = groupingExpr.rex_type_case(); + switch (typeCase) { + case ::substrait::Expression::RexTypeCase::kSelection: + break; + default: + std::cout << "Only field is supported in groupings." << std::endl; + return false; + } + } + } + + // Validate aggregate functions. + std::vector funcSpecs; + funcSpecs.reserve(aggRel.measures().size()); + for (const auto& smea : aggRel.measures()) { + try { + // Validate the filter expression + if (smea.has_filter()) { + ::substrait::Expression aggRelMask = smea.filter(); + if (aggRelMask.ByteSizeLong() > 0) { + auto typeCase = aggRelMask.rex_type_case(); + switch (typeCase) { + case ::substrait::Expression::RexTypeCase::kSelection: + break; + default: + std::cout + << "Only field is supported in aggregate filter expression." + << std::endl; + return false; + } + } + } + + const auto& aggFunction = smea.measure(); + const auto& functionSpec = + planConverter_->findFuncSpec(aggFunction.function_reference()); + funcSpecs.emplace_back(functionSpec); + toVeloxType(subParser_->parseType(aggFunction.output_type())->type); + // Validate the size of arguments. + if (subParser_->getSubFunctionName(functionSpec) == "count" && + aggFunction.arguments().size() > 1) { + // Count accepts only one argument. + return false; + } + for (const auto& arg : aggFunction.arguments()) { + auto typeCase = arg.value().rex_type_case(); + switch (typeCase) { + case ::substrait::Expression::RexTypeCase::kSelection: + case ::substrait::Expression::RexTypeCase::kLiteral: + break; + default: + std::cout << "Only field is supported in aggregate functions." + << std::endl; + return false; + } + } + } catch (const VeloxException& err) { + std::cout << "Validation failed for aggregate function due to: " + << err.message() << std::endl; + return false; + } + } + + std::unordered_set supportedFuncs = { + "sum", + "sum_merge", + "count", + "count_merge", + "avg", + "avg_merge", + "min", + "min_merge", + "max", + "max_merge", + "stddev_samp", + "stddev_samp_merge", + "stddev_pop", + "stddev_pop_merge", + "bloom_filter_agg", + "var_samp", + "var_samp_merge", + "var_pop", + "var_pop_merge", + "bit_and", + "bit_and_merge", + "bit_or", + "bit_or_merge", + "bit_xor", + "bit_xor_merge", + "first", + "first_merge", + "first_ignore_null", + "first_ignore_null_merge", + "last", + "last_merge", + "last_ignore_null", + "last_ignore_null_merge", + "corr", + "corr_merge", + "covar_pop", + "covar_pop_merge", + "covar_samp", + "covar_samp_merge", + "approx_distinct"}; + for (const auto& funcSpec : funcSpecs) { + auto funcName = subParser_->getSubFunctionName(funcSpec); + if (supportedFuncs.find(funcName) == supportedFuncs.end()) { + std::cout << "Validation failed due to " << funcName + << " was not supported in AggregateRel." << std::endl; + return false; + } + } + + if (!validateAggRelFunctionType(aggRel)) { + return false; + } + + // Validate both groupby and aggregates input are empty, which is corner case. + if (aggRel.measures_size() == 0) { + bool hasExpr = false; + for (const auto& grouping : aggRel.groupings()) { + for (const auto& groupingExpr : grouping.grouping_expressions()) { + hasExpr = true; + break; + } + if (hasExpr) { + break; + } + } + if (!hasExpr) { + std::cout + << "Validation failed due to aggregation must specify either grouping keys or aggregates." + << std::endl; + return false; + } + } + return true; +} + +bool SubstraitToVeloxPlanValidator::validate( + const ::substrait::ReadRel& readRel) { + try { + planConverter_->toVeloxPlan(readRel); + } catch (const VeloxException& err) { + std::cout << "ReadRel validation failed due to:" << err.message() + << std::endl; + return false; + } + + // Validate filter in ReadRel. + if (readRel.has_filter()) { + std::vector> expressions; + expressions.reserve(1); + + std::vector veloxTypeList; + if (readRel.has_base_schema()) { + const auto& baseSchema = readRel.base_schema(); + auto substraitTypeList = subParser_->parseNamedStruct(baseSchema); + veloxTypeList.reserve(substraitTypeList.size()); + for (const auto& substraitType : substraitTypeList) { + veloxTypeList.emplace_back(toVeloxType(substraitType->type)); + } + } + std::vector names; + int32_t inputPlanNodeId = 0; + names.reserve(veloxTypeList.size()); + for (auto colIdx = 0; colIdx < veloxTypeList.size(); colIdx++) { + names.emplace_back(subParser_->makeNodeName(inputPlanNodeId, colIdx)); + } + auto rowType = + std::make_shared(std::move(names), std::move(veloxTypeList)); + + try { + expressions.emplace_back( + exprConverter_->toVeloxExpr(readRel.filter(), rowType)); + // Try to compile the expressions. If there is any unregistered function + // or mismatched type, exception will be thrown. + exec::ExprSet exprSet(std::move(expressions), execCtx_); + } catch (const VeloxException& err) { + std::cout << "Validation failed for filter expression in ReadRel due to:" + << err.message() << std::endl; + return false; + } + } + if (readRel.has_base_schema()) { + const auto& baseSchema = readRel.base_schema(); + if (!validateColNames(baseSchema)) { + std::cout + << "Validation failed for column name contains illegal charactor." + << std::endl; + return false; + } + } + return true; +} + +bool SubstraitToVeloxPlanValidator::validate(const ::substrait::Rel& rel) { + if (rel.has_aggregate()) { + return validate(rel.aggregate()); + } + if (rel.has_project()) { + return validate(rel.project()); + } + if (rel.has_filter()) { + return validate(rel.filter()); + } + if (rel.has_join()) { + return validate(rel.join()); + } + if (rel.has_read()) { + return validate(rel.read()); + } + if (rel.has_sort()) { + return validate(rel.sort()); + } + if (rel.has_expand()) { + return validate(rel.expand()); + } + if (rel.has_fetch()) { + return validate(rel.fetch()); + } + if (rel.has_window()) { + return validate(rel.window()); + } + return false; +} + +bool SubstraitToVeloxPlanValidator::validate( + const ::substrait::RelRoot& relRoot) { + if (relRoot.has_input()) { + const auto& rel = relRoot.input(); + return validate(rel); + } + return false; +} + +bool SubstraitToVeloxPlanValidator::validate(const ::substrait::Plan& plan) { + // Create plan converter and expression converter to help the validation. + planConverter_->constructFunctionMap(plan); + exprConverter_ = std::make_shared( + pool_, planConverter_->getFunctionMap()); + + for (const auto& rel : plan.relations()) { + if (rel.has_root()) { + return validate(rel.root()); + } + if (rel.has_rel()) { + return validate(rel.rel()); + } + } + return false; +} + +} // namespace facebook::velox::substrait diff --git a/velox/substrait/SubstraitToVeloxPlanValidator.h b/velox/substrait/SubstraitToVeloxPlanValidator.h new file mode 100644 index 000000000000..66fb0c718428 --- /dev/null +++ b/velox/substrait/SubstraitToVeloxPlanValidator.h @@ -0,0 +1,131 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/substrait/SubstraitToVeloxPlan.h" + +namespace facebook::velox::substrait { + +/// This class is used to validate whether the computing of +/// a Substrait plan is supported in Velox. +class SubstraitToVeloxPlanValidator { + public: + SubstraitToVeloxPlanValidator( + memory::MemoryPool* pool, + core::ExecCtx* execCtx) + : pool_(pool), execCtx_(execCtx) {} + + /// Used to validate whether the computing of this Limit is supported. + bool validate(const ::substrait::FetchRel& fetchRel); + + /// Used to validate whether the computing of this Expand is supported. + bool validate(const ::substrait::ExpandRel& expandRel); + + /// Used to validate whether the computing of this Sort is supported. + bool validate(const ::substrait::SortRel& sortRel); + + /// Used to validate whether the computing of this Window is supported. + bool validate(const ::substrait::WindowRel& windowRel); + + /// Used to validate whether the computing of this Aggregation is supported. + bool validate(const ::substrait::AggregateRel& aggRel); + + /// Used to validate whether the computing of this Project is supported. + bool validate(const ::substrait::ProjectRel& projectRel); + + /// Used to validate whether the computing of this Filter is supported. + bool validate(const ::substrait::FilterRel& filterRel); + + /// Used to validate Join. + bool validate(const ::substrait::JoinRel& joinRel); + + /// Used to validate whether the computing of this Read is supported. + bool validate(const ::substrait::ReadRel& readRel); + + /// Used to validate whether the computing of this Rel is supported. + bool validate(const ::substrait::Rel& rel); + + /// Used to validate whether the computing of this RelRoot is supported. + bool validate(const ::substrait::RelRoot& relRoot); + + /// Used to validate whether the computing of this Plan is supported. + bool validate(const ::substrait::Plan& plan); + + private: + /// A memory pool used for function validation. + memory::MemoryPool* pool_; + + /// An execution context used for function validation. + core::ExecCtx* execCtx_; + + /// A converter used to convert Substrait plan into Velox's plan node. + std::shared_ptr planConverter_ = + std::make_shared(pool_, true); + + /// A parser used to convert Substrait plan into recognizable representations. + std::shared_ptr subParser_ = + std::make_shared(); + + /// An expression converter used to convert Substrait representations into + /// Velox expressions. + std::shared_ptr exprConverter_; + + /// Used to get types from advanced extension and validate them. + bool validateInputTypes( + const ::substrait::extensions::AdvancedExtension& extension, + std::vector& types); + + bool validateAggRelFunctionType( + const ::substrait::AggregateRel& substraitAgg); + + /// Validate the round scalar function. + bool validateRound( + const ::substrait::Expression::ScalarFunction& scalarFunction, + const RowTypePtr& inputType); + + /// Validate extract function. + bool validateExtractExpr( + const std::vector>& params); + + /// Validate Substrait scarlar function. + bool validateScalarFunction( + const ::substrait::Expression::ScalarFunction& scalarFunction, + const RowTypePtr& inputType); + + /// Validate Substrait Cast expression. + bool validateCast( + const ::substrait::Expression::Cast& castExpr, + const RowTypePtr& inputType); + + /// Validate Substrait expression. + bool validateExpression( + const ::substrait::Expression& expression, + const RowTypePtr& inputType); + + /// Validate Substrait literal. + bool validateLiteral( + const ::substrait::Expression_Literal& literal, + const RowTypePtr& inputType); + + /// Create RowType based on the type information in string. + TypePtr getRowType(const std::string& structType); + + /// Create DecimalType based on the type information in string. + TypePtr getDecimalType(const std::string& decimalType); +}; + +} // namespace facebook::velox::substrait diff --git a/velox/substrait/TypeUtils.cpp b/velox/substrait/TypeUtils.cpp index 77cbd17d6dfe..a111ac4853ea 100644 --- a/velox/substrait/TypeUtils.cpp +++ b/velox/substrait/TypeUtils.cpp @@ -65,11 +65,30 @@ std::string_view getNameBeforeDelimiter( return std::string_view(compoundName.data(), pos); } +std::pair getPrecisionAndScale(const std::string& typeName) { + std::size_t start = typeName.find_first_of("<"); + std::size_t end = typeName.find_last_of(">"); + if (start == std::string::npos || end == std::string::npos) { + throw std::runtime_error("Invalid decimal type."); + } + + std::string decimalType = typeName.substr(start + 1, end - start - 1); + std::size_t token_pos = decimalType.find_first_of(","); + auto precision = stoi(decimalType.substr(0, token_pos)); + auto scale = + stoi(decimalType.substr(token_pos + 1, decimalType.length() - 1)); + return std::make_pair(precision, scale); +} + TypePtr toVeloxType(const std::string& typeName) { VELOX_CHECK(!typeName.empty(), "Cannot convert empty string to Velox type."); auto type = getNameBeforeDelimiter(typeName, "<"); auto typeKind = mapNameToTypeKind(std::string(type)); + if (isDecimalName(typeName)) { + auto decimal = getPrecisionAndScale(typeName); + return DECIMAL(decimal.first, decimal.second); + } switch (typeKind) { case TypeKind::BOOLEAN: return BOOLEAN(); @@ -120,6 +139,9 @@ TypePtr toVeloxType(const std::string& typeName) { case TypeKind::DATE: { return DATE(); } + case TypeKind::TIMESTAMP: { + return TIMESTAMP(); + } case TypeKind::UNKNOWN: return UNKNOWN(); default: diff --git a/velox/substrait/TypeUtils.h b/velox/substrait/TypeUtils.h index 3a649eef674c..854efbeed84d 100644 --- a/velox/substrait/TypeUtils.h +++ b/velox/substrait/TypeUtils.h @@ -14,6 +14,8 @@ * limitations under the License. */ +#include "velox/substrait/SubstraitParser.h" +#include "velox/type/Filter.h" #include "velox/type/Type.h" namespace facebook::velox::substrait { @@ -24,4 +26,74 @@ TypePtr toVeloxType(const std::string& typeName); std::string_view getNameBeforeDelimiter( const std::string& compoundName, const std::string& delimiter); +#ifndef RANGETRAITS_H +#define RANGETRAITS_H + +// Traits used to map type kind to the range used in Filter. +template +struct RangeTraits {}; + +template <> +struct RangeTraits { + using RangeType = common::BigintRange; + using MultiRangeType = common::BigintMultiRange; + using NativeType = int8_t; +}; + +template <> +struct RangeTraits { + using RangeType = common::BigintRange; + using MultiRangeType = common::BigintMultiRange; + using NativeType = int16_t; +}; + +template <> +struct RangeTraits { + using RangeType = common::BigintRange; + using MultiRangeType = common::BigintMultiRange; + using NativeType = int32_t; +}; + +template <> +struct RangeTraits { + using RangeType = common::BigintRange; + using MultiRangeType = common::BigintMultiRange; + using NativeType = int64_t; +}; + +template <> +struct RangeTraits { + using RangeType = common::DoubleRange; + using MultiRangeType = common::MultiRange; + using NativeType = double; +}; + +template <> +struct RangeTraits { + using RangeType = common::BigintRange; + using MultiRangeType = common::BigintMultiRange; + using NativeType = bool; +}; + +template <> +struct RangeTraits { + using RangeType = common::BytesRange; + using MultiRangeType = common::MultiRange; + using NativeType = std::string; +}; + +template <> +struct RangeTraits { + using RangeType = common::BigintRange; + using MultiRangeType = common::BigintMultiRange; + using NativeType = int32_t; +}; + +template <> +struct RangeTraits { + using NativeType = int128_t; +}; + +#endif /* RANGETRAITS_H */ + } // namespace facebook::velox::substrait diff --git a/velox/substrait/proto/substrait/algebra.proto b/velox/substrait/proto/substrait/algebra.proto index 3913871aca09..98b1959d9e04 100644 --- a/velox/substrait/proto/substrait/algebra.proto +++ b/velox/substrait/proto/substrait/algebra.proto @@ -4,8 +4,8 @@ syntax = "proto3"; package substrait; import "google/protobuf/any.proto"; -import "velox/substrait/proto/substrait/extensions/extensions.proto"; -import "velox/substrait/proto/substrait/type.proto"; +import "substrait/extensions/extensions.proto"; +import "substrait/type.proto"; option csharp_namespace = "Substrait.Protobuf"; option go_package = "github.com/substrait-io/substrait-go/proto"; @@ -168,11 +168,12 @@ message JoinRel { JOIN_TYPE_OUTER = 2; JOIN_TYPE_LEFT = 3; JOIN_TYPE_RIGHT = 4; - JOIN_TYPE_SEMI = 5; - JOIN_TYPE_ANTI = 6; + JOIN_TYPE_LEFT_SEMI = 5; + JOIN_TYPE_RIGHT_SEMI = 6; + JOIN_TYPE_ANTI = 7; // This join is useful for nested sub-queries where we need exactly one tuple in output (or throw exception) // See Section 3.2 of https://15721.courses.cs.cmu.edu/spring2018/papers/16-optimizer2/hyperjoins-btw2017.pdf - JOIN_TYPE_SINGLE = 7; + JOIN_TYPE_SINGLE = 8; } substrait.extensions.AdvancedExtension advanced_extension = 10; @@ -236,6 +237,19 @@ message SortRel { substrait.extensions.AdvancedExtension advanced_extension = 10; } +message WindowRel { + RelCommon common = 1; + Rel input = 2; + repeated Measure measures = 3; + repeated Expression partition_expressions = 4; + repeated SortField sorts = 5; + substrait.extensions.AdvancedExtension advanced_extension = 10; + + message Measure { + Expression.WindowFunction measure = 1; + } +} + // The relational operator capturing simple FILTERs (as in the WHERE clause of SQL) message FilterRel { RelCommon common = 1; @@ -340,6 +354,35 @@ message ExchangeRel { } } +// Duplicates records, possibly switching output expressions between each duplicate. +// Default output is all of the fields declared followed by one int64 field that contains the +// duplicate_id which is a zero-index ordinal of which duplicate of the original record this +// corresponds to. +message ExpandRel { + RelCommon common = 1; + Rel input = 2; + repeated ExpandField fields = 4; + substrait.extensions.AdvancedExtension advanced_extension = 10; + + message ExpandField { + oneof field_type { + // Field that switches output based on which duplicate_id we're outputting + SwitchingField switching_field = 2; + + // Field that outputs the same value no matter which duplicate_id we're on. + Expression consistent_field = 3; + } + } + + message SwitchingField { + // Array that contains an expression to output per duplicate_id + // each `switching_field` must have the same number of expressions + // all expressions within a switching field be the same type class but can differ in nullability. + // this column will be nullable if any of the expressions are nullable. + repeated Expression duplicates = 1; + } +} + // A relation with output field names. // // This is for use at the root of a `Rel` tree. @@ -369,6 +412,9 @@ message Rel { //Physical relations HashJoinRel hash_join = 13; MergeJoinRel merge_join = 14; + ExpandRel expand = 15; + WindowRel window = 16; + GenerateRel generate = 17; } } @@ -836,6 +882,9 @@ message Expression { // Optional; defaults to the start of the partition. Bound lower_bound = 5; + string column_name = 12; + WindowType window_type = 13; + // Defines the record relative to the current record up to which the window // extends. The bound is inclusive. If the upper bound indexes a record // less than the lower bound, TODO (null range/no records passed? @@ -867,10 +916,9 @@ message Expression { // Defines that the bound extends to or from the current record. message CurrentRow {} - // Defines an "unbounded bound": for lower bounds this means the start - // of the partition, and for upper bounds this means the end of the - // partition. - message Unbounded {} + message Unbounded_Preceding {} + + message Unbounded_Following {} oneof kind { // The bound extends some number of records behind the current record. @@ -883,10 +931,8 @@ message Expression { // The bound extends to the current record. CurrentRow current_row = 3; - // The bound extends to the start of the partition or the end of the - // partition, depending on whether this represents the upper or lower - // bound. - Unbounded unbounded = 4; + Unbounded_Preceding unbounded_preceding = 4; + Unbounded_Following unbounded_following = 5; } } } @@ -1176,6 +1222,17 @@ message Expression { } } +message GenerateRel { + RelCommon common = 1; + Rel input = 2; + + Expression generator = 3; + repeated Expression child_output = 4; + bool outer = 5; + + substrait.extensions.AdvancedExtension advanced_extension = 10; +} + // The description of a field to sort on (including the direction of sorting and null semantics) message SortField { Expression expr = 1; @@ -1224,6 +1281,11 @@ enum AggregationPhase { AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT = 4; } +enum WindowType { + ROWS = 0; + RANGE = 1; +} + // An aggregate function. message AggregateFunction { // Points to a function_anchor defined in this plan, which must refer diff --git a/velox/substrait/proto/substrait/ddl.proto b/velox/substrait/proto/substrait/ddl.proto new file mode 100644 index 000000000000..833ec87369ae --- /dev/null +++ b/velox/substrait/proto/substrait/ddl.proto @@ -0,0 +1,25 @@ +syntax = "proto3"; + +package substrait; + +import "substrait/plan.proto"; +import "substrait/algebra.proto"; + +option java_multiple_files = true; +option java_package = "io.substrait.proto"; +option csharp_namespace = "Substrait.Protobuf"; + +message DllPlan { + oneof dll_type { + InsertPlan insert_plan = 1; + } +} + +message InsertPlan { + Plan input = 1; + ReadRel.ExtensionTable output = 2; +} + +message Dll { + repeated DllPlan dll_plan = 1; +} \ No newline at end of file diff --git a/velox/substrait/proto/substrait/extended_expression.proto b/velox/substrait/proto/substrait/extended_expression.proto index 1c88b384edca..5d1152055930 100644 --- a/velox/substrait/proto/substrait/extended_expression.proto +++ b/velox/substrait/proto/substrait/extended_expression.proto @@ -3,10 +3,10 @@ syntax = "proto3"; package substrait; -import "velox/substrait/proto/substrait/algebra.proto"; -import "velox/substrait/proto/substrait/extensions/extensions.proto"; -import "velox/substrait/proto/substrait/plan.proto"; -import "velox/substrait/proto/substrait/type.proto"; +import "substrait/algebra.proto"; +import "substrait/extensions/extensions.proto"; +import "substrait/plan.proto"; +import "substrait/type.proto"; option csharp_namespace = "Substrait.Protobuf"; option go_package = "github.com/substrait-io/substrait-go/proto"; diff --git a/velox/substrait/proto/substrait/function.proto b/velox/substrait/proto/substrait/function.proto index 1fc0376f6625..123f4a1bf749 100644 --- a/velox/substrait/proto/substrait/function.proto +++ b/velox/substrait/proto/substrait/function.proto @@ -3,9 +3,9 @@ syntax = "proto3"; package substrait; -import "velox/substrait/proto/substrait/parameterized_types.proto"; -import "velox/substrait/proto/substrait/type.proto"; -import "velox/substrait/proto/substrait/type_expressions.proto"; +import "substrait/parameterized_types.proto"; +import "substrait/type.proto"; +import "substrait/type_expressions.proto"; option csharp_namespace = "Substrait.Protobuf"; option go_package = "github.com/substrait-io/substrait-go/proto"; diff --git a/velox/substrait/proto/substrait/parameterized_types.proto b/velox/substrait/proto/substrait/parameterized_types.proto index 9408101c548e..db0669354fa2 100644 --- a/velox/substrait/proto/substrait/parameterized_types.proto +++ b/velox/substrait/proto/substrait/parameterized_types.proto @@ -3,7 +3,7 @@ syntax = "proto3"; package substrait; -import "velox/substrait/proto/substrait/type.proto"; +import "substrait/type.proto"; option csharp_namespace = "Substrait.Protobuf"; option go_package = "github.com/substrait-io/substrait-go/proto"; diff --git a/velox/substrait/proto/substrait/plan.proto b/velox/substrait/proto/substrait/plan.proto index 71590a63b531..e5657fb8f1ef 100644 --- a/velox/substrait/proto/substrait/plan.proto +++ b/velox/substrait/proto/substrait/plan.proto @@ -3,8 +3,8 @@ syntax = "proto3"; package substrait; -import "velox/substrait/proto/substrait/algebra.proto"; -import "velox/substrait/proto/substrait/extensions/extensions.proto"; +import "substrait/algebra.proto"; +import "substrait/extensions/extensions.proto"; option csharp_namespace = "Substrait.Protobuf"; option go_package = "github.com/substrait-io/substrait-go/proto"; diff --git a/velox/substrait/proto/substrait/type.proto b/velox/substrait/proto/substrait/type.proto index a7f1c665db30..37cc95ffd2a2 100644 --- a/velox/substrait/proto/substrait/type.proto +++ b/velox/substrait/proto/substrait/type.proto @@ -45,6 +45,8 @@ message Type { // encountered, treat it as being non-nullable and having the default // variation. uint32 user_defined_type_reference = 31 [deprecated = true]; + + Nothing nothing = 33; } enum Nullability { @@ -53,6 +55,10 @@ message Type { NULLABILITY_REQUIRED = 2; } + message Nothing { + uint32 type_variation_reference = 1; + } + message Boolean { uint32 type_variation_reference = 1; Nullability nullability = 2; @@ -226,4 +232,13 @@ message NamedStruct { // list of names in dfs order repeated string names = 1; Type.Struct struct = 2; + PartitionColumns partition_columns = 3; +} + +message PartitionColumns { + repeated ColumnType column_type = 1; + enum ColumnType { + NORMAL_COL = 0; + PARTITION_COL = 1; + } } diff --git a/velox/substrait/proto/substrait/type_expressions.proto b/velox/substrait/proto/substrait/type_expressions.proto index 59782d5230ff..4be4aab47d40 100644 --- a/velox/substrait/proto/substrait/type_expressions.proto +++ b/velox/substrait/proto/substrait/type_expressions.proto @@ -3,7 +3,7 @@ syntax = "proto3"; package substrait; -import "velox/substrait/proto/substrait/type.proto"; +import "substrait/type.proto"; option csharp_namespace = "Substrait.Protobuf"; option go_package = "github.com/substrait-io/substrait-go/proto"; diff --git a/velox/substrait/tests/CMakeLists.txt b/velox/substrait/tests/CMakeLists.txt index f453fb61b7c4..f8da0d101510 100644 --- a/velox/substrait/tests/CMakeLists.txt +++ b/velox/substrait/tests/CMakeLists.txt @@ -16,6 +16,7 @@ add_executable( velox_plan_conversion_test Substrait2VeloxPlanConversionTest.cpp Substrait2VeloxValuesNodeConversionTest.cpp + Substrait2VeloxPlanValidatorTest.cpp FunctionTest.cpp JsonToProtoConverter.cpp VeloxSubstraitRoundTripTest.cpp diff --git a/velox/substrait/tests/FunctionTest.cpp b/velox/substrait/tests/FunctionTest.cpp index 876b0954446f..41ed444082e9 100644 --- a/velox/substrait/tests/FunctionTest.cpp +++ b/velox/substrait/tests/FunctionTest.cpp @@ -95,31 +95,31 @@ TEST_F(FunctionTest, constructFunctionMap) { auto functionMap = planConverter_->getFunctionMap(); ASSERT_EQ(functionMap.size(), 9); - std::string function = planConverter_->findFunction(1); + std::string function = planConverter_->findFuncSpec(1); ASSERT_EQ(function, "lte:fp64_fp64"); - function = planConverter_->findFunction(2); + function = planConverter_->findFuncSpec(2); ASSERT_EQ(function, "and:bool_bool"); - function = planConverter_->findFunction(3); + function = planConverter_->findFuncSpec(3); ASSERT_EQ(function, "subtract:opt_fp64_fp64"); - function = planConverter_->findFunction(4); + function = planConverter_->findFuncSpec(4); ASSERT_EQ(function, "multiply:opt_fp64_fp64"); - function = planConverter_->findFunction(5); + function = planConverter_->findFuncSpec(5); ASSERT_EQ(function, "add:opt_fp64_fp64"); - function = planConverter_->findFunction(6); + function = planConverter_->findFuncSpec(6); ASSERT_EQ(function, "sum:opt_fp64"); - function = planConverter_->findFunction(7); + function = planConverter_->findFuncSpec(7); ASSERT_EQ(function, "count:opt_fp64"); - function = planConverter_->findFunction(8); + function = planConverter_->findFuncSpec(8); ASSERT_EQ(function, "count:opt_i32"); - function = planConverter_->findFunction(9); + function = planConverter_->findFuncSpec(9); ASSERT_EQ(function, "is_not_null:fp64"); } @@ -199,3 +199,19 @@ TEST_F(FunctionTest, setVectorFromVariants) { ASSERT_EQ(9020, resultVec->asFlatVector()->valueAt(0)); ASSERT_EQ(8875, resultVec->asFlatVector()->valueAt(1)); } + +TEST_F(FunctionTest, getFunctionType) { + std::vector types; + substraitParser_->getSubFunctionTypes("sum:opt_i32", types); + ASSERT_EQ("i32", types[0]); + + types.clear(); + substraitParser_->getSubFunctionTypes("sum:i32", types); + ASSERT_EQ("i32", types[0]); + + types.clear(); + substraitParser_->getSubFunctionTypes("sum:opt_str_str", types); + ASSERT_EQ(2, types.size()); + ASSERT_EQ("str", types[0]); + ASSERT_EQ("str", types[1]); +} diff --git a/velox/substrait/tests/PlanConversionTest.cpp b/velox/substrait/tests/PlanConversionTest.cpp new file mode 100644 index 000000000000..014da7f5738d --- /dev/null +++ b/velox/substrait/tests/PlanConversionTest.cpp @@ -0,0 +1,599 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "velox/common/base/tests/Fs.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/connectors/hive/HiveConnector.h" +#include "velox/connectors/hive/HiveConnectorSplit.h" +#include "velox/dwio/common/Options.h" +#include "velox/dwio/dwrf/test/utils/DataFiles.h" +#include "velox/exec/PartitionedOutputBufferManager.h" +#include "velox/exec/tests/utils/Cursor.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/functions/prestosql/aggregates/AverageAggregate.h" +#include "velox/functions/prestosql/aggregates/CountAggregate.h" +#include "velox/functions/sparksql/Register.h" +#include "velox/substrait/SubstraitToVeloxPlan.h" +#include "velox/type/Type.h" +#include "velox/type/tests/FilterBuilder.h" +#include "velox/type/tests/SubfieldFiltersBuilder.h" + +using namespace facebook::velox; +using namespace facebook::velox::connector::hive; +using namespace facebook::velox::exec; +using namespace facebook::velox::common::test; +using namespace facebook::velox::exec::test; + +class PlanConversionTest : public virtual HiveConnectorTestBase, + public testing::WithParamInterface { + protected: + class VeloxConverter { + public: + // This class is an iterator for Velox computing. + class WholeComputeResultIterator { + public: + WholeComputeResultIterator( + const std::shared_ptr& planNode, + u_int32_t index, + const std::vector& paths, + const std::vector& starts, + const std::vector& lengths, + const dwio::common::FileFormat& format) + : planNode_(planNode), + index_(index), + paths_(paths), + starts_(starts), + lengths_(lengths), + format_(format) { + // Construct the splits. + std::vector> + connectorSplits; + connectorSplits.reserve(paths.size()); + for (int idx = 0; idx < paths.size(); idx++) { + auto path = paths[idx]; + auto start = starts[idx]; + auto length = lengths[idx]; + auto split = std::make_shared< + facebook::velox::connector::hive::HiveConnectorSplit>( + facebook::velox::exec::test::kHiveConnectorId, + path, + format, + start, + length); + connectorSplits.emplace_back(split); + } + splits_.reserve(connectorSplits.size()); + for (const auto& connectorSplit : connectorSplits) { + splits_.emplace_back(exec::Split(folly::copy(connectorSplit), -1)); + } + + params_.planNode = planNode; + cursor_ = std::make_unique(params_); + addSplits_ = [&](Task* task) { + if (noMoreSplits_) { + return; + } + for (auto& split : splits_) { + task->addSplit("0", std::move(split)); + } + task->noMoreSplits("0"); + noMoreSplits_ = true; + }; + } + + bool HasNext() { + if (!mayHasNext_) { + return false; + } + if (numRows_ > 0) { + return true; + } else { + addSplits_(cursor_->task().get()); + if (cursor_->moveNext()) { + result_ = cursor_->current(); + numRows_ += result_->size(); + return true; + } else { + mayHasNext_ = false; + return false; + } + } + } + + RowVectorPtr Next() { + numRows_ = 0; + return result_; + } + + private: + const std::shared_ptr planNode_; + std::unique_ptr cursor_; + exec::test::CursorParameters params_; + std::vector splits_; + bool noMoreSplits_ = false; + std::function addSplits_; + u_int32_t index_; + std::vector paths_; + std::vector starts_; + std::vector lengths_; + dwio::common::FileFormat format_; + uint64_t numRows_ = 0; + bool mayHasNext_ = true; + RowVectorPtr result_; + }; + + // This method will resume the Substrait plan from Json file, + // and convert it into Velox PlanNode. A result iterator for + // Velox computing will be returned. + std::shared_ptr getResIter( + const std::string& subPlanPath) { + // Read json file and resume the Substrait plan. + std::ifstream subJson(subPlanPath); + std::stringstream buffer; + buffer << subJson.rdbuf(); + std::string subData = buffer.str(); + ::substrait::Plan subPlan; + google::protobuf::util::JsonStringToMessage(subData, &subPlan); + + auto planConverter = std::make_shared< + facebook::velox::substrait::SubstraitVeloxPlanConverter>( + memoryPool_.get()); + // Convert to Velox PlanNode. + auto planNode = planConverter->toVeloxPlan(subPlan); + + auto splitInfos = planConverter->splitInfos(); + auto leafPlanNodeIds = planNode->leafPlanNodeIds(); + // Here only one leaf node is expected here. + EXPECT_EQ(1, leafPlanNodeIds.size()); + auto iter = leafPlanNodeIds.begin(); + auto splitInfo = splitInfos[*iter].get(); + + // Get the information for TableScan. + u_int32_t partitionIndex = splitInfo->partitionIndex; + std::vector paths = splitInfo->paths; + + // In test, need to get the absolute path of the generated ORC file. + auto tempPath = getTmpDirPath(); + std::vector absolutePaths; + absolutePaths.reserve(paths.size()); + + for (const auto& path : paths) { + absolutePaths.emplace_back(fmt::format("file://{}{}", tempPath, path)); + } + + std::vector starts = splitInfo->starts; + std::vector lengths = splitInfo->lengths; + auto format = splitInfo->format; + // Construct the result iterator. + auto resIter = std::make_shared( + planNode, partitionIndex, absolutePaths, starts, lengths, format); + return resIter; + } + + std::string getTmpDirPath() const { + return tmpDir_->path; + } + + std::shared_ptr tmpDir_{ + exec::test::TempDirectoryPath::create()}; + + private: + std::unique_ptr memoryPool_{ + memory::getDefaultScopedMemoryPool()}; + }; + + void SetUp() override { + useAsyncCache_ = GetParam(); + HiveConnectorTestBase::SetUp(); + + aggregate::registerSumAggregate("sum"); + aggregate::registerAverageAggregate("avg"); + aggregate::registerCountAggregate("count"); + } + + static void SetUpTestCase() { + HiveConnectorTestBase::SetUpTestCase(); + } + + std::vector makeVectors( + int32_t count, + int32_t rowsPerVector, + const std::shared_ptr& rowType) { + return HiveConnectorTestBase::makeVectors(rowType, count, rowsPerVector); + } + + std::unique_ptr pool_{ + facebook::velox::memory::getDefaultScopedMemoryPool()}; + + // This method can be used to create a Fixed-width type of Vector without Null + // values. + template + VectorPtr createSpecificScalar( + size_t size, + std::vector vals, + memory::MemoryPool* pool) { + facebook::velox::BufferPtr values = AlignedBuffer::allocate(size, pool); + auto valuesPtr = values->asMutableRange(); + facebook::velox::BufferPtr nulls = nullptr; + for (size_t i = 0; i < size; ++i) { + valuesPtr[i] = vals[i]; + } + return std::make_shared>( + pool, nulls, size, values, std::vector{}); + } + + // This method can be used to create a String type of Vector without Null + // values. + VectorPtr createSpecificStringVector( + size_t size, + std::vector vals, + memory::MemoryPool* pool) { + auto vector = BaseVector::create(VARCHAR(), size, pool); + auto flatVector = vector->asFlatVector(); + + size_t childSize = 0; + std::vector lengths(size); + size_t nullCount = 0; + for (size_t i = 0; i < size; ++i) { + auto notNull = true; + vector->setNull(i, !notNull); + auto len = vals[i].size(); + lengths[i] = len; + childSize += len; + } + vector->setNullCount(0); + + BufferPtr buf = AlignedBuffer::allocate(childSize, pool); + char* bufPtr = buf->asMutable(); + char* dest = bufPtr; + for (size_t i = 0; i < size; ++i) { + std::string str = vals[i]; + const char* chr = str.c_str(); + auto length = str.size(); + memcpy(dest, chr, length); + dest = dest + length; + } + size_t offset = 0; + for (size_t i = 0; i < size; ++i) { + if (!vector->isNullAt(i)) { + flatVector->set( + i, facebook::velox::StringView(bufPtr + offset, lengths[i])); + offset += lengths[i]; + } + } + return vector; + } + + void genLineitemORC(const std::shared_ptr& veloxConverter) { + auto type = + ROW({"l_orderkey", + "l_partkey", + "l_suppkey", + "l_linenumber", + "l_quantity", + "l_extendedprice", + "l_discount", + "l_tax", + "l_returnflag", + "l_linestatus", + "l_shipdate", + "l_commitdate", + "l_receiptdate", + "l_shipinstruct", + "l_shipmode", + "l_comment"}, + {BIGINT(), + BIGINT(), + BIGINT(), + INTEGER(), + DOUBLE(), + DOUBLE(), + DOUBLE(), + DOUBLE(), + VARCHAR(), + VARCHAR(), + DOUBLE(), + DOUBLE(), + DOUBLE(), + VARCHAR(), + VARCHAR(), + VARCHAR()}); + + std::vector vectors; + // TPC-H lineitem table has 16 columns. + int colNum = 16; + vectors.reserve(colNum); + std::vector lOrderkeyData = { + 4636438147, + 2012485446, + 1635327427, + 8374290148, + 2972204230, + 8001568994, + 989963396, + 2142695974, + 6354246853, + 4141748419}; + vectors.emplace_back( + createSpecificScalar(10, lOrderkeyData, pool_.get())); + std::vector lPartkeyData = { + 263222018, + 255918298, + 143549509, + 96877642, + 201976875, + 196938305, + 100260625, + 273511608, + 112999357, + 299103530}; + vectors.emplace_back( + createSpecificScalar(10, lPartkeyData, pool_.get())); + std::vector lSuppkeyData = { + 2102019, + 13998315, + 12989528, + 4717643, + 9976902, + 12618306, + 11940632, + 871626, + 1639379, + 3423588}; + vectors.emplace_back( + createSpecificScalar(10, lSuppkeyData, pool_.get())); + std::vector lLinenumberData = {4, 6, 1, 5, 1, 2, 1, 5, 2, 6}; + vectors.emplace_back( + createSpecificScalar(10, lLinenumberData, pool_.get())); + std::vector lQuantityData = { + 6.0, 1.0, 19.0, 4.0, 6.0, 12.0, 23.0, 11.0, 16.0, 19.0}; + vectors.emplace_back( + createSpecificScalar(10, lQuantityData, pool_.get())); + std::vector lExtendedpriceData = { + 30586.05, + 7821.0, + 1551.33, + 30681.2, + 1941.78, + 66673.0, + 6322.44, + 41754.18, + 8704.26, + 63780.36}; + vectors.emplace_back( + createSpecificScalar(10, lExtendedpriceData, pool_.get())); + std::vector lDiscountData = { + 0.05, 0.06, 0.01, 0.07, 0.05, 0.06, 0.07, 0.05, 0.06, 0.07}; + vectors.emplace_back( + createSpecificScalar(10, lDiscountData, pool_.get())); + std::vector lTaxData = { + 0.02, 0.03, 0.01, 0.0, 0.01, 0.01, 0.03, 0.07, 0.01, 0.04}; + vectors.emplace_back( + createSpecificScalar(10, lTaxData, pool_.get())); + std::vector lReturnflagData = { + "N", "A", "A", "R", "A", "N", "A", "A", "N", "R"}; + vectors.emplace_back( + createSpecificStringVector(10, lReturnflagData, pool_.get())); + std::vector lLinestatusData = { + "O", "F", "F", "F", "F", "O", "F", "F", "O", "F"}; + vectors.emplace_back( + createSpecificStringVector(10, lLinestatusData, pool_.get())); + std::vector lShipdateNewData = { + 8953.666666666666, + 8773.666666666666, + 9034.666666666666, + 8558.666666666666, + 9072.666666666666, + 8864.666666666666, + 9004.666666666666, + 8778.666666666666, + 9013.666666666666, + 8832.666666666666}; + vectors.emplace_back( + createSpecificScalar(10, lShipdateNewData, pool_.get())); + std::vector lCommitdateNewData = { + 10447.666666666666, + 8953.666666666666, + 8325.666666666666, + 8527.666666666666, + 8438.666666666666, + 10049.666666666666, + 9036.666666666666, + 8666.666666666666, + 9519.666666666666, + 9138.666666666666}; + vectors.emplace_back( + createSpecificScalar(10, lCommitdateNewData, pool_.get())); + std::vector lReceiptdateNewData = { + 10456.666666666666, + 8979.666666666666, + 8299.666666666666, + 8474.666666666666, + 8525.666666666666, + 9996.666666666666, + 9103.666666666666, + 8726.666666666666, + 9593.666666666666, + 9178.666666666666}; + vectors.emplace_back( + createSpecificScalar(10, lReceiptdateNewData, pool_.get())); + std::vector lShipinstructData = { + "COLLECT COD", + "NONE", + "TAKE BACK RETURN", + "NONE", + "TAKE BACK RETURN", + "NONE", + "DELIVER IN PERSON", + "DELIVER IN PERSON", + "TAKE BACK RETURN", + "NONE"}; + vectors.emplace_back( + createSpecificStringVector(10, lShipinstructData, pool_.get())); + std::vector lShipmodeData = { + "FOB", + "REG AIR", + "MAIL", + "FOB", + "RAIL", + "SHIP", + "REG AIR", + "REG AIR", + "TRUCK", + "AIR"}; + vectors.emplace_back( + createSpecificStringVector(10, lShipmodeData, pool_.get())); + std::vector lCommentData = { + " the furiously final foxes. quickly final p", + "thely ironic", + "ate furiously. even, pending pinto bean", + "ackages af", + "odolites. slyl", + "ng the regular requests sleep above", + "lets above the slyly ironic theodolites sl", + "lyly regular excuses affi", + "lly unusual theodolites grow slyly above", + " the quickly ironic pains lose car"}; + vectors.emplace_back( + createSpecificStringVector(10, lCommentData, pool_.get())); + + // Batches has only one RowVector here. + uint64_t nullCount = 0; + std::vector batches{std::make_shared( + pool_.get(), type, nullptr, 10, vectors, nullCount)}; + + // Writes data into an ORC file. + auto sink = std::make_unique( + veloxConverter->getTmpDirPath() + "/mock_lineitem.orc"); + auto config = std::make_shared(); + const int64_t writerMemoryCap = std::numeric_limits::max(); + facebook::velox::dwrf::WriterOptions options; + options.config = config; + options.schema = type; + options.memoryBudget = writerMemoryCap; + options.flushPolicyFactory = nullptr; + options.layoutPlannerFactory = nullptr; + auto writer = std::make_unique( + options, + std::move(sink), + facebook::velox::memory::getProcessDefaultMemoryManager().getRoot()); + for (size_t i = 0; i < batches.size(); ++i) { + writer->write(batches[i]); + } + writer->close(); + } + + // Used to find the Velox path according current path. + std::string getVeloxPath() { + std::string veloxPath; + std::string currentPath = fs::current_path().c_str(); + size_t pos = 0; + + if ((pos = currentPath.find("project")) != std::string::npos) { + // In Github test, the Velox home is /root/project. + veloxPath = currentPath.substr(0, pos) + "project"; + } else if ((pos = currentPath.find("velox")) != std::string::npos) { + veloxPath = currentPath.substr(0, pos) + "velox"; + } else if ((pos = currentPath.find("fbcode")) != std::string::npos) { + veloxPath = currentPath; + } else { + throw std::runtime_error("Current path is not a valid Velox path."); + } + return veloxPath; + } +}; + +// This test will firstly generate mock TPC-H lineitem ORC file. Then, Velox's +// computing will be tested based on the generated ORC file. +// Input: Json file of the Substrait plan for the first stage of below modified +// TPC-H Q6 query: +// +// select sum(l_extendedprice*l_discount) as revenue from lineitem where +// l_shipdate >= 8766 and l_shipdate < 9131 and l_discount between .06 +// - 0.01 and .06 + 0.01 and l_quantity < 24 +// +// Tested Velox computings include: TableScan (Filter Pushdown) + Project + +// Aggregate +// Output: the Velox computed Aggregation result + +TEST_P(PlanConversionTest, q6FirstStageTest) { + auto veloxConverter = std::make_shared(); + std::string veloxPath = getVeloxPath(); + genLineitemORC(veloxConverter); + // Find and deserialize Substrait plan json file. + std::string subPlanPath = + veloxPath + "/velox/substrait/tests/q6_first_stage.json"; + auto resIter = veloxConverter->getResIter(subPlanPath); + while (resIter->HasNext()) { + auto rv = resIter->Next(); + auto size = rv->size(); + ASSERT_EQ(size, 1); + std::string res = rv->toString(0); + ASSERT_EQ(res, "{13613.1921}"); + } +} + +// This test will firstly generate mock TPC-H lineitem ORC file. Then, Velox's +// computing will be tested based on the generated ORC file. +// Input: Json file of the Substrait plan for the first stage of the below +// modified TPC-H Q1 query: +// +// select l_returnflag, l_linestatus, sum(l_quantity) as sum_qty, +// sum(l_extendedprice) as sum_base_price, sum(l_extendedprice * (1 - +// l_discount)) as sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + +// l_tax)) as sum_charge, avg(l_quantity) as avg_qty, avg(l_extendedprice) as +// avg_price, avg(l_discount) as avg_disc, count(*) as count_order from lineitem +// where l_shipdate <= 10471 group by l_returnflag, l_linestatus order by +// l_returnflag, l_linestatus +// +// Tested Velox computings include: TableScan (Filter Pushdown) + Project + +// Aggregate +// Output: the Velox computed Aggregation result + +TEST_P(PlanConversionTest, q1FirstStageTest) { + auto veloxConverter = std::make_shared(); + std::string veloxPath = getVeloxPath(); + genLineitemORC(veloxConverter); + // Find and deserialize Substrait plan json file. + std::string subPlanPath = + veloxPath + "/velox/substrait/tests/q1_first_stage.json"; + auto resIter = veloxConverter->getResIter(subPlanPath); + while (resIter->HasNext()) { + auto rv = resIter->Next(); + auto size = rv->size(); + ASSERT_EQ(size, 3); + ASSERT_EQ( + rv->toString(0), + "{N, O, 34, 105963.31, 99911.3719, 101201.05309399999, 34, 3, 105963.31, 3, 0.16999999999999998, 3, 3}"); + ASSERT_EQ( + rv->toString(1), + "{A, F, 60, 59390.729999999996, 56278.5879, 59485.994223, 60, 5, 59390.729999999996, 5, 0.24, 5, 5}"); + ASSERT_EQ( + rv->toString(2), + "{R, F, 23, 94461.56, 87849.2508, 90221.880192, 23, 2, 94461.56, 2, 0.14, 2, 2}"); + } +} + +VELOX_INSTANTIATE_TEST_SUITE_P( + PlanConversionTests, + PlanConversionTest, + testing::Values(true, false)); diff --git a/velox/substrait/tests/Substrait2VeloxPlanConversionTest.cpp b/velox/substrait/tests/Substrait2VeloxPlanConversionTest.cpp index bfd53a57adff..2dbf100d7906 100644 --- a/velox/substrait/tests/Substrait2VeloxPlanConversionTest.cpp +++ b/velox/substrait/tests/Substrait2VeloxPlanConversionTest.cpp @@ -18,17 +18,20 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/dwio/common/tests/utils/DataFiles.h" +#include "velox/dwio/dwrf/reader/DwrfReader.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/substrait/SubstraitToVeloxPlan.h" +#include "velox/substrait/SubstraitToVeloxPlanValidator.h" #include "velox/type/Type.h" using namespace facebook::velox; using namespace facebook::velox::test; using namespace facebook::velox::connector::hive; using namespace facebook::velox::exec; +namespace vestrait = facebook::velox::substrait; class Substrait2VeloxPlanConversionTest : public exec::test::HiveConnectorTestBase { @@ -68,6 +71,13 @@ class Substrait2VeloxPlanConversionTest std::shared_ptr tmpDir_{ exec::test::TempDirectoryPath::create()}; + std::shared_ptr planConverter_ = + std::make_shared( + memoryPool_.get()); + + private: + std::shared_ptr memoryPool_{ + memory::addDefaultLeafMemoryPool()}; }; // This test will firstly generate mock TPC-H lineitem ORC file. Then, Velox's @@ -260,29 +270,63 @@ TEST_F(Substrait2VeloxPlanConversionTest, q6) { " the quickly ironic pains lose car"}; vectors.emplace_back(makeFlatVector(lCommentData)); - // Write data into an ORC file. + // Write data into an DWRF file. writeToFile( - tmpDir_->path + "/mock_lineitem.orc", + tmpDir_->path + "/mock_lineitem.dwrf", {makeRowVector(type->names(), vectors)}); // Find and deserialize Substrait plan json file. - std::string planPath = + std::string subPlanPath = getDataFilePath("velox/substrait/tests", "data/q6_first_stage.json"); // Read q6_first_stage.json and resume the Substrait plan. ::substrait::Plan substraitPlan; - JsonToProtoConverter::readFromFile(planPath, substraitPlan); + JsonToProtoConverter::readFromFile(subPlanPath, substraitPlan); // Convert to Velox PlanNode. - facebook::velox::substrait::SubstraitVeloxPlanConverter planConverter( - pool_.get()); - auto planNode = planConverter.toVeloxPlan(substraitPlan); + auto planNode = planConverter_->toVeloxPlan(substraitPlan); auto expectedResult = makeRowVector({ makeFlatVector(1, [](auto /*row*/) { return 13613.1921; }), }); exec::test::AssertQueryBuilder(planNode) - .splits(makeSplits(planConverter, planNode)) + .splits(makeSplits(*planConverter_, planNode)) .assertResults(expectedResult); } + +TEST_F(Substrait2VeloxPlanConversionTest, ifthenTest) { + std::string subPlanPath = + getDataFilePath("velox/substrait/tests", "data/if_then.json"); + + ::substrait::Plan substraitPlan; + JsonToProtoConverter::readFromFile(subPlanPath, substraitPlan); + + // Convert to Velox PlanNode. + auto planNode = planConverter_->toVeloxPlan(substraitPlan); + ASSERT_EQ( + "-- Project[expressions: ] -> \n " + " -- TableScan[table: hive_table, range filters: " + "[(hd_demo_sk, Filter(IsNotNull, deterministic, null not allowed)), " + "(hd_vehicle_count, BigintRange: [1, 999999999999999999] no nulls)], " + "remaining filter: (and(or(equalto(ROW[\"hd_buy_potential\"],\">10000\")," + "equalto(ROW[\"hd_buy_potential\"],\"unknown\")),if(greaterthan(ROW[\"hd_vehicle_count\"],0)," + "greaterthan(divide(cast ROW[\"hd_dep_count\"] as DOUBLE,cast ROW[\"hd_vehicle_count\"] as DOUBLE),1.2))))] " + "-> n0_0:BIGINT, n0_1:VARCHAR, n0_2:BIGINT, n0_3:BIGINT\n", + planNode->toString(true, true)); +} + +TEST_F(Substrait2VeloxPlanConversionTest, filterUpper) { + std::string subPlanPath = + getDataFilePath("velox/substrait/tests", "data/filter_upper.json"); + + ::substrait::Plan substraitPlan; + JsonToProtoConverter::readFromFile(subPlanPath, substraitPlan); + + // Convert to Velox PlanNode. + auto planNode = planConverter_->toVeloxPlan(substraitPlan); + ASSERT_EQ( + "-- Project[expressions: ] -> \n -- TableScan[table: hive_table, range filters: " + "[(key, BigintRange: [-2147483648, 2] no nulls)]] -> n0_0:INTEGER\n", + planNode->toString(true, true)); +} diff --git a/velox/substrait/tests/Substrait2VeloxPlanValidatorTest.cpp b/velox/substrait/tests/Substrait2VeloxPlanValidatorTest.cpp new file mode 100644 index 000000000000..c442d2e6afc0 --- /dev/null +++ b/velox/substrait/tests/Substrait2VeloxPlanValidatorTest.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/substrait/tests/JsonToProtoConverter.h" + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/dwio/common/tests/utils/DataFiles.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/substrait/SubstraitToVeloxPlan.h" +#include "velox/substrait/SubstraitToVeloxPlanValidator.h" +#include "velox/type/Type.h" + +using namespace facebook::velox; +using namespace facebook::velox::test; +using namespace facebook::velox::connector::hive; +using namespace facebook::velox::exec; +namespace vestrait = facebook::velox::substrait; + +class Substrait2VeloxPlanValidatorTest + : public exec::test::HiveConnectorTestBase { + protected: + std::shared_ptr planConverter_ = + std::make_shared( + memoryPool_.get()); + + bool validatePlan(std::string file) { + std::string subPlanPath = + getDataFilePath("velox/substrait/tests", "data/" + file); + + ::substrait::Plan substraitPlan; + JsonToProtoConverter::readFromFile(subPlanPath, substraitPlan); + return validatePlan(substraitPlan); + } + + bool validatePlan(::substrait::Plan& plan) { + std::shared_ptr queryCtx = + std::make_shared(); + + // An execution context used for function validation. + std::unique_ptr execCtx = + std::make_unique(pool_.get(), queryCtx.get()); + + auto planValidator = std::make_shared< + facebook::velox::substrait::SubstraitToVeloxPlanValidator>( + pool_.get(), execCtx.get()); + return planValidator->validate(plan); + } + + private: + std::shared_ptr memoryPool_{ + memory::addDefaultLeafMemoryPool()}; +}; + +TEST_F(Substrait2VeloxPlanValidatorTest, group) { + std::string subPlanPath = + getDataFilePath("velox/substrait/tests", "data/group.json"); + + ::substrait::Plan substraitPlan; + JsonToProtoConverter::readFromFile(subPlanPath, substraitPlan); + + ASSERT_FALSE(validatePlan(substraitPlan)); + // Convert to Velox PlanNode. + EXPECT_ANY_THROW(planConverter_->toVeloxPlan(substraitPlan)); +} diff --git a/velox/substrait/tests/Substrait2VeloxValuesNodeConversionTest.cpp b/velox/substrait/tests/Substrait2VeloxValuesNodeConversionTest.cpp index 45927a2e7a00..733cdb5249bc 100644 --- a/velox/substrait/tests/Substrait2VeloxValuesNodeConversionTest.cpp +++ b/velox/substrait/tests/Substrait2VeloxValuesNodeConversionTest.cpp @@ -33,7 +33,7 @@ using namespace facebook::velox::substrait; class Substrait2VeloxValuesNodeConversionTest : public OperatorTestBase { protected: std::shared_ptr planConverter_ = - std::make_shared(pool_.get()); + std::make_shared(pool_.get(), true); }; // SELECT * FROM tmp diff --git a/velox/substrait/tests/VeloxSubstraitRoundTripTest.cpp b/velox/substrait/tests/VeloxSubstraitRoundTripTest.cpp index 149342bc065b..6c2f17d0bebb 100644 --- a/velox/substrait/tests/VeloxSubstraitRoundTripTest.cpp +++ b/velox/substrait/tests/VeloxSubstraitRoundTripTest.cpp @@ -21,10 +21,10 @@ #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/vector/tests/utils/VectorMaker.h" +#include "velox/functions/sparksql/Register.h" #include "velox/substrait/SubstraitToVeloxPlan.h" -#include "velox/substrait/VeloxToSubstraitPlan.h" - #include "velox/substrait/VariantToVectorConverter.h" +#include "velox/substrait/VeloxToSubstraitPlan.h" using namespace facebook::velox; using namespace facebook::velox::test; @@ -97,7 +97,7 @@ class VeloxSubstraitRoundTripTest : public OperatorTestBase { std::shared_ptr veloxConvertor_ = std::make_shared(); std::shared_ptr substraitConverter_ = - std::make_shared(pool_.get()); + std::make_shared(pool_.get(), true); }; TEST_F(VeloxSubstraitRoundTripTest, project) { @@ -508,6 +508,7 @@ TEST_F(VeloxSubstraitRoundTripTest, dateType) { } int main(int argc, char** argv) { + facebook::velox::functions::sparksql::registerFunctions(""); testing::InitGoogleTest(&argc, argv); folly::init(&argc, &argv, false); return RUN_ALL_TESTS(); diff --git a/velox/substrait/tests/data/filter_upper.json b/velox/substrait/tests/data/filter_upper.json new file mode 100644 index 000000000000..7b4211fbf008 --- /dev/null +++ b/velox/substrait/tests/data/filter_upper.json @@ -0,0 +1,137 @@ +{ + "extensions": [{ + "extensionFunction": { + "name": "is_not_null:opt_bool_i32" + } + }, { + "extensionFunction": { + "functionAnchor": 2, + "name": "and:opt_bool_bool" + } + }, { + "extensionFunction": { + "functionAnchor": 1, + "name": "lt:opt_i32_i32" + } + } + ], + "relations": [{ + "root": { + "input": { + "project": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": ["key"], + "struct": { + "types": [{ + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ] + }, + "partitionColumns": { + "columnType": ["NORMAL_COL"] + } + }, + "filter": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": {} + } + } + } + } + ] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": {} + } + } + } + }, { + "value": { + "literal": { + "i32": 3 + } + } + } + ] + } + } + } + ] + } + }, + "localFiles": { + "items": [{ + "uriFile": "file:///tmp/file.parquet", + "length": "1486", + "parquet": {} + } + ] + } + } + }, + "expressions": [{ + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "selection": { + "directReference": { + "structField": {} + } + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_RETURN_NULL" + } + } + ] + } + }, + "names": ["key#173"] + } + } + ] +} diff --git a/velox/substrait/tests/data/group.json b/velox/substrait/tests/data/group.json new file mode 100644 index 000000000000..a1b77da5ba42 --- /dev/null +++ b/velox/substrait/tests/data/group.json @@ -0,0 +1,34 @@ +{ + "relations": [{ + "root": { + "input": { + "aggregate": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "struct": {} + }, + "localFiles": { + "items": [{ + "uriFile": "file:///tmp/tmp_file", + "length": "31979", + "parquet": {} + } + ] + } + } + }, + "groupings": [{} + ] + } + } + } + } + ] +} diff --git a/velox/substrait/tests/data/if_then.json b/velox/substrait/tests/data/if_then.json new file mode 100644 index 000000000000..c2b26365cb02 --- /dev/null +++ b/velox/substrait/tests/data/if_then.json @@ -0,0 +1,398 @@ +{ + "extensions": [{ + "extensionFunction": { + "functionAnchor": 4, + "name": "gt:i64_i64" + } + }, { + "extensionFunction": { + "functionAnchor": 2, + "name": "or:bool_bool" + } + }, { + "extensionFunction": { + "functionAnchor": 1, + "name": "equal:str_str" + } + }, { + "extensionFunction": { + "functionAnchor": 5, + "name": "divide:opt_fp64_fp64" + } + }, { + "extensionFunction": { + "name": "is_not_null:i64" + } + }, { + "extensionFunction": { + "functionAnchor": 3, + "name": "and:bool_bool" + } + }, { + "extensionFunction": { + "functionAnchor": 6, + "name": "gt:fp64_fp64" + } + } + ], + "relations": [{ + "root": { + "input": { + "project": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": ["hd_demo_sk", "hd_buy_potential", "hd_dep_count", "hd_vehicle_count"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ] + }, + "partitionColumns": { + "columnType": ["NORMAL_COL", "NORMAL_COL", "NORMAL_COL", "NORMAL_COL"] + } + }, + "filter": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + } + } + } + } + ] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + } + } + } + }, { + "value": { + "literal": { + "string": "\u003e10000" + } + } + } + ] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + } + } + } + }, { + "value": { + "literal": { + "string": "unknown" + } + } + } + ] + } + } + } + ] + } + } + } + ] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + } + } + } + }, { + "value": { + "literal": { + "i64": "0" + } + } + } + ] + } + } + } + ] + } + } + }, { + "value": { + "ifThen": { + "ifs": [{ + "if": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + } + } + } + }, { + "value": { + "literal": { + "i64": "0" + } + } + } + ] + } + }, + "then": { + "scalarFunction": { + "functionReference": 6, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + } + } + } + } + } + }, { + "value": { + "cast": { + "type": { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + } + } + } + } + } + } + ] + } + } + }, { + "value": { + "literal": { + "fp64": 1.2 + } + } + } + ] + } + } + } + ] + } + } + } + ] + } + } + }, { + "value": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": {} + } + } + } + } + ] + } + } + } + ] + } + }, + "localFiles": { + "items": [{ + "uriFile": "file:///tmp/tmp_file", + "length": "31979", + "parquet": {} + } + ] + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": {} + } + } + } + ] + } + }, + "names": ["hd_demo_sk#1927"] + } + } + ] +} \ No newline at end of file diff --git a/velox/substrait/tests/data/q6_first_stage.json b/velox/substrait/tests/data/q6_first_stage.json index b6c2f535df84..24733dd0f185 100644 --- a/velox/substrait/tests/data/q6_first_stage.json +++ b/velox/substrait/tests/data/q6_first_stage.json @@ -67,19 +67,19 @@ "input": { "project": { "common": { - "emit":{ - "outputMapping":[ - 2 + "emit": { + "outputMapping": [ + 2 ] } }, "input": { "project": { "common": { - "emit":{ - "outputMapping":[ - 4, - 5 + "emit": { + "outputMapping": [ + 4, + 5 ] } }, @@ -491,8 +491,8 @@ "partition_index": "0", "start": "0", "length": "3719", - "uri_file": "/mock_lineitem.orc", - "orc": {} + "uri_file": "/mock_lineitem.dwrf", + "dwrf": {} } ] }