Skip to content

Commit

Permalink
add some logics about dealing with illegal queries
Browse files Browse the repository at this point in the history
  • Loading branch information
amamiya-len committed Sep 10, 2024
1 parent 405a138 commit 10e6f9b
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/Common/ErrorCodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@
M(2611, UNKNOWN_FORMAT_SCHEMA) \
M(2612, AMBIGUOUS_FORMAT_SCHEMA) \
M(2613, UNKNOWN_FORMAT_SCHEMA_TYPE) \
M(2614, DUPLICATE_KEY) \
M(2631, DUPLICATE_KEY) \
/* See END */

namespace DB
Expand Down
2 changes: 1 addition & 1 deletion src/Common/sendRequest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ std::pair<String, Int32> sendRequest(
const String & password,
const String & payload,
const std::vector<std::pair<String, String>> & headers,
/// One second for connect/send/receive
/// Timeout second for connect/send/receive
ConnectionTimeouts timeouts,
Poco::Logger * log)
{
Expand Down
2 changes: 1 addition & 1 deletion src/Coordination/MetaStoreConnection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ std::pair<String, Int32> MetaStoreConnection::forwardRequest(
for (auto it = active_servers.begin(); it != active_servers.end();)
{
Poco::URI uri{fmt::format({METASTORE_URL}, it->getHost(), it->getPort(), uri_parameter)};
auto [response, http_status] = sendRequest(uri, method, query_id, user, password, body, {}, ConnectionTimeouts({2, 0}/*connection timeout */, {5, 0}/* send timeout */, {10, 0}/* receive timeout */) , log);
auto [response, http_status] = sendRequest(uri, method, query_id, user, password, body, {}, ConnectionTimeouts({2, 0}/* connect timeout */, {5, 0}/* send timeout */, {10, 0}/* receive timeout */) , log);

if (http_status != Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR)
{
Expand Down
2 changes: 1 addition & 1 deletion src/Functions/UserDefined/RemoteUserDefinedFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class RemoteUserDefinedFunction final : public UserDefinedFunctionBase
"",
out,
{{config.auth_context.key_name, config.auth_context.key_value}, {"", context->getCurrentQueryId()}},
ConnectionTimeouts({2, 0}, {5, 0}, {static_cast<long>(config.command_execution_timeout_milliseconds / 1000), static_cast<long>((config.command_execution_timeout_milliseconds % 1000u) * 1000u)}),
ConnectionTimeouts({2, 0}/* connect timeout */, {5, 0} /* send timeout */, {static_cast<long>(config.command_execution_timeout_milliseconds / 1000), static_cast<long>((config.command_execution_timeout_milliseconds % 1000u) * 1000u)/* receive timeout */}), /// timeout and limit for connect/send/receive ...
&Poco::Logger::get("UserDefinedFunction"));

if (http_status != Poco::Net::HTTPResponse::HTTP_OK)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ struct RemoteUserDefinedFunctionConfiguration : public UserDefinedFunctionConfig
/// Timeout for reading data from input format
size_t command_read_timeout_milliseconds = 10000;

/// Timeout for receiving response from remote endpoint
size_t command_execution_timeout_milliseconds = 10000;

/// url of remote endpoint, only available when 'type' is 'remote'
Expand Down
20 changes: 14 additions & 6 deletions src/Parsers/ASTCreateFunctionQuery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ void ASTCreateFunctionQuery::formatImpl(const IAST::FormatSettings & settings, I
/// proton: starts
if (is_remote)
{
settings.ostr << '\n';
settings.ostr << ' ';
function_core->formatImpl(settings, state, frame);
return;
}
Expand Down Expand Up @@ -173,12 +173,20 @@ Poco::JSON::Object::Ptr ASTCreateFunctionQuery::toJSON() const
}

inner_func->set("url", url.value());
if (auth_method.has_value() && auth_method.value() == "auth_header")
if (auth_method.has_value())
{
inner_func->set("auth_method", auth_method.value());
if (auth_method.value() == "auth_header")
{
Poco::JSON::Object::Ptr auth_context = new Poco::JSON::Object();
auth_context->set("key_name", auth_header.value_or(""));
auth_context->set("key_value", auth_key.value_or(""));
inner_func->set("auth_context", auth_context);
}
}
else
{
Poco::JSON::Object::Ptr auth_context = new Poco::JSON::Object();
auth_context->set("key_name", auth_header.value_or(""));
auth_context->set("key_value", auth_key.value_or(""));
inner_func->set("auth_context", auth_context);
inner_func->set("auth_method", "none");
}
if (execution_timeout.has_value())
{
Expand Down
64 changes: 64 additions & 0 deletions src/Parsers/ParserCreateFunctionQuery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,70 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp
ParserKeyValuePairsSet kv_pairs_list;
if (!kv_pairs_list.parse(pos, kv_list, expected))
return false;

/// check if the parameters are valid and no unsupported or unknown parameters.
std::optional<String> ast_url;
std::optional<String> ast_auth_method;
std::optional<String> ast_auth_header;
std::optional<String> ast_auth_key;
std::optional<UInt64> ast_execution_timeout;
for (const auto & kv : kv_list->children)
{
auto * kv_pair = kv->as<ASTPair>();
auto key = kv_pair->first;
auto pair_value = kv_pair->second->as<ASTLiteral>()->value;
if (!kv_pair)
throw Exception("Key-value pair expected", ErrorCodes::UNKNOWN_FUNCTION);

if (key == "url")
{
ast_url = pair_value.safeGet<String>();
}
else if (key == "auth_method")
{
ast_auth_method = pair_value.safeGet<String>();
if (ast_auth_method.value() != "none" && ast_auth_method.value() != "auth_header")
throw Exception("Unknown auth method", ErrorCodes::UNKNOWN_FUNCTION);
}
else if (key == "auth_header")
{
ast_auth_header = pair_value.safeGet<String>();
}
else if (key == "auth_key")
{
ast_auth_key = pair_value.safeGet<String>();
}
else if (key == "execution_timeout")
{
ast_execution_timeout = pair_value.safeGet<UInt64>();
}
}
/// check if URL is set
if (!ast_url)
throw Exception("URL is required for remote function", ErrorCodes::UNKNOWN_FUNCTION);
/// check if auth_method is "auth_header" or "none"
if (ast_auth_method)
{
if (ast_auth_method.value() == "auth_header")
{
if (!ast_auth_header || !ast_auth_key)
throw Exception("Auth header and auth key are required for auth_header auth method", ErrorCodes::UNKNOWN_FUNCTION);
}
else if (ast_auth_method.value() == "none")
{
if (ast_auth_header || ast_auth_key)
throw Exception("Auth method is 'none', but auth header or auth key is set.", ErrorCodes::UNKNOWN_FUNCTION);
}
else
{
throw Exception("Unknown auth method " + ast_auth_method.value(), ErrorCodes::UNKNOWN_FUNCTION);
}
}
else
{
if (ast_auth_header || ast_auth_key)
throw Exception("Auth method is 'none', but auth header or auth key is set.", ErrorCodes::UNKNOWN_FUNCTION);
}
function_core = std::move(kv_list);
}
/// proton: ends
Expand Down
1 change: 0 additions & 1 deletion src/Parsers/ParserKeyValuePairsSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,3 @@ class ParserKeyValuePairsSet : public IParserBase

}


135 changes: 134 additions & 1 deletion src/Parsers/tests/gtest_create_remote_func_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,55 @@ TEST(ParserCreateRemoteFunctionQuery, UDFHeaderMethodIsOther)
"URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/'"
"AUTH_METHOD 'token'";
ParserCreateFunctionQuery parser;
EXPECT_NO_THROW(parseQuery(parser, input.data(), input.data() + input.size(), "", 0, 0));
EXPECT_THROW(parseQuery(parser, input.data(), input.data() + input.size(), "", 0, 0), Exception);
}

TEST(ParserCreateRemoteFunctionQuery, UDFNoURL)
{
String input = "CREATE REMOTE FUNCTION ip_lookup(ip string) RETURNS string "
"AUTH_METHOD 'none'";
ParserCreateFunctionQuery parser;
EXPECT_THROW(parseQuery(parser, input.data(), input.data() + input.size(), "", 0, 0), Exception);
}

TEST(ParserCreateRemoteFunctionQuery, UDFURLWithExecutionTimeout)
{
String input = "CREATE REMOTE FUNCTION ip_lookup(ip string) RETURNS string "
"EXECUTION_TIMEOUT 2000 "
"URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/'";
ParserCreateFunctionQuery parser;
ASTPtr ast = parseQuery(parser, input.data(), input.data() + input.size(), "", 0, 0);
ASTCreateFunctionQuery * create = ast->as<ASTCreateFunctionQuery>();

EXPECT_EQ(create->getFunctionName(), "ip_lookup");
EXPECT_EQ(create->lang, "Remote");
EXPECT_NE(create->function_core, nullptr);
EXPECT_NE(create->arguments, nullptr);

/// Check arguments
String args = queryToString(*create->arguments.get(), true);
EXPECT_EQ(args, "(ip string)");

/// Check return type
String ret = queryToString(*create->return_type.get(), true);
EXPECT_EQ(ret, "string");

auto remote_func_settings = create->function_core;

EXPECT_EQ(remote_func_settings->children.size(), 2);
EXPECT_EQ(remote_func_settings->children.front()->as<ASTPair>()->first, "execution_timeout");
EXPECT_EQ(remote_func_settings->children.front()->as<ASTPair>()->second->as<ASTLiteral>()->value.safeGet<UInt64>(),2000u);
EXPECT_EQ(remote_func_settings->children.back()->as<ASTPair>()->first, "url");
EXPECT_EQ(remote_func_settings->children.back()->as<ASTPair>()->second->as<ASTLiteral>()->value.safeGet<String>(),
"https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/");
}

TEST(ParserCreateRemoteFunctionQuery, UDFOnlyExecutionTimeout)
{
String input = "CREATE REMOTE FUNCTION ip_lookup(ip string) RETURNS string "
"EXECUTION_TIMEOUT 3000";
ParserCreateFunctionQuery parser;
EXPECT_THROW(parseQuery(parser, input.data(), input.data() + input.size(), "", 0, 0), Exception);
}

TEST(ParserCreateRemoteFunctionQuery, UDFMultipleKey)
Expand All @@ -133,3 +181,88 @@ TEST(ParserCreateRemoteFunctionQuery, UDFMultipleKey)
EXPECT_THROW(parseQuery(parser, input.data(), input.data() + input.size(), "", 0, 0), Exception);
}

