Skip to content

Commit

Permalink
refactor(torii-server): server proxy handlers (#2708)
Browse files Browse the repository at this point in the history
* refactor(torii-server): cleanup & handlers

* refactor: handlers

* better sql error handling
  • Loading branch information
Larkooo authored Nov 22, 2024
1 parent 374461d commit b55f1fc
Show file tree
Hide file tree
Showing 9 changed files with 319 additions and 154 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/torii/server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ tower.workspace = true
tracing.workspace = true
warp.workspace = true
form_urlencoded = "1.2.1"
async-trait = "0.1.83"
46 changes: 46 additions & 0 deletions crates/torii/server/src/handlers/graphql.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use std::net::{IpAddr, SocketAddr};

use http::StatusCode;
use hyper::{Body, Request, Response};
use tracing::error;

use super::Handler;

pub(crate) const LOG_TARGET: &str = "torii::server::handlers::graphql";

pub struct GraphQLHandler {
client_ip: IpAddr,
graphql_addr: Option<SocketAddr>,
}

impl GraphQLHandler {
pub fn new(client_ip: IpAddr, graphql_addr: Option<SocketAddr>) -> Self {
Self { client_ip, graphql_addr }
}
}

#[async_trait::async_trait]
impl Handler for GraphQLHandler {
fn should_handle(&self, req: &Request<Body>) -> bool {
req.uri().path().starts_with("/graphql")
}

async fn handle(&self, req: Request<Body>) -> Response<Body> {
if let Some(addr) = self.graphql_addr {
let graphql_addr = format!("http://{}", addr);
match crate::proxy::GRAPHQL_PROXY_CLIENT.call(self.client_ip, &graphql_addr, req).await
{
Ok(response) => response,
Err(_error) => {
error!(target: LOG_TARGET, "GraphQL proxy error: {:?}", _error);
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap()
}
}
} else {
Response::builder().status(StatusCode::NOT_FOUND).body(Body::empty()).unwrap()
}
}
}
49 changes: 49 additions & 0 deletions crates/torii/server/src/handlers/grpc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use std::net::{IpAddr, SocketAddr};

use http::header::CONTENT_TYPE;
use hyper::{Body, Request, Response, StatusCode};
use tracing::error;

use super::Handler;

pub(crate) const LOG_TARGET: &str = "torii::server::handlers::grpc";

pub struct GrpcHandler {
client_ip: IpAddr,
grpc_addr: Option<SocketAddr>,
}

impl GrpcHandler {
pub fn new(client_ip: IpAddr, grpc_addr: Option<SocketAddr>) -> Self {
Self { client_ip, grpc_addr }
}
}

#[async_trait::async_trait]
impl Handler for GrpcHandler {
fn should_handle(&self, req: &Request<Body>) -> bool {
req.headers()
.get(CONTENT_TYPE)
.and_then(|ct| ct.to_str().ok())
.map(|ct| ct.starts_with("application/grpc"))
.unwrap_or(false)
}

async fn handle(&self, req: Request<Body>) -> Response<Body> {
if let Some(grpc_addr) = self.grpc_addr {
let grpc_addr = format!("http://{}", grpc_addr);
match crate::proxy::GRPC_PROXY_CLIENT.call(self.client_ip, &grpc_addr, req).await {
Ok(response) => response,
Err(_error) => {
error!(target: LOG_TARGET, "{:?}", _error);
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap()
}
}
} else {
Response::builder().status(StatusCode::NOT_FOUND).body(Body::empty()).unwrap()
}
}
}
15 changes: 15 additions & 0 deletions crates/torii/server/src/handlers/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
pub mod graphql;
pub mod grpc;
pub mod sql;
pub mod static_files;

use hyper::{Body, Request, Response};

#[async_trait::async_trait]
pub trait Handler: Send + Sync {
// Check if this handler should handle the given request
fn should_handle(&self, req: &Request<Body>) -> bool;

// Handle the request
async fn handle(&self, req: Request<Body>) -> Response<Body>;
}
134 changes: 134 additions & 0 deletions crates/torii/server/src/handlers/sql.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use std::sync::Arc;

use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use http::header::CONTENT_TYPE;
use hyper::{Body, Method, Request, Response, StatusCode};
use sqlx::{Column, Row, SqlitePool, TypeInfo};

use super::Handler;

pub struct SqlHandler {
pool: Arc<SqlitePool>,
}

impl SqlHandler {
pub fn new(pool: Arc<SqlitePool>) -> Self {
Self { pool }
}

pub async fn execute_query(&self, query: String) -> Response<Body> {
match sqlx::query(&query).fetch_all(&*self.pool).await {
Ok(rows) => {
let result: Vec<_> = rows
.iter()
.map(|row| {
let mut obj = serde_json::Map::new();
for (i, column) in row.columns().iter().enumerate() {
let value: serde_json::Value = match column.type_info().name() {
"TEXT" => row
.get::<Option<String>, _>(i)
.map_or(serde_json::Value::Null, serde_json::Value::String),
"INTEGER" | "NULL" => row
.get::<Option<i64>, _>(i)
.map_or(serde_json::Value::Null, |n| {
serde_json::Value::Number(n.into())
}),
"REAL" => row.get::<Option<f64>, _>(i).map_or(
serde_json::Value::Null,
|f| {
serde_json::Number::from_f64(f).map_or(
serde_json::Value::Null,
serde_json::Value::Number,
)
},
),
"BLOB" => row
.get::<Option<Vec<u8>>, _>(i)
.map_or(serde_json::Value::Null, |bytes| {
serde_json::Value::String(STANDARD.encode(bytes))
}),
_ => row
.get::<Option<String>, _>(i)
.map_or(serde_json::Value::Null, serde_json::Value::String),
};
obj.insert(column.name().to_string(), value);
}
serde_json::Value::Object(obj)
})
.collect();

let json = match serde_json::to_string(&result) {
Ok(json) => json,
Err(e) => {
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from(format!("Failed to serialize result: {:?}", e)))
.unwrap();
}
};

Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, "application/json")
.body(Body::from(json))
.unwrap()
}
Err(e) => Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from(format!("Query error: {:?}", e)))
.unwrap(),
}
}

