Skip to content

Commit

Permalink
[native] Fix bug when parsing SqlFunctionHandle
Browse files Browse the repository at this point in the history
  • Loading branch information
pdabre12 authored and Pratik Joseph Dabre committed Nov 18, 2024
1 parent 90354c9 commit 9c31ea7
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -842,27 +842,7 @@ void VeloxQueryPlanConverterBase::toAggregations(
auto sqlFunction =
std::dynamic_pointer_cast<protocol::SqlFunctionHandle>(
prestoAggregation.functionHandle)) {
const auto& functionId = sqlFunction->functionId;

// functionId format is function-name;arg-type1;arg-type2;...
// For example: foo;INTEGER;VARCHAR.
auto start = functionId.find(";");
if (start != std::string::npos) {
for (;;) {
auto pos = functionId.find(";", start + 1);
if (pos == std::string::npos) {
auto argumentType = functionId.substr(start + 1);
aggregate.rawInputTypes.push_back(
stringToType(argumentType, typeParser_));
break;
}

auto argumentType = functionId.substr(start + 1, pos - start - 1);
aggregate.rawInputTypes.push_back(
stringToType(argumentType, typeParser_));
pos = start + 1;
}
}
parseSqlFunctionHandle(sqlFunction, aggregate.rawInputTypes, typeParser_);
} else {
VELOX_USER_FAIL(
"Unsupported aggregate function handle: {}",
Expand Down Expand Up @@ -2106,4 +2086,29 @@ void registerPrestoPlanNodeSerDe() {
registry.Register(
"BroadcastWriteNode", presto::operators::BroadcastWriteNode::create);
}

void parseSqlFunctionHandle(
std::shared_ptr<protocol::SqlFunctionHandle> sqlFunction,
std::vector<velox::TypePtr>& rawInputTypes,
TypeParser& typeParser) {
const auto& functionId = sqlFunction->functionId;
// functionId format is function-name;arg-type1;arg-type2;...
// For example: foo;INTEGER;VARCHAR.
auto start = functionId.find(";");
if (start != std::string::npos) {
for (;;) {
auto pos = functionId.find(";", start + 1);
if (pos == std::string::npos) {
auto argumentType = functionId.substr(start + 1);
if (!argumentType.empty()) {
rawInputTypes.push_back(stringToType(argumentType, typeParser));
}
break;
}
auto argumentType = functionId.substr(start + 1, pos - start - 1);
rawInputTypes.push_back(stringToType(argumentType, typeParser));
start = pos;
}
}
}
} // namespace facebook::presto
Original file line number Diff line number Diff line change
Expand Up @@ -271,4 +271,8 @@ class VeloxBatchQueryPlanConverter : public VeloxQueryPlanConverterBase {
};

void registerPrestoPlanNodeSerDe();
void parseSqlFunctionHandle(
std::shared_ptr<protocol::SqlFunctionHandle> sqlFunction,
std::vector<velox::TypePtr>& rawInputTypes,
TypeParser& typeParser);
} // namespace facebook::presto
17 changes: 17 additions & 0 deletions presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,20 @@ target_link_libraries(
velox_window
GTest::gtest
GTest::gtest_main)

add_executable(presto_to_velox_query_plan_test PrestoToVeloxQueryPlanTest.cpp)

add_test(
NAME presto_to_velox_query_plan_test
COMMAND presto_to_velox_query_plan_test
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})

target_link_libraries(
presto_to_velox_query_plan_test
presto_operators
presto_protocol
presto_type_converter
presto_types
velox_exec_test_lib
GTest::gtest
GTest::gtest_main)
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "presto_cpp/main/types/PrestoToVeloxQueryPlan.h"
#include <gtest/gtest.h>
#include "velox/common/base/tests/GTestUtils.h"

using namespace facebook::presto;
using namespace facebook::velox;

class PrestoToVeloxQueryPlanTest : public ::testing::Test {};

