Skip to content

Commit

Permalink
enable full security logic for websocket transports
Browse files Browse the repository at this point in the history
Signed-off-by: onur-ozkan <work@onurozkan.dev>
  • Loading branch information
onur-ozkan committed Feb 27, 2024
1 parent a83a3f3 commit d56e0ee
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 159 deletions.
4 changes: 2 additions & 2 deletions src/ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub(crate) fn get_app_config() -> &'static AppConfig {
})
}

#[derive(Debug, Serialize, Deserialize, PartialEq)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub(crate) struct AppConfig {
pub(crate) port: Option<u16>,
pub(crate) redis_connection_string: String,
Expand All @@ -35,7 +35,7 @@ pub(crate) struct ProxyRoute {
pub(crate) allowed_methods: Vec<String>,
}

#[derive(Debug, Serialize, Deserialize, PartialEq)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub(crate) struct RateLimiter {
pub(crate) rp_1_min: u16,
pub(crate) rp_5_min: u16,
Expand Down
18 changes: 16 additions & 2 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,28 @@ impl Db {
connection: get_redis_connection(cfg).await,
}
}
}

impl Db {
#[allow(dead_code)]
pub(crate) async fn key_exists(&mut self, key: &str) -> GenericResult<bool> {
Ok(redis::cmd("EXISTS")
.arg(key)
.query_async(&mut self.connection)
.await?)
}

pub(crate) async fn insert_cache(
&mut self,
key: &str,
value: &str,
seconds: usize,
) -> GenericResult<()> {
redis::cmd("SETEX")
.arg(key)
.arg(seconds)
.arg(value)
.query_async(&mut self.connection)
.await?;

Ok(())
}
}
147 changes: 15 additions & 132 deletions src/net/http.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
use std::net::SocketAddr;

use address_status::{
get_address_status_list, post_address_status, AddressStatus, AddressStatusOperations,
};
use address_status::{get_address_status_list, post_address_status};
use ctx::{AppConfig, ProxyRoute};
use db::*;
use hyper::header::HeaderName;
use hyper::{
header::{self, HeaderValue},
Body, HeaderMap, Method, Request, Response, StatusCode,
};
use hyper_tls::HttpsConnector;
use jwt::{get_cached_token_or_generate_one, JwtClaims};
use proof_of_funding::{verify_message_and_balance, ProofOfFundingError};
use rate_limiter::RateLimitOperations;
use serde::{Deserialize, Serialize};
use serde_json::json;
use sign::SignedMessage;

use super::*;
use crate::server::is_private_ip;
use crate::server::{is_private_ip, validation_middleware};

async fn get_healthcheck() -> GenericResult<Response<Body>> {
let json = json!({
Expand Down Expand Up @@ -292,133 +287,21 @@ pub(crate) async fn http_handler(
.await;
}

let mut db = Db::create_instance(cfg).await;

match db
.read_address_status(payload.signed_message.address.clone())
.await
if let Err(status_code) =
validation_middleware(cfg, &payload, proxy_route, req.uri(), &remote_addr).await
{
AddressStatus::Trusted => {
proxy(
cfg,
req,
&remote_addr,
payload,
x_forwarded_for,
proxy_route,
)
.await
}
AddressStatus::Blocked => {
log::warn!(
"{}",
log_format!(
remote_addr.ip(),
payload.signed_message.address,
req_uri,
"Request blocked."
)
);
response_by_status(StatusCode::FORBIDDEN)
}
_ => {
let signed_message_status =
verify_message_and_balance(cfg, &payload, proxy_route).await;

if let Err(ProofOfFundingError::InvalidSignedMessage) = signed_message_status {
log::warn!(
"{}",
log_format!(
remote_addr.ip(),
payload.signed_message.address,
req_uri,
"Request has invalid signed message, returning 401."
)
);

return response_by_status(StatusCode::UNAUTHORIZED);
};

let rate_limiter_key = format!(
"{}:{}",
payload.signed_message.coin_ticker, payload.signed_message.address
);

match db
.rate_exceeded(rate_limiter_key.clone(), &cfg.rate_limiter)
.await
{
Ok(false) => {}
_ => {
log::warn!(
"{}",
log_format!(
remote_addr.ip(),
payload.signed_message.address,
req_uri,
"Rate exceed for {}, checking balance for {} address.",
rate_limiter_key,
payload.signed_message.address
)
);

match verify_message_and_balance(cfg, &payload, proxy_route).await {
Ok(_) => {}
Err(ProofOfFundingError::InsufficientBalance) => {
log::warn!(
"{}",
log_format!(
remote_addr.ip(),
payload.signed_message.address,
req_uri,
"Wallet {} has insufficient balance for coin {}, returning 406.",
payload.signed_message.coin_ticker,
payload.signed_message.address
)
);
return response_by_status(StatusCode::NOT_ACCEPTABLE);
}
e => {
log::error!(
"{}",
log_format!(
remote_addr.ip(),
payload.signed_message.address,
req_uri,
"verify_message_and_balance failed in coin {}: {:?}",
payload.signed_message.coin_ticker,
e
)
);
return response_by_status(StatusCode::INTERNAL_SERVER_ERROR);
}
}
}
}

if db.rate_address(rate_limiter_key).await.is_err() {
log::error!(
"{}",
log_format!(
remote_addr.ip(),
payload.signed_message.address,
req_uri,
"Rate incrementing failed."
)
);
};

proxy(
cfg,
req,
&remote_addr,
payload,
x_forwarded_for,
proxy_route,
)
.await
}
return response_by_status(status_code);
}

proxy(
cfg,
req,
&remote_addr,
payload,
x_forwarded_for,
proxy_route,
)
.await
}

#[test]
Expand Down
115 changes: 112 additions & 3 deletions src/net/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@ use std::str::FromStr;

use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request, Response, Server, StatusCode};
use hyper::{Body, Request, Response, Server, StatusCode, Uri};

use crate::http::{http_handler, response_by_status};
use crate::address_status::AddressStatusOperations;
use crate::ctx::ProxyRoute;
use crate::db::Db;
use crate::http::{http_handler, response_by_status, RpcPayload};
use crate::log_format;
use crate::proof_of_funding::{verify_message_and_balance, ProofOfFundingError};
use crate::rate_limiter::RateLimitOperations;
use crate::websocket::{should_upgrade_to_socket_conn, socket_handler};
use crate::{ctx::AppConfig, GenericError, GenericResult};

Expand Down Expand Up @@ -56,12 +61,116 @@ async fn connection_handler(
};

if should_upgrade_to_socket_conn(&req) {
socket_handler(cfg, req, remote_addr).await
socket_handler(cfg.clone(), req, remote_addr).await
} else {
http_handler(cfg, req, remote_addr).await
}
}

