diff --git a/hybridse/include/node/node_enum.h b/hybridse/include/node/node_enum.h index 7b189aa6aac..eea2bd9a953 100644 --- a/hybridse/include/node/node_enum.h +++ b/hybridse/include/node/node_enum.h @@ -98,6 +98,8 @@ enum SqlNodeType { kColumnSchema, kCreateUserStmt, kAlterUserStmt, + kGrantStmt, + kRevokeStmt, kCallStmt, kSqlNodeTypeLast, // debug type kVariadicUdfDef, @@ -347,6 +349,8 @@ enum PlanType { kPlanTypeShow, kPlanTypeCreateUser, kPlanTypeAlterUser, + kPlanTypeGrant, + kPlanTypeRevoke, kPlanTypeCallStmt, kUnknowPlan = -1, }; diff --git a/hybridse/include/node/plan_node.h b/hybridse/include/node/plan_node.h index ec82b6a586f..0e5683c4702 100644 --- a/hybridse/include/node/plan_node.h +++ b/hybridse/include/node/plan_node.h @@ -739,6 +739,67 @@ class CreateUserPlanNode : public LeafPlanNode { const std::shared_ptr options_; }; +class GrantPlanNode : public LeafPlanNode { + public: + explicit GrantPlanNode(std::optional target_type, std::string database, std::string target, + std::vector privileges, bool is_all_privileges, + std::vector grantees, bool with_grant_option) + : LeafPlanNode(kPlanTypeGrant), + target_type_(target_type), + database_(database), + target_(target), + privileges_(privileges), + is_all_privileges_(is_all_privileges), + grantees_(grantees), + with_grant_option_(with_grant_option) {} + ~GrantPlanNode() = default; + const std::vector Privileges() const { return privileges_; } + const std::vector Grantees() const { return grantees_; } + const std::string Database() const { return database_; } + const std::string Target() const { return target_; } + const std::optional TargetType() const { return target_type_; } + const bool IsAllPrivileges() const { return is_all_privileges_; } + const bool WithGrantOption() const { return with_grant_option_; } + + private: + std::optional target_type_; + std::string database_; + std::string target_; + std::vector privileges_; + bool is_all_privileges_; + std::vector grantees_; + bool with_grant_option_; +}; + +class RevokePlanNode : public LeafPlanNode { + public: + explicit RevokePlanNode(std::optional target_type, std::string database, std::string target, + std::vector privileges, bool is_all_privileges, + std::vector grantees) + : LeafPlanNode(kPlanTypeRevoke), + target_type_(target_type), + database_(database), + target_(target), + privileges_(privileges), + is_all_privileges_(is_all_privileges), + grantees_(grantees) {} + ~RevokePlanNode() = default; + const std::vector Privileges() const { return privileges_; } + const std::vector Grantees() const { return grantees_; } + const std::string Database() const { return database_; } + const std::string Target() const { return target_; } + const std::optional TargetType() const { return target_type_; } + const bool IsAllPrivileges() const { return is_all_privileges_; } + + private: + std::optional target_type_; + std::string database_; + std::string target_; + std::vector privileges_; + bool is_all_privileges_; + std::vector grantees_; +}; + class AlterUserPlanNode : public LeafPlanNode { public: explicit AlterUserPlanNode(const std::string& name, bool if_exists, std::shared_ptr options) diff --git a/hybridse/include/node/sql_node.h b/hybridse/include/node/sql_node.h index 96ea7a94163..52542426c2a 100644 --- a/hybridse/include/node/sql_node.h +++ b/hybridse/include/node/sql_node.h @@ -2421,6 +2421,64 @@ class AlterUserNode : public SqlNode { const std::shared_ptr options_; }; +class GrantNode : public SqlNode { + public: + explicit GrantNode(std::optional target_type, std::string database, std::string target, + std::vector privileges, bool is_all_privileges, std::vector grantees, + bool with_grant_option) + : SqlNode(kGrantStmt, 0, 0), + target_type_(target_type), + database_(database), + target_(target), + privileges_(privileges), + is_all_privileges_(is_all_privileges), + grantees_(grantees), + with_grant_option_(with_grant_option) {} + const std::vector Privileges() const { return privileges_; } + const std::vector Grantees() const { return grantees_; } + const std::string Database() const { return database_; } + const std::string Target() const { return target_; } + const std::optional TargetType() const { return target_type_; } + const bool IsAllPrivileges() const { return is_all_privileges_; } + const bool WithGrantOption() const { return with_grant_option_; } + + private: + std::optional target_type_; + std::string database_; + std::string target_; + std::vector privileges_; + bool is_all_privileges_; + std::vector grantees_; + bool with_grant_option_; +}; + +class RevokeNode : public SqlNode { + public: + explicit RevokeNode(std::optional target_type, std::string database, std::string target, + std::vector privileges, bool is_all_privileges, std::vector grantees) + : SqlNode(kRevokeStmt, 0, 0), + target_type_(target_type), + database_(database), + target_(target), + privileges_(privileges), + is_all_privileges_(is_all_privileges), + grantees_(grantees) {} + const std::vector Privileges() const { return privileges_; } + const std::vector Grantees() const { return grantees_; } + const std::string Database() const { return database_; } + const std::string Target() const { return target_; } + const std::optional TargetType() const { return target_type_; } + const bool IsAllPrivileges() const { return is_all_privileges_; } + + private: + std::optional target_type_; + std::string database_; + std::string target_; + std::vector privileges_; + bool is_all_privileges_; + std::vector grantees_; +}; + class ExplainNode : public SqlNode { public: explicit ExplainNode(const QueryNode *query, node::ExplainType explain_type) diff --git a/hybridse/src/plan/planner.cc b/hybridse/src/plan/planner.cc index b2a57b4128c..3a3984c9b16 100644 --- a/hybridse/src/plan/planner.cc +++ b/hybridse/src/plan/planner.cc @@ -768,6 +768,22 @@ base::Status SimplePlanner::CreatePlanTree(const NodePointVector &parser_trees, plan_trees.push_back(create_user_plan_node); break; } + case ::hybridse::node::kGrantStmt: { + auto node = dynamic_cast(parser_tree); + auto grant_plan_node = node_manager_->MakeNode( + node->TargetType(), node->Database(), node->Target(), node->Privileges(), node->IsAllPrivileges(), + node->Grantees(), node->WithGrantOption()); + plan_trees.push_back(grant_plan_node); + break; + } + case ::hybridse::node::kRevokeStmt: { + auto node = dynamic_cast(parser_tree); + auto revoke_plan_node = node_manager_->MakeNode( + node->TargetType(), node->Database(), node->Target(), node->Privileges(), node->IsAllPrivileges(), + node->Grantees()); + plan_trees.push_back(revoke_plan_node); + break; + } case ::hybridse::node::kAlterUserStmt: { auto node = dynamic_cast(parser_tree); auto alter_user_plan_node = node_manager_->MakeNode(node->Name(), diff --git a/hybridse/src/planv2/ast_node_converter.cc b/hybridse/src/planv2/ast_node_converter.cc index a8453e1221c..23e56924ae2 100644 --- a/hybridse/src/planv2/ast_node_converter.cc +++ b/hybridse/src/planv2/ast_node_converter.cc @@ -24,6 +24,7 @@ #include "absl/strings/ascii.h" #include "absl/strings/match.h" #include "absl/types/span.h" +#include "ast_node_converter.h" #include "base/fe_status.h" #include "node/sql_node.h" #include "udf/udf.h" @@ -725,6 +726,20 @@ base::Status ConvertStatement(const zetasql::ASTStatement* statement, node::Node *output = create_user_node; break; } + case zetasql::AST_GRANT_STATEMENT: { + const zetasql::ASTGrantStatement* grant_stmt = statement->GetAsOrNull(); + node::GrantNode* grant_node = nullptr; + CHECK_STATUS(ConvertGrantStatement(grant_stmt, node_manager, &grant_node)) + *output = grant_node; + break; + } + case zetasql::AST_REVOKE_STATEMENT: { + const zetasql::ASTRevokeStatement* revoke_stmt = statement->GetAsOrNull(); + node::RevokeNode* revoke_node = nullptr; + CHECK_STATUS(ConvertRevokeStatement(revoke_stmt, node_manager, &revoke_node)) + *output = revoke_node; + break; + } case zetasql::AST_ALTER_USER_STATEMENT: { const zetasql::ASTAlterUserStatement* alter_user_stmt = statement->GetAsOrNull(); @@ -2133,6 +2148,81 @@ base::Status ConvertAlterUserStatement(const zetasql::ASTAlterUserStatement* roo return base::Status::OK(); } +base::Status ConvertGrantStatement(const zetasql::ASTGrantStatement* root, node::NodeManager* node_manager, + node::GrantNode** output) { + CHECK_TRUE(root != nullptr, common::kSqlAstError, "not an ASTGrantStatement"); + std::vector target_path; + CHECK_STATUS(AstPathExpressionToStringList(root->target_path(), target_path)); + std::optional target_type = std::nullopt; + if (root->target_type() != nullptr) { + target_type = root->target_type()->GetAsString(); + } + + std::vector privileges; + std::vector grantees; + for (auto privilege : root->privileges()->privileges()) { + if (privilege == nullptr) { + continue; + } + + auto privilege_action = privilege->privilege_action(); + if (privilege_action != nullptr) { + privileges.push_back(privilege_action->GetAsString()); + } + } + + for (auto grantee : root->grantee_list()->grantee_list()) { + if (grantee == nullptr) { + continue; + } + + std::string grantee_str; + CHECK_STATUS(AstStringLiteralToString(grantee, &grantee_str)); + grantees.push_back(grantee_str); + } + *output = node_manager->MakeNode(target_type, target_path.at(0), target_path.at(1), privileges, + root->privileges()->is_all_privileges(), grantees, + root->with_grant_option()); + return base::Status::OK(); +} + +base::Status ConvertRevokeStatement(const zetasql::ASTRevokeStatement* root, node::NodeManager* node_manager, + node::RevokeNode** output) { + CHECK_TRUE(root != nullptr, common::kSqlAstError, "not an ASTRevokeStatement"); + std::vector target_path; + CHECK_STATUS(AstPathExpressionToStringList(root->target_path(), target_path)); + std::optional target_type = std::nullopt; + if (root->target_type() != nullptr) { + target_type = root->target_type()->GetAsString(); + } + + std::vector privileges; + std::vector grantees; + for (auto privilege : root->privileges()->privileges()) { + if (privilege == nullptr) { + continue; + } + + auto privilege_action = privilege->privilege_action(); + if (privilege_action != nullptr) { + privileges.push_back(privilege_action->GetAsString()); + } + } + + for (auto grantee : root->grantee_list()->grantee_list()) { + if (grantee == nullptr) { + continue; + } + + std::string grantee_str; + CHECK_STATUS(AstStringLiteralToString(grantee, &grantee_str)); + grantees.push_back(grantee_str); + } + *output = node_manager->MakeNode(target_type, target_path.at(0), target_path.at(1), privileges, + root->privileges()->is_all_privileges(), grantees); + return base::Status::OK(); +} + base::Status ConvertCreateIndexStatement(const zetasql::ASTCreateIndexStatement* root, node::NodeManager* node_manager, node::CreateIndexNode** output) { CHECK_TRUE(nullptr != root, common::kSqlAstError, "not an ASTCreateIndexStatement") diff --git a/hybridse/src/planv2/ast_node_converter.h b/hybridse/src/planv2/ast_node_converter.h index 631569156d2..edc0fb60c50 100644 --- a/hybridse/src/planv2/ast_node_converter.h +++ b/hybridse/src/planv2/ast_node_converter.h @@ -72,6 +72,12 @@ base::Status ConvertCreateUserStatement(const zetasql::ASTCreateUserStatement* r base::Status ConvertAlterUserStatement(const zetasql::ASTAlterUserStatement* root, node::NodeManager* node_manager, node::AlterUserNode** output); +base::Status ConvertGrantStatement(const zetasql::ASTGrantStatement* root, node::NodeManager* node_manager, + node::GrantNode** output); + +base::Status ConvertRevokeStatement(const zetasql::ASTRevokeStatement* root, node::NodeManager* node_manager, + node::RevokeNode** output); + base::Status ConvertQueryNode(const zetasql::ASTQuery* root, node::NodeManager* node_manager, node::QueryNode** output); base::Status ConvertQueryExpr(const zetasql::ASTQueryExpression* query_expr, node::NodeManager* node_manager, diff --git a/src/auth/user_access_manager.cc b/src/auth/user_access_manager.cc index d668a7dc497..1f354998ef3 100644 --- a/src/auth/user_access_manager.cc +++ b/src/auth/user_access_manager.cc @@ -47,7 +47,7 @@ void UserAccessManager::StopSyncTask() { void UserAccessManager::SyncWithDB() { if (auto it_pair = user_table_iterator_factory_(::openmldb::nameserver::USER_INFO_NAME); it_pair) { - auto new_user_map = std::make_unique>(); + auto new_user_map = std::make_unique>(); auto it = it_pair->first.get(); it->SeekToFirst(); while (it->Valid()) { @@ -56,13 +56,18 @@ void UserAccessManager::SyncWithDB() { auto size = it->GetValue().size(); codec::RowView row_view(*it_pair->second.get(), buf, size); std::string host, user, password; + std::string privilege_level_str; row_view.GetStrValue(0, &host); row_view.GetStrValue(1, &user); row_view.GetStrValue(2, &password); + row_view.GetStrValue(5, &privilege_level_str); + openmldb::nameserver::PrivilegeLevel privilege_level; + ::openmldb::nameserver::PrivilegeLevel_Parse(privilege_level_str, &privilege_level); + UserRecord user_record = {password, privilege_level}; if (host == "%") { - new_user_map->emplace(user, password); + new_user_map->emplace(user, user_record); } else { - new_user_map->emplace(FormUserHost(user, host), password); + new_user_map->emplace(FormUserHost(user, host), user_record); } it->Next(); } @@ -70,12 +75,36 @@ void UserAccessManager::SyncWithDB() { } } +std::optional UserAccessManager::GetUserPassword(const std::string& host, const std::string& user) { + if (auto user_record = user_map_.Get(FormUserHost(user, host)); user_record.has_value()) { + return user_record.value().password; + } else if (auto stored_password = user_map_.Get(user); stored_password.has_value()) { + return stored_password.value().password; + } else { + return std::nullopt; + } +} + bool UserAccessManager::IsAuthenticated(const std::string& host, const std::string& user, const std::string& password) { - if (auto stored_password = user_map_.Get(FormUserHost(user, host)); stored_password.has_value()) { - return stored_password.value() == password; + if (auto user_record = user_map_.Get(FormUserHost(user, host)); user_record.has_value()) { + return user_record.value().password == password; } else if (auto stored_password = user_map_.Get(user); stored_password.has_value()) { - return stored_password.value() == password; + return stored_password.value().password == password; } return false; } + +::openmldb::nameserver::PrivilegeLevel UserAccessManager::GetPrivilegeLevel(const std::string& user_at_host) { + std::size_t at_pos = user_at_host.find('@'); + if (at_pos != std::string::npos) { + std::string user = user_at_host.substr(0, at_pos); + std::string host = user_at_host.substr(at_pos + 1); + if (auto user_record = user_map_.Get(FormUserHost(user, host)); user_record.has_value()) { + return user_record.value().privilege_level; + } else if (auto stored_password = user_map_.Get(user); stored_password.has_value()) { + return stored_password.value().privilege_level; + } + } + return ::openmldb::nameserver::PrivilegeLevel::NO_PRIVILEGE; +} } // namespace openmldb::auth diff --git a/src/auth/user_access_manager.h b/src/auth/user_access_manager.h index 996efc326c4..9de6890f93a 100644 --- a/src/auth/user_access_manager.h +++ b/src/auth/user_access_manager.h @@ -26,9 +26,15 @@ #include #include "catalog/distribute_iterator.h" +#include "proto/name_server.pb.h" #include "refreshable_map.h" namespace openmldb::auth { +struct UserRecord { + std::string password; + ::openmldb::nameserver::PrivilegeLevel privilege_level; +}; + class UserAccessManager { public: using IteratorFactory = std::function GetUserPassword(const std::string& host, const std::string& user); private: IteratorFactory user_table_iterator_factory_; - RefreshableMap user_map_; + RefreshableMap user_map_; std::thread sync_task_thread_; std::promise stop_promise_; void StartSyncTask(); diff --git a/src/base/status.h b/src/base/status.h index c7e5ec75198..a2da254e78e 100644 --- a/src/base/status.h +++ b/src/base/status.h @@ -186,7 +186,8 @@ enum ReturnCode { kRPCError = 1004, // brpc controller error // auth - kFlushPrivilegesFailed = 1100 // brpc controller error + kFlushPrivilegesFailed = 1100, // brpc controller error + kNotAuthorized = 1101 // brpc controller error }; struct Status { diff --git a/src/client/ns_client.cc b/src/client/ns_client.cc index cdeef07e521..9a4c6f4df6d 100644 --- a/src/client/ns_client.cc +++ b/src/client/ns_client.cc @@ -317,6 +317,35 @@ bool NsClient::PutUser(const std::string& host, const std::string& name, const s return false; } +bool NsClient::PutPrivilege(const std::optional target_type, const std::string database, + const std::string target, const std::vector privileges, + const bool is_all_privileges, const std::vector grantees, + const ::openmldb::nameserver::PrivilegeLevel privilege_level) { + ::openmldb::nameserver::PutPrivilegeRequest request; + if (target_type.has_value()) { + request.set_target_type(target_type.value()); + } + request.set_database(database); + request.set_target(target); + for (const auto& privilege : privileges) { + request.add_privilege(privilege); + } + request.set_is_all_privileges(is_all_privileges); + for (const auto& grantee : grantees) { + request.add_grantee(grantee); + } + + request.set_privilege_level(privilege_level); + + ::openmldb::nameserver::GeneralResponse response; + bool ok = client_.SendRequest(&::openmldb::nameserver::NameServer_Stub::PutPrivilege, &request, &response, + FLAGS_request_timeout_ms, 1); + if (ok && response.code() == 0) { + return true; + } + return false; +} + bool NsClient::DeleteUser(const std::string& host, const std::string& name) { ::openmldb::nameserver::DeleteUserRequest request; request.set_host(host); diff --git a/src/client/ns_client.h b/src/client/ns_client.h index 73a52854765..1ddd50963bf 100644 --- a/src/client/ns_client.h +++ b/src/client/ns_client.h @@ -112,6 +112,11 @@ class NsClient : public Client { bool PutUser(const std::string& host, const std::string& name, const std::string& password); // NOLINT + bool PutPrivilege(const std::optional target_type, const std::string database, + const std::string target, const std::vector privileges, const bool is_all_privileges, + const std::vector grantees, + const ::openmldb::nameserver::PrivilegeLevel privilege_level); // NOLINT + bool DeleteUser(const std::string& host, const std::string& name); // NOLINT bool DropTable(const std::string& db, const std::string& name, diff --git a/src/cmd/sql_cmd_test.cc b/src/cmd/sql_cmd_test.cc index fe8faa21504..79225eb52dd 100644 --- a/src/cmd/sql_cmd_test.cc +++ b/src/cmd/sql_cmd_test.cc @@ -243,7 +243,6 @@ TEST_P(DBSDKTest, TestUser) { ASSERT_FALSE(status.IsOK()); sr->ExecuteSQL(absl::StrCat("CREATE USER IF NOT EXISTS user1"), &status); ASSERT_TRUE(status.IsOK()); - ASSERT_TRUE(true); auto opt = sr->GetRouterOptions(); if (cs->IsClusterMode()) { auto real_opt = std::dynamic_pointer_cast(opt); @@ -280,6 +279,121 @@ TEST_P(DBSDKTest, TestUser) { ASSERT_TRUE(status.IsOK()); } +TEST_P(DBSDKTest, TestGrantCreateUser) { + auto cli = GetParam(); + cs = cli->cs; + sr = cli->sr; + hybridse::sdk::Status status; + sr->ExecuteSQL(absl::StrCat("CREATE USER user1 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + auto opt = sr->GetRouterOptions(); + if (cs->IsClusterMode()) { + auto real_opt = std::dynamic_pointer_cast(opt); + sdk::SQLRouterOptions opt1; + opt1.zk_cluster = real_opt->zk_cluster; + opt1.zk_path = real_opt->zk_path; + opt1.user = "user1"; + opt1.password = "123456"; + auto router = NewClusterSQLRouter(opt1); + ASSERT_TRUE(router != nullptr); + router->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_FALSE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user1@%'"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("DROP USER user2"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("REVOKE CREATE USER ON *.* FROM 'user1@%'"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("CREATE USER user3 OPTIONS(password='123456')"), &status); + ASSERT_FALSE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("DROP USER user2"), &status); + ASSERT_FALSE(status.IsOK()); + } else { + auto real_opt = std::dynamic_pointer_cast(opt); + sdk::StandaloneOptions opt1; + opt1.host = real_opt->host; + opt1.port = real_opt->port; + opt1.user = "user1"; + opt1.password = "123456"; + auto router = NewStandaloneSQLRouter(opt1); + ASSERT_TRUE(router != nullptr); + router->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_FALSE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user1@%'"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("DROP USER user2"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("REVOKE CREATE USER ON *.* FROM 'user1@%'"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("CREATE USER user3 OPTIONS(password='123456')"), &status); + ASSERT_FALSE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("DROP USER user2"), &status); + ASSERT_FALSE(status.IsOK()); + } + sr->ExecuteSQL(absl::StrCat("DROP USER IF EXISTS user1"), &status); + ASSERT_TRUE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("DROP USER IF EXISTS user2"), &status); + ASSERT_TRUE(status.IsOK()); +} + +TEST_P(DBSDKTest, TestGrantCreateUserGrantOption) { + auto cli = GetParam(); + cs = cli->cs; + sr = cli->sr; + hybridse::sdk::Status status; + sr->ExecuteSQL(absl::StrCat("CREATE USER user1 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user1@%'"), &status); + ASSERT_TRUE(status.IsOK()); + + auto opt = sr->GetRouterOptions(); + if (cs->IsClusterMode()) { + auto real_opt = std::dynamic_pointer_cast(opt); + sdk::SQLRouterOptions opt1; + opt1.zk_cluster = real_opt->zk_cluster; + opt1.zk_path = real_opt->zk_path; + opt1.user = "user1"; + opt1.password = "123456"; + auto router = NewClusterSQLRouter(opt1); + ASSERT_TRUE(router != nullptr); + router->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user2@%'"), &status); + ASSERT_FALSE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user1@%' WITH GRANT OPTION"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user2@%'"), &status); + ASSERT_TRUE(status.IsOK()); + } else { + auto real_opt = std::dynamic_pointer_cast(opt); + sdk::StandaloneOptions opt1; + opt1.host = real_opt->host; + opt1.port = real_opt->port; + opt1.user = "user1"; + opt1.password = "123456"; + auto router = NewStandaloneSQLRouter(opt1); + ASSERT_TRUE(router != nullptr); + router->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user2@%'"), &status); + ASSERT_FALSE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user1@%' WITH GRANT OPTION"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user2@%'"), &status); + ASSERT_TRUE(status.IsOK()); + } + sr->ExecuteSQL(absl::StrCat("DROP USER IF EXISTS user1"), &status); + ASSERT_TRUE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("DROP USER IF EXISTS user2"), &status); + ASSERT_TRUE(status.IsOK()); +} + TEST_P(DBSDKTest, CreateDatabase) { auto cli = GetParam(); cs = cli->cs; diff --git a/src/nameserver/name_server_impl.cc b/src/nameserver/name_server_impl.cc index 9c565272fb3..5e65a7d2d94 100644 --- a/src/nameserver/name_server_impl.cc +++ b/src/nameserver/name_server_impl.cc @@ -1380,7 +1380,8 @@ void NameServerImpl::ShowTablet(RpcController* controller, const ShowTabletReque } base::Status NameServerImpl::PutUserRecord(const std::string& host, const std::string& user, - const std::string& password) { + const std::string& password, + const ::openmldb::nameserver::PrivilegeLevel privilege_level) { std::shared_ptr table_info; if (!GetTableInfo(USER_INFO_NAME, INTERNAL_DB, &table_info)) { return {ReturnCode::kTableIsNotExist, "user table does not exist"}; @@ -1391,12 +1392,8 @@ base::Status NameServerImpl::PutUserRecord(const std::string& host, const std::s row_values.push_back(user); row_values.push_back(password); row_values.push_back("0"); // password_last_changed - row_values.push_back("0"); // password_expired_time - row_values.push_back("0"); // create_time - row_values.push_back("0"); // update_time - row_values.push_back("1"); // account_type - row_values.push_back("0"); // privileges - row_values.push_back("null"); // extra_info + row_values.push_back("0"); // password_expired + row_values.push_back(PrivilegeLevel_Name(privilege_level)); // Create_user_priv std::string encoded_row; codec::RowCodec::EncodeRow(row_values, table_info->column_desc(), 1, encoded_row); @@ -1431,7 +1428,6 @@ base::Status NameServerImpl::DeleteUserRecord(const std::string& host, const std for (int meta_idx = 0; meta_idx < table_partition.partition_meta_size(); meta_idx++) { if (table_partition.partition_meta(meta_idx).is_leader() && table_partition.partition_meta(meta_idx).is_alive()) { - uint64_t cur_ts = ::baidu::common::timer::get_micros() / 1000; std::string endpoint = table_partition.partition_meta(meta_idx).endpoint(); auto table_ptr = GetTablet(endpoint); if (!table_ptr->client_->Delete(tid, 0, host + "|" + user, "index", msg)) { @@ -5640,7 +5636,8 @@ void NameServerImpl::OnLocked() { CreateDatabaseOrExit(INTERNAL_DB); if (db_table_info_[INTERNAL_DB].count(USER_INFO_NAME) == 0) { CreateSystemTableOrExit(SystemTableType::kUser); - PutUserRecord("%", "root", "1e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"); + PutUserRecord("%", "root", "1e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + ::openmldb::nameserver::PrivilegeLevel::PRIVILEGE_WITH_GRANT_OPTION); } if (IsClusterMode()) { if (tablets_.size() < FLAGS_system_table_replica_num) { @@ -9663,15 +9660,64 @@ NameServerImpl::GetSystemTableIterator() { void NameServerImpl::PutUser(RpcController* controller, const PutUserRequest* request, GeneralResponse* response, Closure* done) { brpc::ClosureGuard done_guard(done); - auto status = PutUserRecord(request->host(), request->name(), request->password()); - base::SetResponseStatus(status, response); + brpc::Controller* brpc_controller = static_cast(controller); + + if (brpc_controller->auth_context()->is_service() || + user_access_manager_.GetPrivilegeLevel(brpc_controller->auth_context()->user()) > + ::openmldb::nameserver::PrivilegeLevel::NO_PRIVILEGE) { + auto status = PutUserRecord(request->host(), request->name(), request->password(), + ::openmldb::nameserver::PrivilegeLevel::NO_PRIVILEGE); + base::SetResponseStatus(status, response); + } else { + base::SetResponseStatus(base::ReturnCode::kNotAuthorized, "not authorized to create user", response); + } +} + +void NameServerImpl::PutPrivilege(RpcController* controller, const PutPrivilegeRequest* request, + GeneralResponse* response, Closure* done) { + brpc::ClosureGuard done_guard(done); + + for (int i = 0; i < request->privilege_size(); ++i) { + auto privilege = request->privilege(i); + if (privilege == "CREATE USER") { + brpc::Controller* brpc_controller = static_cast(controller); + if (brpc_controller->auth_context()->is_service() || + user_access_manager_.GetPrivilegeLevel(brpc_controller->auth_context()->user()) >= + ::openmldb::nameserver::PrivilegeLevel::PRIVILEGE_WITH_GRANT_OPTION) { + for (int i = 0; i < request->grantee_size(); ++i) { + auto grantee = request->grantee(i); + std::size_t at_pos = grantee.find('@'); + if (at_pos != std::string::npos) { + std::string user = grantee.substr(0, at_pos); + std::string host = grantee.substr(at_pos + 1); + auto password = user_access_manager_.GetUserPassword(host, user); + if (password.has_value()) { + auto status = PutUserRecord(host, user, password.value(), request->privilege_level()); + base::SetResponseStatus(status, response); + } + } + } + } else { + base::SetResponseStatus(base::ReturnCode::kNotAuthorized, + "not authorized to grant create user privilege", response); + } + } + } } void NameServerImpl::DeleteUser(RpcController* controller, const DeleteUserRequest* request, GeneralResponse* response, Closure* done) { brpc::ClosureGuard done_guard(done); - auto status = DeleteUserRecord(request->host(), request->name()); - base::SetResponseStatus(status, response); + brpc::Controller* brpc_controller = static_cast(controller); + + if (brpc_controller->auth_context()->is_service() || + user_access_manager_.GetPrivilegeLevel(brpc_controller->auth_context()->user()) > + ::openmldb::nameserver::PrivilegeLevel::NO_PRIVILEGE) { + auto status = DeleteUserRecord(request->host(), request->name()); + base::SetResponseStatus(status, response); + } else { + base::SetResponseStatus(base::ReturnCode::kNotAuthorized, "not authorized to create user", response); + } } bool NameServerImpl::IsAuthenticated(const std::string& host, const std::string& username, diff --git a/src/nameserver/name_server_impl.h b/src/nameserver/name_server_impl.h index dadc335c7a3..53f44cde278 100644 --- a/src/nameserver/name_server_impl.h +++ b/src/nameserver/name_server_impl.h @@ -360,6 +360,8 @@ class NameServerImpl : public NameServer { Closure* done); void PutUser(RpcController* controller, const PutUserRequest* request, GeneralResponse* response, Closure* done); + void PutPrivilege(RpcController* controller, const PutPrivilegeRequest* request, GeneralResponse* response, + Closure* done); void DeleteUser(RpcController* controller, const DeleteUserRequest* request, GeneralResponse* response, Closure* done); bool IsAuthenticated(const std::string& host, const std::string& username, const std::string& password); @@ -373,7 +375,8 @@ class NameServerImpl : public NameServer { bool GetTableInfo(const std::string& table_name, const std::string& db_name, std::shared_ptr* table_info); - base::Status PutUserRecord(const std::string& host, const std::string& user, const std::string& password); + base::Status PutUserRecord(const std::string& host, const std::string& user, const std::string& password, + const ::openmldb::nameserver::PrivilegeLevel privilege_level); base::Status DeleteUserRecord(const std::string& host, const std::string& user); base::Status FlushPrivileges(); diff --git a/src/nameserver/system_table.h b/src/nameserver/system_table.h index cda34e1798e..03c8bc2364e 100644 --- a/src/nameserver/system_table.h +++ b/src/nameserver/system_table.h @@ -163,20 +163,16 @@ class SystemTable { break; } case SystemTableType::kUser: { - SetColumnDesc("host", type::DataType::kString, table_info->add_column_desc()); - SetColumnDesc("user", type::DataType::kString, table_info->add_column_desc()); - SetColumnDesc("password", type::DataType::kString, table_info->add_column_desc()); + SetColumnDesc("Host", type::DataType::kString, table_info->add_column_desc()); + SetColumnDesc("User", type::DataType::kString, table_info->add_column_desc()); + SetColumnDesc("authentication_string", type::DataType::kString, table_info->add_column_desc()); SetColumnDesc("password_last_changed", type::DataType::kTimestamp, table_info->add_column_desc()); - SetColumnDesc("password_expired_time", type::DataType::kBigInt, table_info->add_column_desc()); - SetColumnDesc("create_time", type::DataType::kTimestamp, table_info->add_column_desc()); - SetColumnDesc("update_time", type::DataType::kTimestamp, table_info->add_column_desc()); - SetColumnDesc("account_type", type::DataType::kInt, table_info->add_column_desc()); - SetColumnDesc("privileges", type::DataType::kString, table_info->add_column_desc()); - SetColumnDesc("extra_info", type::DataType::kString, table_info->add_column_desc()); + SetColumnDesc("password_expired", type::DataType::kTimestamp, table_info->add_column_desc()); + SetColumnDesc("Create_user_priv", type::DataType::kString, table_info->add_column_desc()); auto index = table_info->add_column_key(); index->set_index_name("index"); - index->add_col_name("host"); - index->add_col_name("user"); + index->add_col_name("Host"); + index->add_col_name("User"); auto ttl = index->mutable_ttl(); ttl->set_ttl_type(::openmldb::type::kLatestTime); ttl->set_lat_ttl(1); diff --git a/src/proto/name_server.proto b/src/proto/name_server.proto index f7c8fd5c830..14cd00d6ddd 100755 --- a/src/proto/name_server.proto +++ b/src/proto/name_server.proto @@ -544,6 +544,22 @@ message DeleteUserRequest { required string name = 2; } +enum PrivilegeLevel { + NO_PRIVILEGE = 0; + PRIVILEGE = 1; + PRIVILEGE_WITH_GRANT_OPTION = 2; +} + +message PutPrivilegeRequest { + repeated string grantee = 1; + repeated string privilege = 2; + optional string target_type = 3; + required string database = 4; + required string target = 5; + required bool is_all_privileges = 6; + required PrivilegeLevel privilege_level = 7; +} + message DeploySQLRequest { optional openmldb.api.ProcedureInfo sp_info = 3; repeated TableIndex index = 4; @@ -617,4 +633,7 @@ service NameServer { // user related interfaces rpc PutUser(PutUserRequest) returns (GeneralResponse); rpc DeleteUser(DeleteUserRequest) returns (GeneralResponse); + + // authz related interfaces + rpc PutPrivilege(PutPrivilegeRequest) returns (GeneralResponse); } diff --git a/src/sdk/sql_cluster_router.cc b/src/sdk/sql_cluster_router.cc index dbdd7dede9d..3d09156fdcc 100644 --- a/src/sdk/sql_cluster_router.cc +++ b/src/sdk/sql_cluster_router.cc @@ -2786,6 +2786,30 @@ std::shared_ptr SQLClusterRouter::ExecuteSQL( } return {}; } + case hybridse::node::kPlanTypeGrant: { + auto grant_node = dynamic_cast(node); + auto ns = cluster_sdk_->GetNsClient(); + auto ok = ns->PutPrivilege(grant_node->TargetType(), grant_node->Database(), grant_node->Target(), + grant_node->Privileges(), grant_node->IsAllPrivileges(), grant_node->Grantees(), + grant_node->WithGrantOption() + ? ::openmldb::nameserver::PrivilegeLevel::PRIVILEGE_WITH_GRANT_OPTION + : ::openmldb::nameserver::PrivilegeLevel::PRIVILEGE); + if (!ok) { + *status = {StatusCode::kCmdError, "Grant API call failed"}; + } + return {}; + } + case hybridse::node::kPlanTypeRevoke: { + auto revoke_node = dynamic_cast(node); + auto ns = cluster_sdk_->GetNsClient(); + auto ok = ns->PutPrivilege(revoke_node->TargetType(), revoke_node->Database(), revoke_node->Target(), + revoke_node->Privileges(), revoke_node->IsAllPrivileges(), + revoke_node->Grantees(), ::openmldb::nameserver::PrivilegeLevel::NO_PRIVILEGE); + if (!ok) { + *status = {StatusCode::kCmdError, "Revoke API call failed"}; + } + return {}; + } case hybridse::node::kPlanTypeAlterUser: { auto alter_node = dynamic_cast(node); UserInfo user_info;