From 397eef509b85d7ec4ed68a368cf9d22352b2c464 Mon Sep 17 00:00:00 2001 From: Laura Date: Sun, 17 Dec 2023 21:11:46 +0100 Subject: [PATCH 1/3] Implement credentialed CORS according to MDN --- examples/server/server.cpp | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 04038530f94da..5b9499c004cd4 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2710,9 +2710,15 @@ int main(int argc, char **argv) return false; }; - svr.set_default_headers({{"Server", "llama.cpp"}, - {"Access-Control-Allow-Origin", "*"}, - {"Access-Control-Allow-Headers", "content-type"}}); + svr.set_default_headers({{"Server", "llama.cpp"}}); + + // CORS preflight + svr.Options(R"(.*)", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + res.set_header("Access-Control-Allow-Credentials", "true"); + res.set_header("Access-Control-Allow-Methods", "POST"); + res.set_header("Access-Control-Allow-Headers", "*"); + }); // this is only called if no index.html is found in the public --path svr.Get("/", [](const httplib::Request &, httplib::Response &res) @@ -2744,7 +2750,7 @@ int main(int argc, char **argv) svr.Get("/props", [&llama](const httplib::Request & /*req*/, httplib::Response &res) { - res.set_header("Access-Control-Allow-Origin", "*"); + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { { "user_name", llama.name_user.c_str() }, { "assistant_name", llama.name_assistant.c_str() } @@ -2754,6 +2760,7 @@ int main(int argc, char **argv) svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; } @@ -2821,10 +2828,9 @@ int main(int argc, char **argv) } }); - - svr.Get("/v1/models", [¶ms](const httplib::Request&, httplib::Response& res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); std::time_t t = std::time(0); json models = { @@ -2842,9 +2848,11 @@ int main(int argc, char **argv) res.set_content(models.dump(), "application/json; charset=utf-8"); }); + // TODO: add mount point without "/v1" prefix -- how? svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; } @@ -2918,6 +2926,7 @@ int main(int argc, char **argv) svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; } @@ -2990,6 +2999,7 @@ int main(int argc, char **argv) svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); std::vector tokens; if (body.count("content") != 0) @@ -3002,6 +3012,7 @@ int main(int argc, char **argv) svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); std::string content; if (body.count("tokens") != 0) @@ -3016,6 +3027,7 @@ int main(int argc, char **argv) svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); json prompt; if (body.count("content") != 0) From a6e9700a836463f4153f65c9c16a5c297be7f4f3 Mon Sep 17 00:00:00 2001 From: Laura Date: Wed, 20 Dec 2023 22:54:38 +0100 Subject: [PATCH 2/3] Fix syntax error --- examples/server/server.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 5b9499c004cd4..fbc84e79dd8cb 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2748,7 +2748,7 @@ int main(int argc, char **argv) return false; }); - svr.Get("/props", [&llama](const httplib::Request & /*req*/, httplib::Response &res) + svr.Get("/props", [&llama](const httplib::Request & req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { @@ -2828,7 +2828,7 @@ int main(int argc, char **argv) } }); - svr.Get("/v1/models", [¶ms](const httplib::Request&, httplib::Response& res) + svr.Get("/v1/models", [¶ms](const httplib::Request& req, httplib::Response& res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); std::time_t t = std::time(0); From af9bf475e29dd834017f86087e7448cdf7c16a00 Mon Sep 17 00:00:00 2001 From: Laura Date: Thu, 11 Jan 2024 18:54:19 +0100 Subject: [PATCH 3/3] Move validate_api_key up so it is defined before its first usage --- examples/server/server.cpp | 55 +++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index aef1d207b25aa..ac0ee6faa33a0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2802,8 +2802,35 @@ int main(int argc, char **argv) svr.set_default_headers({{"Server", "llama.cpp"}}); + // Middleware for API key validation + auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { + // If API key is not set, skip validation + if (sparams.api_key.empty()) { + return true; + } + + // Check for API key in the header + auto auth_header = req.get_header_value("Authorization"); + std::string prefix = "Bearer "; + if (auth_header.substr(0, prefix.size()) == prefix) { + std::string received_api_key = auth_header.substr(prefix.size()); + if (received_api_key == sparams.api_key) { + return true; // API key is valid + } + } + + // API key is invalid or not provided + res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8"); + res.status = 401; // Unauthorized + + LOG_WARNING("Unauthorized: Invalid API Key", {}); + + return false; + }; + + // CORS preflight - svr.Options(R"(.*)", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) { + svr.Options(R"(.*)", [](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Credentials", "true"); res.set_header("Access-Control-Allow-Methods", "POST"); @@ -2915,32 +2942,6 @@ int main(int argc, char **argv) LOG_INFO("model loaded", {}); } - // Middleware for API key validation - auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { - // If API key is not set, skip validation - if (sparams.api_key.empty()) { - return true; - } - - // Check for API key in the header - auto auth_header = req.get_header_value("Authorization"); - std::string prefix = "Bearer "; - if (auth_header.substr(0, prefix.size()) == prefix) { - std::string received_api_key = auth_header.substr(prefix.size()); - if (received_api_key == sparams.api_key) { - return true; // API key is valid - } - } - - // API key is invalid or not provided - res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8"); - res.status = 401; // Unauthorized - - LOG_WARNING("Unauthorized: Invalid API Key", {}); - - return false; - }; - // this is only called if no index.html is found in the public --path svr.Get("/", [](const httplib::Request &, httplib::Response &res) {