TEST(ParserCreateRemoteFunctionQuery, UDFNoneButSetAUTHHEADER)
{
String input = "CREATE REMOTE FUNCTION ip_lookup(ip string) RETURNS string "
"URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/'"
"AUTH_METHOD 'none'"
"AUTH_HEADER 'auth'"
"AUTH_KEY 'proton'"
"EXECUTION_TIMEOUT 30000";
ParserCreateFunctionQuery parser;
EXPECT_THROW(parseQuery(parser, input.data(), input.data() + input.size(), "", 0, 0), Exception);
}

TEST(ParserCreateRemoteFunctionQuery, UDFAUTHHEADERButNotSetAUTHHEADER)
{
String input = "CREATE REMOTE FUNCTION ip_lookup(ip string) RETURNS string "
"URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/'"
"AUTH_METHOD 'auth_header'"
"AUTH_KEY 'proton'"
"EXECUTION_TIMEOUT 30000";
ParserCreateFunctionQuery parser;
EXPECT_THROW(parseQuery(parser, input.data(), input.data() + input.size(), "", 0, 0), Exception);
}

TEST(ParserCreateRemoteFunctionQuery, UDFAUTHHEADERButNotSetAUTHKEY)
{
String input = "CREATE REMOTE FUNCTION ip_lookup(ip string) RETURNS string "
"URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/'"
"AUTH_METHOD 'auth_header'"
"AUTH_HEADER 'proton'"
"EXECUTION_TIMEOUT 30000";
ParserCreateFunctionQuery parser;
EXPECT_THROW(parseQuery(parser, input.data(), input.data() + input.size(), "", 0, 0), Exception);
}

