diff --git a/engine/common/api_server_configuration.h b/engine/common/api_server_configuration.h index 03b3022a4..63383301b 100644 --- a/engine/common/api_server_configuration.h +++ b/engine/common/api_server_configuration.h @@ -107,7 +107,7 @@ class ApiServerConfiguration { const std::string& proxy_url = "", const std::string& proxy_username = "", const std::string& proxy_password = "", const std::string& no_proxy = "", bool verify_peer_ssl = true, bool verify_host_ssl = true, - const std::string& hf_token = "") + const std::string& hf_token = "", std::vector api_keys = {}) : cors{cors}, allowed_origins{allowed_origins}, verify_proxy_ssl{verify_proxy_ssl}, @@ -118,7 +118,8 @@ class ApiServerConfiguration { no_proxy{no_proxy}, verify_peer_ssl{verify_peer_ssl}, verify_host_ssl{verify_host_ssl}, - hf_token{hf_token} {} + hf_token{hf_token}, + api_keys{api_keys} {} // cors bool cors{true}; @@ -139,6 +140,9 @@ class ApiServerConfiguration { // token std::string hf_token{""}; + // authentication + std::vector api_keys; + Json::Value ToJson() const { Json::Value root; root["cors"] = cors; @@ -155,6 +159,10 @@ class ApiServerConfiguration { root["verify_peer_ssl"] = verify_peer_ssl; root["verify_host_ssl"] = verify_host_ssl; root["huggingface_token"] = hf_token; + root["api_keys"] = Json::Value(Json::arrayValue); + for (const auto& api_key : api_keys) { + root["api_keys"].append(api_key); + } return root; } @@ -256,7 +264,8 @@ class ApiServerConfiguration { return true; }}, - {"allowed_origins", [this](const Json::Value& value) -> bool { + {"allowed_origins", + [this](const Json::Value& value) -> bool { if (!value.isArray()) { return false; } @@ -271,7 +280,26 @@ class ApiServerConfiguration { this->allowed_origins.push_back(origin.asString()); } return true; - }}}; + }}, + + {"api_keys", + [this](const Json::Value& value) -> bool { + if (!value.isArray()) { + return false; + } + for (const auto& key : value) { + if (!key.isString()) { + return false; + } + } + + this->api_keys.clear(); + for (const auto& key : value) { + this->api_keys.push_back(key.asString()); + } + return true; + }}, + }; for (const auto& key : json.getMemberNames()) { auto updater = field_updater.find(key); diff --git a/engine/main.cc b/engine/main.cc index 2f60916a6..51ace2d9b 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -249,6 +249,55 @@ void RunServer(std::optional host, std::optional port, .setClientMaxBodySize(256 * 1024 * 1024) // Max 256MiB body size .setClientMaxMemoryBodySize(1024 * 1024); // 1MiB before writing to disk + auto validate_api_key = [config_service](const drogon::HttpRequestPtr& req) { + auto const& api_keys = + config_service->GetApiServerConfiguration()->api_keys; + static const std::unordered_set public_endpoints = { + "/healthz", "/processManager/destroy"}; + + // If API key is not set, skip validation + if (api_keys.empty()) { + return true; + } + + // If path is public or is static file, skip validation + if (public_endpoints.find(req->path()) != public_endpoints.end() || + req->path() == "/") { + return true; + } + + // Check for API key in the header + auto auth_header = req->getHeader("Authorization"); + + std::string prefix = "Bearer "; + if (auth_header.substr(0, prefix.size()) == prefix) { + std::string received_api_key = auth_header.substr(prefix.size()); + if (std::find(api_keys.begin(), api_keys.end(), received_api_key) != + api_keys.end()) { + return true; // API key is valid + } + } + + CTL_WRN("Unauthorized: Invalid API Key\n"); + return false; + }; + + drogon::app().registerPreRoutingAdvice( + [&validate_api_key]( + const drogon::HttpRequestPtr& req, + std::function&& cb, + drogon::AdviceChainCallback&& ccb) { + if (!validate_api_key(req)) { + Json::Value ret; + ret["message"] = "Invalid API Key"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(drogon::k401Unauthorized); + cb(resp); + return; + } + ccb(); + }); + // CORS drogon::app().registerPostHandlingAdvice( [config_service](const drogon::HttpRequestPtr& req, diff --git a/engine/services/config_service.cc b/engine/services/config_service.cc index ce5526090..ae90e93fb 100644 --- a/engine/services/config_service.cc +++ b/engine/services/config_service.cc @@ -6,10 +6,10 @@ cpp::result ConfigService::UpdateApiServerConfiguration(const Json::Value& json) { auto config = file_manager_utils::GetCortexConfig(); ApiServerConfiguration api_server_config{ - config.enableCors, config.allowedOrigins, config.verifyProxySsl, - config.verifyProxyHostSsl, config.proxyUrl, config.proxyUsername, - config.proxyPassword, config.noProxy, config.verifyPeerSsl, - config.verifyHostSsl, config.huggingFaceToken}; + config.enableCors, config.allowedOrigins, config.verifyProxySsl, + config.verifyProxyHostSsl, config.proxyUrl, config.proxyUsername, + config.proxyPassword, config.noProxy, config.verifyPeerSsl, + config.verifyHostSsl, config.huggingFaceToken, config.apiKeys}; std::vector updated_fields; std::vector invalid_fields; @@ -36,6 +36,7 @@ ConfigService::UpdateApiServerConfiguration(const Json::Value& json) { config.verifyHostSsl = api_server_config.verify_host_ssl; config.huggingFaceToken = api_server_config.hf_token; + config.apiKeys = api_server_config.api_keys; auto result = file_manager_utils::UpdateCortexConfig(config); return api_server_config; @@ -45,8 +46,8 @@ cpp::result ConfigService::GetApiServerConfiguration() { auto config = file_manager_utils::GetCortexConfig(); return ApiServerConfiguration{ - config.enableCors, config.allowedOrigins, config.verifyProxySsl, - config.verifyProxyHostSsl, config.proxyUrl, config.proxyUsername, - config.proxyPassword, config.noProxy, config.verifyPeerSsl, - config.verifyHostSsl, config.huggingFaceToken}; + config.enableCors, config.allowedOrigins, config.verifyProxySsl, + config.verifyProxyHostSsl, config.proxyUrl, config.proxyUsername, + config.proxyPassword, config.noProxy, config.verifyPeerSsl, + config.verifyHostSsl, config.huggingFaceToken, config.apiKeys}; } diff --git a/engine/utils/config_yaml_utils.cc b/engine/utils/config_yaml_utils.cc index b26d690c6..49b31acd0 100644 --- a/engine/utils/config_yaml_utils.cc +++ b/engine/utils/config_yaml_utils.cc @@ -51,6 +51,7 @@ cpp::result CortexConfigMgr::DumpYamlConfig( node["sslKeyPath"] = config.sslKeyPath; node["supportedEngines"] = config.supportedEngines; node["checkedForSyncHubAt"] = config.checkedForSyncHubAt; + node["apiKeys"] = config.apiKeys; out_file << node; out_file.close(); @@ -87,7 +88,7 @@ CortexConfig CortexConfigMgr::FromYaml(const std::string& path, !node["verifyProxySsl"] || !node["verifyProxyHostSsl"] || !node["supportedEngines"] || !node["sslCertPath"] || !node["sslKeyPath"] || !node["noProxy"] || - !node["checkedForSyncHubAt"]); + !node["checkedForSyncHubAt"] || !node["apiKeys"]); CortexConfig config = { .logFolderPath = node["logFolderPath"] @@ -182,6 +183,11 @@ CortexConfig CortexConfigMgr::FromYaml(const std::string& path, .checkedForSyncHubAt = node["checkedForSyncHubAt"] ? node["checkedForSyncHubAt"].as() : default_cfg.checkedForSyncHubAt, + .apiKeys = + node["apiKeys"] + ? node["apiKeys"].as>() + : default_cfg.apiKeys, + }; if (should_update_config) { l.unlock(); diff --git a/engine/utils/config_yaml_utils.h b/engine/utils/config_yaml_utils.h index 1749cd2d0..c94b8fe5f 100644 --- a/engine/utils/config_yaml_utils.h +++ b/engine/utils/config_yaml_utils.h @@ -68,6 +68,7 @@ struct CortexConfig { std::string sslKeyPath; std::vector supportedEngines; uint64_t checkedForSyncHubAt; + std::vector apiKeys; }; class CortexConfigMgr { diff --git a/engine/utils/file_manager_utils.cc b/engine/utils/file_manager_utils.cc index b5713456a..575a3cb9b 100644 --- a/engine/utils/file_manager_utils.cc +++ b/engine/utils/file_manager_utils.cc @@ -219,6 +219,7 @@ config_yaml_utils::CortexConfig GetDefaultConfig() { .sslKeyPath = "", .supportedEngines = config_yaml_utils::kDefaultSupportedEngines, .checkedForSyncHubAt = 0u, + .apiKeys = {}, }; }