Skip to content

Commit

Permalink
remove routerify (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
kemingy authored Sep 28, 2021
1 parent d590527 commit ce8f23a
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 103 deletions.
81 changes: 2 additions & 79 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ exclude = ["target", "examples", "tests", "scripts"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
hyper = "0.14"
routerify = "2"
hyper = { version = "0.14", features = ["http1", "server", "runtime"] }
bytes = "1"
tracing = "0.1"
tracing-subscriber = "0.2"
Expand Down
2 changes: 1 addition & 1 deletion mosec/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.0-alpha.2"
__version__ = "0.1.0"
48 changes: 27 additions & 21 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ mod tasks;
use std::net::SocketAddr;

use clap::Clap;
use hyper::{body::to_bytes, header::HeaderValue, Body, Request, Response, Server, StatusCode};
use hyper::service::{make_service_fn, service_fn};
use hyper::{body::to_bytes, header::HeaderValue, Body, Method, Request, Response, StatusCode};
use prometheus::{Encoder, TextEncoder};
use routerify::{Middleware, RouteError, Router, RouterService};
use tokio::signal::unix::{signal, SignalKind};
use tracing::info;
use tracing_subscriber::EnvFilter;
Expand All @@ -22,6 +22,7 @@ use crate::metrics::Metrics;
use crate::tasks::{TaskCode, TaskManager};

const SERVER_INFO: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
const NOT_FOUND: &[u8] = b"Not Found";

async fn index(_: Request<Body>) -> Result<Response<Body>, ServiceError> {
let task_manager = TaskManager::global();
Expand Down Expand Up @@ -70,9 +71,8 @@ async fn inference(req: Request<Body>) -> Result<Response<Body>, ServiceError> {
}
}

async fn error_handler(err: RouteError) -> Response<Body> {
let mosec_err = err.downcast::<ServiceError>().unwrap();
let status = match mosec_err.as_ref() {
fn error_handler(err: ServiceError) -> Response<Body> {
let status = match err {
ServiceError::Timeout => StatusCode::REQUEST_TIMEOUT,
ServiceError::BadRequestError => StatusCode::BAD_REQUEST,
ServiceError::TooManyRequests => StatusCode::TOO_MANY_REQUESTS,
Expand All @@ -91,14 +91,29 @@ async fn error_handler(err: RouteError) -> Response<Body> {

Response::builder()
.status(status)
.body(Body::from(mosec_err.to_string()))
.header("server", HeaderValue::from_static(SERVER_INFO))
.body(Body::from(err.to_string()))
.unwrap()
}

async fn post_middleware_handler(mut resp: Response<Body>) -> Result<Response<Body>, ServiceError> {
resp.headers_mut()
.insert("Server", HeaderValue::from_static(SERVER_INFO));
Ok(resp)
async fn service_func(req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
let res = match (req.method(), req.uri().path()) {
(&Method::GET, "/") => index(req).await,
(&Method::GET, "/metrics") => metrics(req).await,
(&Method::POST, "/inference") => inference(req).await,
_ => Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(NOT_FOUND.into())
.unwrap()),
};
match res {
Ok(mut resp) => {
resp.headers_mut()
.insert("server", HeaderValue::from_static(SERVER_INFO));
Ok(resp)
}
Err(err) => Ok(error_handler(err)),
}
}

async fn shutdown_signal() {
Expand Down Expand Up @@ -141,18 +156,9 @@ async fn main() {
coordinator.run().await;
});

let router = Router::builder()
.get("/", index)
.get("/metrics", metrics)
.post("/inference", inference)
.err_handler(error_handler)
.middleware(Middleware::post(post_middleware_handler))
.build()
.unwrap();

let service = RouterService::new(router).unwrap();
let service = make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(service_func)) });
let addr: SocketAddr = format!("{}:{}", opts.address, opts.port).parse().unwrap();
let server = Server::bind(&addr).serve(service);
let server = hyper::Server::bind(&addr).serve(service);
let graceful = server.with_graceful_shutdown(shutdown_signal());
if let Err(err) = graceful.await {
tracing::error!(%err, "server error");
Expand Down

0 comments on commit ce8f23a

Please sign in to comment.