Skip to content

Commit

Permalink
[native] Fix bug when parsing SqlFunctionHandle
Browse files Browse the repository at this point in the history
  • Loading branch information
pdabre12 authored and Pratik Joseph Dabre committed Nov 21, 2024
1 parent 1fba459 commit 405d79b
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -842,27 +842,7 @@ void VeloxQueryPlanConverterBase::toAggregations(
auto sqlFunction =
std::dynamic_pointer_cast<protocol::SqlFunctionHandle>(
prestoAggregation.functionHandle)) {
const auto& functionId = sqlFunction->functionId;

// functionId format is function-name;arg-type1;arg-type2;...
// For example: foo;INTEGER;VARCHAR.
auto start = functionId.find(";");
if (start != std::string::npos) {
for (;;) {
auto pos = functionId.find(";", start + 1);
if (pos == std::string::npos) {
auto argumentType = functionId.substr(start + 1);
aggregate.rawInputTypes.push_back(
stringToType(argumentType, typeParser_));
break;
}

auto argumentType = functionId.substr(start + 1, pos - start - 1);
aggregate.rawInputTypes.push_back(
stringToType(argumentType, typeParser_));
pos = start + 1;
}
}
parseSqlFunctionHandle(sqlFunction, aggregate.rawInputTypes, typeParser_);
} else {
VELOX_USER_FAIL(
"Unsupported aggregate function handle: {}",
Expand Down Expand Up @@ -2106,4 +2086,30 @@ void registerPrestoPlanNodeSerDe() {
registry.Register(
"BroadcastWriteNode", presto::operators::BroadcastWriteNode::create);
}

void parseSqlFunctionHandle(
const std::shared_ptr<protocol::SqlFunctionHandle>& sqlFunction,
std::vector<velox::TypePtr>& rawInputTypes,
TypeParser& typeParser) {
const auto& functionId = sqlFunction->functionId;
// functionId format is function-name;arg-type1;arg-type2;...
// For example: foo;INTEGER;VARCHAR.
auto start = functionId.find(";");
if (start != std::string::npos) {
for (;;) {
auto pos = functionId.find(";", start + 1);
if (pos == std::string::npos) {
auto argumentType = functionId.substr(start + 1);
if (!argumentType.empty()) {
rawInputTypes.push_back(stringToType(argumentType, typeParser));
}
break;
}
auto argumentType = functionId.substr(start + 1, pos - start - 1);
VELOX_CHECK(!argumentType.empty());
rawInputTypes.push_back(stringToType(argumentType, typeParser));
start = pos;
}
}
}
} // namespace facebook::presto
Original file line number Diff line number Diff line change
Expand Up @@ -271,4 +271,8 @@ class VeloxBatchQueryPlanConverter : public VeloxQueryPlanConverterBase {
};

void registerPrestoPlanNodeSerDe();
void parseSqlFunctionHandle(
const std::shared_ptr<protocol::SqlFunctionHandle>& sqlFunction,
std::vector<velox::TypePtr>& rawInputTypes,
TypeParser& typeParser);
} // namespace facebook::presto
17 changes: 17 additions & 0 deletions presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,20 @@ target_link_libraries(
velox_window
GTest::gtest
GTest::gtest_main)

add_executable(presto_to_velox_query_plan_test PrestoToVeloxQueryPlanTest.cpp)

add_test(
NAME presto_to_velox_query_plan_test
COMMAND presto_to_velox_query_plan_test
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})

target_link_libraries(
presto_to_velox_query_plan_test
presto_operators
presto_protocol
presto_type_converter
presto_types
velox_exec_test_lib
GTest::gtest
GTest::gtest_main)
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "presto_cpp/main/types/PrestoToVeloxQueryPlan.h"
#include <gtest/gtest.h>
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/prestosql/types/HyperLogLogType.h"
#include "velox/functions/prestosql/types/IPAddressType.h"
#include "velox/functions/prestosql/types/IPPrefixType.h"
#include "velox/functions/prestosql/types/JsonType.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"
#include "velox/functions/prestosql/types/UuidType.h"

using namespace facebook::presto;
using namespace facebook::velox;

