Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 88 additions & 3 deletions crates/zeph-a2a/src/server/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,29 @@ struct AuthConfig {
token: Option<String>,
}

const MAX_RATE_LIMIT_ENTRIES: usize = 10_000;
const EVICTION_INTERVAL: Duration = Duration::from_secs(60);
const RATE_WINDOW: Duration = Duration::from_secs(60);

#[derive(Clone)]
struct RateLimitState {
limit: u32,
counters: Arc<Mutex<HashMap<IpAddr, (u32, Instant)>>>,
}

fn spawn_eviction_task(counters: Arc<Mutex<HashMap<IpAddr, (u32, Instant)>>>) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(EVICTION_INTERVAL);
interval.tick().await; // skip immediate first tick
loop {
interval.tick().await;
let now = Instant::now();
let mut map = counters.lock().await;
map.retain(|_, (_, ts)| now.duration_since(*ts) < RATE_WINDOW);
}
});
}

#[cfg(test)]
pub fn build_router_with_config(
state: AppState,
Expand All @@ -47,9 +64,13 @@ pub fn build_router_with_full_config(
max_body_size: usize,
) -> Router {
let auth_cfg = AuthConfig { token: auth_token };
let counters = Arc::new(Mutex::new(HashMap::new()));
if rate_limit > 0 {
spawn_eviction_task(Arc::clone(&counters));
}
let rate_state = RateLimitState {
limit: rate_limit,
counters: Arc::new(Mutex::new(HashMap::new())),
counters,
};

let protected = Router::new()
Expand Down Expand Up @@ -108,12 +129,16 @@ async fn rate_limit_middleware(
.map_or(IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), |ci| ci.0.ip());

let now = Instant::now();
let window = Duration::from_secs(60);

let mut counters = state.counters.lock().await;

if counters.len() >= MAX_RATE_LIMIT_ENTRIES && !counters.contains_key(&ip) {
counters.clear();
}

let entry = counters.entry(ip).or_insert((0, now));

if now.duration_since(entry.1) >= window {
if now.duration_since(entry.1) >= RATE_WINDOW {
*entry = (1, now);
} else {
entry.0 += 1;
Expand Down Expand Up @@ -325,4 +350,64 @@ mod tests {
let resp = app.call(make_req()).await.unwrap();
assert_eq!(resp.status(), 429, "request 3 should be rate-limited");
}

#[tokio::test]
async fn max_entries_cap_clears_map() {
let counters = Arc::new(Mutex::new(HashMap::new()));
{
let mut map = counters.lock().await;
for i in 0..MAX_RATE_LIMIT_ENTRIES {
let ip = IpAddr::V4(std::net::Ipv4Addr::new(
((i >> 16) & 0xFF) as u8,
((i >> 8) & 0xFF) as u8,
(i & 0xFF) as u8,
1,
));
map.insert(ip, (1, Instant::now()));
}
assert_eq!(map.len(), MAX_RATE_LIMIT_ENTRIES);
}

let state = RateLimitState {
limit: 10,
counters,
};

let new_ip = IpAddr::V4(std::net::Ipv4Addr::new(255, 255, 255, 255));
assert!(!state.counters.lock().await.contains_key(&new_ip));

// Simulate what the middleware does when cap is exceeded
let mut map = state.counters.lock().await;
if map.len() >= MAX_RATE_LIMIT_ENTRIES && !map.contains_key(&new_ip) {
map.clear();
}
assert_eq!(map.len(), 0);
}

#[tokio::test]
async fn eviction_removes_stale_entries() {
let counters = Arc::new(Mutex::new(HashMap::new()));
let stale_time = Instant::now() - Duration::from_secs(120);
let fresh_time = Instant::now();

let stale_ip = IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1));
let fresh_ip = IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 2));

{
let mut map = counters.lock().await;
map.insert(stale_ip, (5, stale_time));
map.insert(fresh_ip, (3, fresh_time));
}

// Simulate eviction logic
let now = Instant::now();
let mut map = counters.lock().await;
map.retain(|_, (_, ts)| now.duration_since(*ts) < RATE_WINDOW);

assert!(
!map.contains_key(&stale_ip),
"stale entry should be evicted"
);
assert!(map.contains_key(&fresh_ip), "fresh entry should remain");
}
}
Loading