async fn extract_query(&self, req: Request<Body>) -> Result<String, Response<Body>> {
match *req.method() {
Method::GET => {
// Get the query from the query params
let params = req.uri().query().unwrap_or_default();
form_urlencoded::parse(params.as_bytes())
.find(|(key, _)| key == "q" || key == "query")
.map(|(_, value)| value.to_string())
.ok_or(
Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Missing 'q' or 'query' parameter."))
.unwrap(),
)
}
Method::POST => {
// Get the query from request body
let body_bytes = hyper::body::to_bytes(req.into_body()).await.map_err(|_| {
Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Failed to read query from request body"))
.unwrap()
})?;
String::from_utf8(body_bytes.to_vec()).map_err(|_| {
Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Invalid query"))
.unwrap()
})
}
_ => Err(Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.body(Body::from("Only GET and POST methods are allowed"))
.unwrap()),
}
}
}

#[async_trait::async_trait]
impl Handler for SqlHandler {
fn should_handle(&self, req: &Request<Body>) -> bool {
req.uri().path().starts_with("/sql")
}

async fn handle(&self, req: Request<Body>) -> Response<Body> {
match self.extract_query(req).await {
Ok(query) => self.execute_query(query).await,
Err(response) => response,
}
}
}
47 changes: 47 additions & 0 deletions crates/torii/server/src/handlers/static_files.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use std::net::{IpAddr, SocketAddr};

use hyper::{Body, Request, Response, StatusCode};
use tracing::error;

use super::Handler;

pub(crate) const LOG_TARGET: &str = "torii::server::handlers::static";

pub struct StaticHandler {
client_ip: IpAddr,
artifacts_addr: Option<SocketAddr>,
}

impl StaticHandler {
pub fn new(client_ip: IpAddr, artifacts_addr: Option<SocketAddr>) -> Self {
Self { client_ip, artifacts_addr }
}
}

#[async_trait::async_trait]
impl Handler for StaticHandler {
fn should_handle(&self, req: &Request<Body>) -> bool {
req.uri().path().starts_with("/static")
}

async fn handle(&self, req: Request<Body>) -> Response<Body> {
if let Some(artifacts_addr) = self.artifacts_addr {
let artifacts_addr = format!("http://{}", artifacts_addr);
match crate::proxy::GRAPHQL_PROXY_CLIENT
.call(self.client_ip, &artifacts_addr, req)
.await
{
Ok(response) => response,
Err(_error) => {
error!(target: LOG_TARGET, "{:?}", _error);
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap()
}
}
} else {
Response::builder().status(StatusCode::NOT_FOUND).body(Body::empty()).unwrap()
}
}
}
1 change: 1 addition & 0 deletions crates/torii/server/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod artifacts;
pub(crate) mod handlers;
pub mod proxy;
Loading

0 comments on commit b55f1fc

Please sign in to comment.