From 09532f8a04a4a0bc1b8fe13c788759cd09a1e73c Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Mon, 27 May 2024 11:08:53 +0000 Subject: [PATCH] feat: support batchrequest in ProcessQuery --- src/tablet/tablet_impl.cc | 239 ++++++++++++++++++++++++-------------- 1 file changed, 151 insertions(+), 88 deletions(-) diff --git a/src/tablet/tablet_impl.cc b/src/tablet/tablet_impl.cc index 1545b96c9d4..230b5c46a09 100644 --- a/src/tablet/tablet_impl.cc +++ b/src/tablet/tablet_impl.cc @@ -20,6 +20,7 @@ #include #include #include +#include "vm/sql_compiler.h" #ifdef DISALLOW_COPY_AND_ASSIGN #undef DISALLOW_COPY_AND_ASSIGN #endif @@ -1691,98 +1692,127 @@ void TabletImpl::ProcessQuery(bool is_sub, RpcController* ctrl, const openmldb:: auto mode = hybridse::vm::Engine::TryDetermineEngineMode(request->sql(), default_mode); ::hybridse::base::Status status; - // FIXME(someone): it does not handles batchrequest - if (mode == hybridse::vm::EngineMode::kBatchMode) { - // convert repeated openmldb:type::DataType into hybridse::codec::Schema - hybridse::codec::Schema parameter_schema; - for (int i = 0; i < request->parameter_types().size(); i++) { - auto column = parameter_schema.Add(); - hybridse::type::Type hybridse_type; - - if (!openmldb::schema::SchemaAdapter::ConvertType(request->parameter_types(i), &hybridse_type)) { - response->set_msg("Invalid parameter type: " + - openmldb::type::DataType_Name(request->parameter_types(i))); - response->set_code(::openmldb::base::kSQLCompileError); - return; + switch (mode) { + case hybridse::vm::EngineMode::kBatchMode: { + // convert repeated openmldb:type::DataType into hybridse::codec::Schema + hybridse::codec::Schema parameter_schema; + for (int i = 0; i < request->parameter_types().size(); i++) { + auto column = parameter_schema.Add(); + hybridse::type::Type hybridse_type; + + if (!openmldb::schema::SchemaAdapter::ConvertType(request->parameter_types(i), &hybridse_type)) { + response->set_msg("Invalid parameter type: " + + openmldb::type::DataType_Name(request->parameter_types(i))); + response->set_code(::openmldb::base::kSQLCompileError); + return; + } + column->set_type(hybridse_type); } - column->set_type(hybridse_type); - } - ::hybridse::vm::BatchRunSession session; - if (request->is_debug()) { - session.EnableDebug(); - } - session.SetParameterSchema(parameter_schema); - { - bool ok = engine_->Get(request->sql(), request->db(), session, status); - if (!ok) { - response->set_msg(status.msg); - response->set_code(::openmldb::base::kSQLCompileError); - DLOG(WARNING) << "fail to compile sql " << request->sql() << ", message: " << status.msg; - return; + ::hybridse::vm::BatchRunSession session; + if (request->is_debug()) { + session.EnableDebug(); + } + session.SetParameterSchema(parameter_schema); + { + bool ok = engine_->Get(request->sql(), request->db(), session, status); + if (!ok) { + response->set_msg(status.msg); + response->set_code(::openmldb::base::kSQLCompileError); + DLOG(WARNING) << "fail to compile sql " << request->sql() << ", message: " << status.msg; + return; + } } - } - ::hybridse::codec::Row parameter_row; - auto& request_buf = static_cast(ctrl)->request_attachment(); - if (request->parameter_row_size() > 0 && - !codec::DecodeRpcRow(request_buf, 0, request->parameter_row_size(), request->parameter_row_slices(), - ¶meter_row)) { - response->set_code(::openmldb::base::kSQLRunError); - response->set_msg("fail to decode parameter row"); - return; - } - std::vector<::hybridse::codec::Row> output_rows; - int32_t run_ret = session.Run(parameter_row, output_rows); - if (run_ret != 0) { - response->set_msg(status.msg); - response->set_code(::openmldb::base::kSQLRunError); - DLOG(WARNING) << "fail to run sql: " << request->sql(); - return; - } - uint32_t byte_size = 0; - uint32_t count = 0; - for (auto& output_row : output_rows) { - if (FLAGS_scan_max_bytes_size > 0 && byte_size > FLAGS_scan_max_bytes_size) { - LOG(WARNING) << "reach the max byte size " << FLAGS_scan_max_bytes_size << " truncate result"; - response->set_schema(session.GetEncodedSchema()); - response->set_byte_size(byte_size); - response->set_count(count); - response->set_code(::openmldb::base::kOk); + ::hybridse::codec::Row parameter_row; + auto& request_buf = static_cast(ctrl)->request_attachment(); + if (request->parameter_row_size() > 0 && + !codec::DecodeRpcRow(request_buf, 0, request->parameter_row_size(), request->parameter_row_slices(), + ¶meter_row)) { + response->set_code(::openmldb::base::kSQLRunError); + response->set_msg("fail to decode parameter row"); + return; + } + std::vector<::hybridse::codec::Row> output_rows; + int32_t run_ret = session.Run(parameter_row, output_rows); + if (run_ret != 0) { + response->set_msg(status.msg); + response->set_code(::openmldb::base::kSQLRunError); + DLOG(WARNING) << "fail to run sql: " << request->sql(); return; } - byte_size += output_row.size(); - buf->append(reinterpret_cast(output_row.buf()), output_row.size()); - count += 1; + uint32_t byte_size = 0; + uint32_t count = 0; + for (auto& output_row : output_rows) { + if (FLAGS_scan_max_bytes_size > 0 && byte_size > FLAGS_scan_max_bytes_size) { + LOG(WARNING) << "reach the max byte size " << FLAGS_scan_max_bytes_size << " truncate result"; + response->set_schema(session.GetEncodedSchema()); + response->set_byte_size(byte_size); + response->set_count(count); + response->set_code(::openmldb::base::kOk); + return; + } + byte_size += output_row.size(); + buf->append(reinterpret_cast(output_row.buf()), output_row.size()); + count += 1; + } + response->set_schema(session.GetEncodedSchema()); + response->set_byte_size(byte_size); + response->set_count(count); + response->set_code(::openmldb::base::kOk); + DLOG(INFO) << "handle batch sql " << request->sql() << " with record cnt " << count << " byte size " + << byte_size; + break; } - response->set_schema(session.GetEncodedSchema()); - response->set_byte_size(byte_size); - response->set_count(count); - response->set_code(::openmldb::base::kOk); - DLOG(INFO) << "handle batch sql " << request->sql() << " with record cnt " << count << " byte size " - << byte_size; - } else { - ::hybridse::vm::RequestRunSession session; - if (request->is_debug()) { - session.EnableDebug(); - } - if (request->is_procedure()) { - const std::string& db_name = request->db(); - const std::string& sp_name = request->sp_name(); - std::shared_ptr request_compile_info; - { - hybridse::base::Status status; - request_compile_info = sp_cache_->GetRequestInfo(db_name, sp_name, status); - if (!status.isOK()) { - response->set_code(::openmldb::base::ReturnCode::kProcedureNotFound); + case hybridse::vm::kRequestMode: { + ::hybridse::vm::RequestRunSession session; + if (request->is_debug()) { + session.EnableDebug(); + } + if (request->is_procedure()) { + const std::string& db_name = request->db(); + const std::string& sp_name = request->sp_name(); + std::shared_ptr request_compile_info; + { + hybridse::base::Status status; + request_compile_info = sp_cache_->GetRequestInfo(db_name, sp_name, status); + if (!status.isOK()) { + response->set_code(::openmldb::base::ReturnCode::kProcedureNotFound); + response->set_msg(status.msg); + PDLOG(WARNING, status.msg.c_str()); + return; + } + } + session.SetCompileInfo(request_compile_info); + session.SetSpName(sp_name); + RunRequestQuery(ctrl, *request, session, *response, *buf); + } else { + bool ok = engine_->Get(request->sql(), request->db(), session, status); + if (!ok || session.GetCompileInfo() == nullptr) { response->set_msg(status.msg); - PDLOG(WARNING, status.msg.c_str()); + response->set_code(::openmldb::base::kSQLCompileError); + DLOG(WARNING) << "fail to compile sql in request mode:\n" << request->sql(); return; } + RunRequestQuery(ctrl, *request, session, *response, *buf); + } + const std::string& sql = session.GetCompileInfo()->GetSql(); + if (response->code() != ::openmldb::base::kOk) { + DLOG(WARNING) << "fail to run sql " << sql << " error msg: " << response->msg(); + } else { + DLOG(INFO) << "handle request sql " << sql; + } + break; + } + case hybridse::vm::kBatchRequestMode: { + // we support a simplified batch request query here + // not procedure + // no parameter input or bachrequst row + // batchrequest row must specified in CONFIG (values = ...) + ::hybridse::base::Status status; + ::hybridse::vm::BatchRequestRunSession session; + if (request->is_debug()) { + session.EnableDebug(); } - session.SetCompileInfo(request_compile_info); - session.SetSpName(sp_name); - RunRequestQuery(ctrl, *request, session, *response, *buf); - } else { bool ok = engine_->Get(request->sql(), request->db(), session, status); if (!ok || session.GetCompileInfo() == nullptr) { response->set_msg(status.msg); @@ -1790,13 +1820,46 @@ void TabletImpl::ProcessQuery(bool is_sub, RpcController* ctrl, const openmldb:: DLOG(WARNING) << "fail to compile sql in request mode:\n" << request->sql(); return; } - RunRequestQuery(ctrl, *request, session, *response, *buf); + auto info = std::dynamic_pointer_cast(session.GetCompileInfo()); + if (info && info->get_sql_context().request_rows.empty()) { + response->set_msg("batch request values must specified in SQL CONFIG (values = [...])"); + response->set_code(::openmldb::base::kSQLCompileError); + return; + } + std::vector<::hybridse::codec::Row> output_rows; + std::vector<::hybridse::codec::Row> empty_inputs; + int32_t run_ret = session.Run(empty_inputs, output_rows); + if (run_ret != 0) { + response->set_msg(status.msg); + response->set_code(::openmldb::base::kSQLRunError); + DLOG(WARNING) << "fail to run batchrequest sql: " << request->sql(); + return; + } + uint32_t byte_size = 0; + uint32_t count = 0; + for (auto& output_row : output_rows) { + if (FLAGS_scan_max_bytes_size > 0 && byte_size > FLAGS_scan_max_bytes_size) { + LOG(WARNING) << "reach the max byte size " << FLAGS_scan_max_bytes_size << " truncate result"; + response->set_schema(session.GetEncodedSchema()); + response->set_byte_size(byte_size); + response->set_count(count); + response->set_code(::openmldb::base::kOk); + return; + } + byte_size += output_row.size(); + buf->append(reinterpret_cast(output_row.buf()), output_row.size()); + count += 1; + } + response->set_schema(session.GetEncodedSchema()); + response->set_byte_size(byte_size); + response->set_count(count); + response->set_code(::openmldb::base::kOk); + break; } - const std::string& sql = session.GetCompileInfo()->GetSql(); - if (response->code() != ::openmldb::base::kOk) { - DLOG(WARNING) << "fail to run sql " << sql << " error msg: " << response->msg(); - } else { - DLOG(INFO) << "handle request sql " << sql; + default: { + response->set_msg("un-implemented execute_mode: " + hybridse::vm::EngineModeName(mode)); + response->set_code(::openmldb::base::kSQLCompileError); + break; } } }