diff --git a/Cargo.lock b/Cargo.lock index bd6dc8a5..503e0d1d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,15 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "aho-corasick" -version = "0.7.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e37cfd5e7657ada45f742d6e99ca5788580b5c529dc78faf11ece6dc702656f" -dependencies = [ - "memchr", -] - [[package]] name = "ansi_term" version = "0.12.1" @@ -170,12 +161,6 @@ version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af51b1b4a7fdff033703db39de8802c673eb91855f2e0d47dcf3bf2c0ef01f99" -[[package]] -name = "futures-sink" -version = "0.3.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0f30aaa67363d119812743aa5f33c201a7a66329f97d1a887022971feea4b53" - [[package]] name = "futures-task" version = "0.3.16" @@ -195,25 +180,6 @@ dependencies = [ "pin-utils", ] -[[package]] -name = "h2" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "825343c4eef0b63f541f8903f395dc5beb362a979b5799a84062527ef1e37726" -dependencies = [ - "bytes", - "fnv", - "futures-core", - "futures-sink", - "futures-util", - "http", - "indexmap", - "slab", - "tokio", - "tokio-util", - "tracing", -] - [[package]] name = "hashbrown" version = "0.11.2" @@ -282,7 +248,6 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2", "http", "http-body", "httparse", @@ -400,7 +365,6 @@ dependencies = [ "once_cell", "parking_lot", "prometheus", - "routerify", "tokio", "tracing", "tracing-subscriber", @@ -484,12 +448,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "percent-encoding" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e" - [[package]] name = "pest" version = "2.1.3" @@ -589,8 +547,6 @@ version = "1.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d07a8629359eb56f1e2fb1652bb04212c072a87ba68546a04065d525673ac461" dependencies = [ - "aho-corasick", - "memchr", "regex-syntax", ] @@ -609,19 +565,6 @@ version = "0.6.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" -[[package]] -name = "routerify" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c6bb49594c791cadb5ccfa5f36d41b498d40482595c199d10cd318800280bd9" -dependencies = [ - "http", - "hyper", - "lazy_static", - "percent-encoding", - "regex", -] - [[package]] name = "rustc_version" version = "0.3.3" @@ -696,12 +639,6 @@ dependencies = [ "libc", ] -[[package]] -name = "slab" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c307a32c1c5c437f38c7fd45d753050587732ba8628319fbdf12a7e289ccc590" - [[package]] name = "smallvec" version = "1.6.1" @@ -710,9 +647,9 @@ checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e" [[package]] name = "socket2" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "765f090f0e423d2b55843402a07915add955e7d60657db13707a159727326cad" +checksum = "5dc90fe6c7be1a323296982db1836d1ea9e47b6839496dde9a541bc496df3516" dependencies = [ "libc", "winapi", @@ -813,20 +750,6 @@ dependencies = [ "syn", ] -[[package]] -name = "tokio-util" -version = "0.6.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1caa0b0c8d94a049db56b5acf8cba99dc0623aab1b26d5b5f5e2d945846b3592" -dependencies = [ - "bytes", - "futures-core", - "futures-sink", - "log", - "pin-project-lite", - "tokio", -] - [[package]] name = "tower-service" version = "0.3.1" diff --git a/Cargo.toml b/Cargo.toml index 95e7b4e7..727a82e7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,8 +13,7 @@ exclude = ["target", "examples", "tests", "scripts"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -hyper = "0.14" -routerify = "2" +hyper = { version = "0.14", features = ["http1", "server", "runtime"] } bytes = "1" tracing = "0.1" tracing-subscriber = "0.2" diff --git a/mosec/_version.py b/mosec/_version.py index 6ff2a242..3dc1f76b 100644 --- a/mosec/_version.py +++ b/mosec/_version.py @@ -1 +1 @@ -__version__ = "0.1.0-alpha.2" +__version__ = "0.1.0" diff --git a/src/main.rs b/src/main.rs index bb4c55c3..08d91ce2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,9 +8,9 @@ mod tasks; use std::net::SocketAddr; use clap::Clap; -use hyper::{body::to_bytes, header::HeaderValue, Body, Request, Response, Server, StatusCode}; +use hyper::service::{make_service_fn, service_fn}; +use hyper::{body::to_bytes, header::HeaderValue, Body, Method, Request, Response, StatusCode}; use prometheus::{Encoder, TextEncoder}; -use routerify::{Middleware, RouteError, Router, RouterService}; use tokio::signal::unix::{signal, SignalKind}; use tracing::info; use tracing_subscriber::EnvFilter; @@ -22,6 +22,7 @@ use crate::metrics::Metrics; use crate::tasks::{TaskCode, TaskManager}; const SERVER_INFO: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); +const NOT_FOUND: &[u8] = b"Not Found"; async fn index(_: Request) -> Result, ServiceError> { let task_manager = TaskManager::global(); @@ -70,9 +71,8 @@ async fn inference(req: Request) -> Result, ServiceError> { } } -async fn error_handler(err: RouteError) -> Response { - let mosec_err = err.downcast::().unwrap(); - let status = match mosec_err.as_ref() { +fn error_handler(err: ServiceError) -> Response { + let status = match err { ServiceError::Timeout => StatusCode::REQUEST_TIMEOUT, ServiceError::BadRequestError => StatusCode::BAD_REQUEST, ServiceError::TooManyRequests => StatusCode::TOO_MANY_REQUESTS, @@ -91,14 +91,29 @@ async fn error_handler(err: RouteError) -> Response { Response::builder() .status(status) - .body(Body::from(mosec_err.to_string())) + .header("server", HeaderValue::from_static(SERVER_INFO)) + .body(Body::from(err.to_string())) .unwrap() } -async fn post_middleware_handler(mut resp: Response) -> Result, ServiceError> { - resp.headers_mut() - .insert("Server", HeaderValue::from_static(SERVER_INFO)); - Ok(resp) +async fn service_func(req: Request) -> Result, hyper::Error> { + let res = match (req.method(), req.uri().path()) { + (&Method::GET, "/") => index(req).await, + (&Method::GET, "/metrics") => metrics(req).await, + (&Method::POST, "/inference") => inference(req).await, + _ => Ok(Response::builder() + .status(StatusCode::NOT_FOUND) + .body(NOT_FOUND.into()) + .unwrap()), + }; + match res { + Ok(mut resp) => { + resp.headers_mut() + .insert("server", HeaderValue::from_static(SERVER_INFO)); + Ok(resp) + } + Err(err) => Ok(error_handler(err)), + } } async fn shutdown_signal() { @@ -141,18 +156,9 @@ async fn main() { coordinator.run().await; }); - let router = Router::builder() - .get("/", index) - .get("/metrics", metrics) - .post("/inference", inference) - .err_handler(error_handler) - .middleware(Middleware::post(post_middleware_handler)) - .build() - .unwrap(); - - let service = RouterService::new(router).unwrap(); + let service = make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(service_func)) }); let addr: SocketAddr = format!("{}:{}", opts.address, opts.port).parse().unwrap(); - let server = Server::bind(&addr).serve(service); + let server = hyper::Server::bind(&addr).serve(service); let graceful = server.with_graceful_shutdown(shutdown_signal()); if let Err(err) = graceful.await { tracing::error!(%err, "server error");