TEST(ParserCreateRemoteFunctionQuery, UDFFormat1)
{
String input = "CREATE REMOTE FUNCTION ip_lookup(ip string) RETURNS string "
"EXECUTION_TIMEOUT 2000 "
"URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/'";
ParserCreateFunctionQuery parser;
ASTPtr ast = parseQuery(parser, input.data(), input.data() + input.size(), "", 0, 0);
ASTCreateFunctionQuery * create = ast->as<ASTCreateFunctionQuery>();

auto str = serializeAST(*create, true);

EXPECT_EQ(str,
"CREATE REMOTE FUNCTION ip_lookup(ip string) RETURNS string EXECUTION_TIMEOUT 2000 URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/'"
);
}

TEST(ParserCreateRemoteFunctionQuery, UDFFormat2)
{
String input = "CREATE REMOTE FUNCTION ip_lookup(ip string) RETURNS string "
"URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/' "
"EXECUTION_TIMEOUT 2000";
ParserCreateFunctionQuery parser;
ASTPtr ast = parseQuery(parser, input.data(), input.data() + input.size(), "", 0, 0);
ASTCreateFunctionQuery * create = ast->as<ASTCreateFunctionQuery>();

auto str = serializeAST(*create, true);

EXPECT_EQ(str,
"CREATE REMOTE FUNCTION ip_lookup(ip string) RETURNS string URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/' EXECUTION_TIMEOUT 2000"
);
}

TEST(ParserCreateRemoteFunctionQuery, UDFFormat3)
{
String input = "CREATE REMOTE FUNCTION ip_lookup(ip string) RETURNS string "
"AUTH_METHOD 'auth_header' "
"URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/' "
"EXECUTION_TIMEOUT 2000 "
"AUTH_KEY 'proton' "
"AUTH_HEADER 'auth' ";
ParserCreateFunctionQuery parser;
ASTPtr ast = parseQuery(parser, input.data(), input.data() + input.size(), "", 0, 0);
ASTCreateFunctionQuery * create = ast->as<ASTCreateFunctionQuery>();

auto str = serializeAST(*create, true);

EXPECT_EQ(str,
"CREATE REMOTE FUNCTION ip_lookup(ip string) RETURNS string AUTH_METHOD 'auth_header' URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/' EXECUTION_TIMEOUT 2000 AUTH_KEY 'proton' AUTH_HEADER 'auth'"
);
}

0 comments on commit 10e6f9b

Please sign in to comment.