namespace {
inline void validateSqlFunctionHandleParsing(
const std::shared_ptr<facebook::presto::protocol::FunctionHandle>&
functionHandle,
std::vector<TypePtr> expectedRawInputTypes) {
std::vector<TypePtr> actualRawInputTypes;
TypeParser typeParser;
auto sqlFunctionHandle =
std::static_pointer_cast<protocol::SqlFunctionHandle>(functionHandle);
facebook::presto::parseSqlFunctionHandle(
sqlFunctionHandle, actualRawInputTypes, typeParser);
EXPECT_EQ(expectedRawInputTypes.size(), actualRawInputTypes.size());
for (int i = 0; i < expectedRawInputTypes.size(); i++) {
EXPECT_EQ(*expectedRawInputTypes[i], *actualRawInputTypes[i]);
}
}
} // namespace

class PrestoToVeloxQueryPlanTest : public ::testing::Test {
public:
PrestoToVeloxQueryPlanTest() {
registerHyperLogLogType();
registerIPAddressType();
registerIPPrefixType();
registerJsonType();
registerTimestampWithTimeZoneType();
registerUuidType();
}
};

TEST_F(PrestoToVeloxQueryPlanTest, parseSqlFunctionHandleWithZeroParam) {
std::string str = R"(
{
"@type": "json_file",
"functionId": "json_file.test.count;",
"version": "1"
}
)";

json j = json::parse(str);
std::shared_ptr<facebook::presto::protocol::FunctionHandle> functionHandle =
j;
ASSERT_NE(functionHandle, nullptr);
validateSqlFunctionHandleParsing(functionHandle, {});
}

TEST_F(PrestoToVeloxQueryPlanTest, parseSqlFunctionHandleWithOneParam) {
std::string str = R"(
{
"@type": "json_file",
"functionId": "json_file.test.sum;tinyint",
"version": "1"
}
)";

json j = json::parse(str);
std::shared_ptr<facebook::presto::protocol::FunctionHandle> functionHandle =
j;
ASSERT_NE(functionHandle, nullptr);

std::vector<TypePtr> expectedRawInputTypes{TINYINT()};
validateSqlFunctionHandleParsing(functionHandle, expectedRawInputTypes);
}

TEST_F(PrestoToVeloxQueryPlanTest, parseSqlFunctionHandleWithMultipleParam) {
std::string str = R"(
{
"@type": "json_file",
"functionId": "json_file.test.avg;array(decimal(15, 2));varchar",
"version": "1"
}
)";

json j = json::parse(str);
std::shared_ptr<facebook::presto::protocol::FunctionHandle> functionHandle =
j;
ASSERT_NE(functionHandle, nullptr);

std::vector<TypePtr> expectedRawInputTypes{ARRAY(DECIMAL(15, 2)), VARCHAR()};
validateSqlFunctionHandleParsing(functionHandle, expectedRawInputTypes);
}

TEST_F(PrestoToVeloxQueryPlanTest, parseSqlFunctionHandleAllComplexTypes) {
std::string str = R"(
{
"@type": "json_file",
"functionId": "json_file.test.all_complex_types;row(map(hugeint, ipaddress), ipprefix);row(array(varbinary), timestamp, date, json, hyperloglog, timestamp with time zone, interval year to month, interval day to second);function(double, boolean);uuid",
"version": "1"
}
)";

json j = json::parse(str);
std::shared_ptr<facebook::presto::protocol::FunctionHandle> functionHandle =
j;
ASSERT_NE(functionHandle, nullptr);

std::vector<TypePtr> expectedRawInputTypes{
ROW({MAP(HUGEINT(), IPADDRESS()), IPPREFIX()}),
ROW(
{ARRAY(VARBINARY()),
TIMESTAMP(),
DATE(),
JSON(),
HYPERLOGLOG(),
TIMESTAMP_WITH_TIME_ZONE(),
INTERVAL_YEAR_MONTH(),
INTERVAL_DAY_TIME()}),
FUNCTION({DOUBLE()}, BOOLEAN()),
UUID()};
validateSqlFunctionHandleParsing(functionHandle, expectedRawInputTypes);
}

0 comments on commit 405d79b

Please sign in to comment.