pub(crate) async fn validation_middleware(
cfg: &AppConfig,
payload: &RpcPayload,
proxy_route: &ProxyRoute,
req_uri: &Uri,
remote_addr: &SocketAddr,
) -> Result<(), StatusCode> {
let mut db = Db::create_instance(cfg).await;

match db
.read_address_status(&payload.signed_message.address)
.await
{
crate::address_status::AddressStatus::Trusted => Ok(()),
crate::address_status::AddressStatus::Blocked => Err(StatusCode::FORBIDDEN),
crate::address_status::AddressStatus::None => {
let signed_message_status = verify_message_and_balance(cfg, payload, proxy_route).await;

if let Err(ProofOfFundingError::InvalidSignedMessage) = signed_message_status {
log::warn!(
"{}",
log_format!(
remote_addr.ip(),
payload.signed_message.address,
req_uri,
"Request has invalid signed message, returning 401"
)
);

return Err(StatusCode::UNAUTHORIZED);
};

let rate_limiter_key = format!(
"{}:{}",
payload.signed_message.coin_ticker, payload.signed_message.address
);

match db.rate_exceeded(&rate_limiter_key, &cfg.rate_limiter).await {
Ok(false) => {}
_ => {
log::warn!(
"{}",
log_format!(
remote_addr.ip(),
payload.signed_message.address,
req_uri,
"Rate exceed for {}, checking balance for {} address.",
rate_limiter_key,
payload.signed_message.address
)
);

match verify_message_and_balance(cfg, payload, proxy_route).await {
Ok(_) => {}
Err(ProofOfFundingError::InsufficientBalance) => {
log::warn!(
"{}",
log_format!(
remote_addr.ip(),
payload.signed_message.address,
req_uri,
"Wallet {} has insufficient balance for coin {}, returning 406.",
payload.signed_message.coin_ticker,
payload.signed_message.address
)
);

return Err(StatusCode::NOT_ACCEPTABLE);
}
e => {
log::error!(
"{}",
log_format!(
remote_addr.ip(),
payload.signed_message.address,
req_uri,
"verify_message_and_balance failed in coin {}: {:?}",
payload.signed_message.coin_ticker,
e
)
);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
}
}
};

if db.rate_address(rate_limiter_key).await.is_err() {
log::error!(
"{}",
log_format!(
remote_addr.ip(),
payload.signed_message.address,
req_uri,
"Rate incrementing failed."
)
);
};

Ok(())
}
}
}

pub(crate) async fn serve(cfg: &'static AppConfig) -> GenericResult<()> {
let addr = format!("0.0.0.0:{}", cfg.port.unwrap_or(5000)).parse()?;

Expand Down
Loading

0 comments on commit d56e0ee

Please sign in to comment.