TEST_F(PrestoToVeloxQueryPlanTest, parseSqlFunctionHandleWithZeroParam) {
std::string str = R"(
{
"@type": "call",
"arguments": [
{
"@type": "variable",
"name": "event_based_revenue",
"type": "real"
}
],
"displayName": "sum",
"functionHandle": {
"@type": "json_file",
"functionId": "json.x4.sum;",
"version": "1"
},
"returnType": "double"
}
)";

json j = json::parse(str);
facebook::presto::protocol::CallExpression p = j;
ASSERT_NE(p.functionHandle, nullptr);

std::vector<TypePtr> actualRawInputTypes;
std::vector<TypePtr> expectedRawInputTypes;
TypeParser typeParser;
auto sqlFunctionHandle =
std::static_pointer_cast<protocol::SqlFunctionHandle>(p.functionHandle);
facebook::presto::parseSqlFunctionHandle(
sqlFunctionHandle, actualRawInputTypes, typeParser);
ASSERT_TRUE(actualRawInputTypes.empty());
EXPECT_EQ(expectedRawInputTypes.size(), actualRawInputTypes.size());
for (int i = 0; i < expectedRawInputTypes.size(); i++) {
EXPECT_EQ(expectedRawInputTypes[i], actualRawInputTypes[i]);
}
}

TEST_F(PrestoToVeloxQueryPlanTest, parseSqlFunctionHandleWithOneParam) {
std::string str = R"(
{
"@type": "call",
"arguments": [
{
"@type": "variable",
"name": "event_based_revenue",
"type": "real"
}
],
"displayName": "sum",
"functionHandle": {
"@type": "json_file",
"functionId": "json.x4.sum;INTEGER",
"version": "1"
},
"returnType": "double"
}
)";

json j = json::parse(str);
facebook::presto::protocol::CallExpression p = j;
ASSERT_NE(p.functionHandle, nullptr);

std::vector<TypePtr> actualRawInputTypes;
std::vector<TypePtr> expectedRawInputTypes{INTEGER()};
TypeParser typeParser;
auto sqlFunctionHandle =
std::static_pointer_cast<protocol::SqlFunctionHandle>(p.functionHandle);
facebook::presto::parseSqlFunctionHandle(
sqlFunctionHandle, actualRawInputTypes, typeParser);
ASSERT_FALSE(actualRawInputTypes.empty());
EXPECT_EQ(expectedRawInputTypes.size(), actualRawInputTypes.size());
for (int i = 0; i < expectedRawInputTypes.size(); i++) {
EXPECT_EQ(expectedRawInputTypes[i], actualRawInputTypes[i]);
}
}

TEST_F(PrestoToVeloxQueryPlanTest, parseSqlFunctionHandleWithMultipleParam) {
std::string str = R"(
{
"@type": "call",
"arguments": [
{
"@type": "variable",
"name": "event_based_revenue",
"type": "real"
}
],
"displayName": "sum",
"functionHandle": {
"@type": "json_file",
"functionId": "json.x4.sum;BIGINT;BIGINT",
"version": "1"
},
"returnType": "bigint"
}
)";

json j = json::parse(str);
facebook::presto::protocol::CallExpression p = j;
ASSERT_NE(p.functionHandle, nullptr);

std::vector<TypePtr> actualRawInputTypes;
std::vector<TypePtr> expectedRawInputTypes{BIGINT(), BIGINT()};
TypeParser typeParser;
auto sqlFunctionHandle =
std::static_pointer_cast<protocol::SqlFunctionHandle>(p.functionHandle);
facebook::presto::parseSqlFunctionHandle(
sqlFunctionHandle, actualRawInputTypes, typeParser);
ASSERT_FALSE(actualRawInputTypes.empty());
EXPECT_EQ(expectedRawInputTypes.size(), actualRawInputTypes.size());
for (int i = 0; i < expectedRawInputTypes.size(); i++) {
EXPECT_EQ(expectedRawInputTypes[i], actualRawInputTypes[i]);
}
}

0 comments on commit 9c31ea7

Please sign in to comment.