diff --git a/CHANGELOG.md b/CHANGELOG.md index 4807d187..57379ace 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). - Replace double `AnyProvider::clone()` in `embed_fn()` with single `Arc` clone (#636) - Add `with_client()` builder to ClaudeProvider and OpenAiProvider for shared `reqwest::Client` (#637) - Cache `JsonSchema` per `TypeId` in `chat_typed` to avoid per-call schema generation (#638) +- Scrape executor performs post-DNS resolution validation against private/loopback IPs with pinned address client to prevent SSRF via DNS rebinding +- Private host detection expanded to block `*.localhost`, `*.internal`, `*.local` domains +- A2A error responses sanitized: serde details and method names no longer exposed to clients +- Rate limiter rejects new clients with 429 when entry map is at capacity after stale eviction +- Secret redaction regex-based pattern matching replaces whitespace tokenizer, detecting secrets in URLs, JSON, and quoted strings +- Added `hf_`, `npm_`, `dckr_pat_` to secret redaction prefixes +- A2A client stream errors truncate upstream body to 256 bytes ### Fixed - False positive: "sudoku" no longer matched by "sudo" blocked pattern (word-boundary matching) diff --git a/Cargo.lock b/Cargo.lock index 8abb5987..dbd7ee67 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8935,6 +8935,7 @@ dependencies = [ "notify", "notify-debouncer-mini", "proptest", + "regex", "schemars 1.2.1", "serde", "serde_json", diff --git a/crates/zeph-a2a/src/client.rs b/crates/zeph-a2a/src/client.rs index 31086a62..2e802518 100644 --- a/crates/zeph-a2a/src/client.rs +++ b/crates/zeph-a2a/src/client.rs @@ -76,7 +76,13 @@ impl A2aClient { if !resp.status().is_success() { let status = resp.status(); let body = resp.text().await.unwrap_or_default(); - return Err(A2aError::Stream(format!("HTTP {status}: {body}"))); + // Truncate body to avoid leaking large upstream error responses. + let truncated = if body.len() > 256 { + format!("{}…", &body[..256]) + } else { + body + }; + return Err(A2aError::Stream(format!("HTTP {status}: {truncated}"))); } let event_stream = resp.bytes_stream().eventsource(); @@ -182,9 +188,38 @@ impl A2aClient { fn is_private_ip(ip: IpAddr) -> bool { match ip { IpAddr::V4(v4) => { - v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified() + v4.is_loopback() + || v4.is_private() + || v4.is_link_local() + || v4.is_unspecified() + || v4.is_broadcast() + } + IpAddr::V6(v6) => { + if v6.is_loopback() || v6.is_unspecified() { + return true; + } + let seg = v6.segments(); + // fe80::/10 — link-local + if seg[0] & 0xffc0 == 0xfe80 { + return true; + } + // fc00::/7 — unique local + if seg[0] & 0xfe00 == 0xfc00 { + return true; + } + // ::ffff:x.x.x.x — IPv4-mapped, check inner IPv4 + if seg[0..6] == [0, 0, 0, 0, 0, 0xffff] { + let v4 = v6 + .to_ipv4_mapped() + .unwrap_or(std::net::Ipv4Addr::UNSPECIFIED); + return v4.is_loopback() + || v4.is_private() + || v4.is_link_local() + || v4.is_unspecified() + || v4.is_broadcast(); + } + false } - IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified(), } } diff --git a/crates/zeph-a2a/src/server/handlers.rs b/crates/zeph-a2a/src/server/handlers.rs index 1154b50f..ecff1db4 100644 --- a/crates/zeph-a2a/src/server/handlers.rs +++ b/crates/zeph-a2a/src/server/handlers.rs @@ -75,11 +75,10 @@ pub async fn jsonrpc_handler( ), METHOD_GET_TASK => handle_get_task(state, id.clone(), raw.params).await, METHOD_CANCEL_TASK => handle_cancel_task(state, id.clone(), raw.params).await, - _ => error_response( - id.clone(), - ERR_METHOD_NOT_FOUND, - format!("unknown method: {}", raw.method), - ), + _ => { + tracing::warn!(method = %raw.method, "unknown JSON-RPC method"); + error_response(id.clone(), ERR_METHOD_NOT_FOUND, "method not found") + } }; Json(response) @@ -92,7 +91,10 @@ async fn handle_send_message( ) -> JsonRpcResponse { let params: SendMessageParams = match serde_json::from_value(params) { Ok(p) => p, - Err(e) => return error_response(id, ERR_INVALID_PARAMS, format!("invalid params: {e}")), + Err(e) => { + tracing::warn!("invalid params in send_message: {e}"); + return error_response(id, ERR_INVALID_PARAMS, "invalid parameters"); + } }; let task = state.task_manager.create_task(params.message.clone()).await; @@ -142,7 +144,10 @@ async fn handle_get_task( ) -> JsonRpcResponse { let params: TaskIdParams = match serde_json::from_value(params) { Ok(p) => p, - Err(e) => return error_response(id, ERR_INVALID_PARAMS, format!("invalid params: {e}")), + Err(e) => { + tracing::warn!("invalid params in get_task: {e}"); + return error_response(id, ERR_INVALID_PARAMS, "invalid parameters"); + } }; match state @@ -162,7 +167,10 @@ async fn handle_cancel_task( ) -> JsonRpcResponse { let params: TaskIdParams = match serde_json::from_value(params) { Ok(p) => p, - Err(e) => return error_response(id, ERR_INVALID_PARAMS, format!("invalid params: {e}")), + Err(e) => { + tracing::warn!("invalid params in cancel_task: {e}"); + return error_response(id, ERR_INVALID_PARAMS, "invalid parameters"); + } }; match state.task_manager.cancel_task(¶ms.id).await { @@ -314,3 +322,92 @@ async fn stream_task(state: AppState, message: crate::types::Message, tx: mpsc:: } } } + +#[cfg(test)] +mod tests { + use axum::body::Body; + use http_body_util::BodyExt; + use tower::ServiceExt; + + use super::super::router::build_router_with_config; + use super::super::testing::test_state; + + fn make_rpc_request(method: &str, params: serde_json::Value) -> axum::http::Request { + let body = serde_json::json!({ + "jsonrpc": "2.0", + "id": "1", + "method": method, + "params": params, + }); + axum::http::Request::builder() + .method("POST") + .uri("/a2a") + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&body).unwrap())) + .unwrap() + } + + async fn get_rpc_body(resp: axum::http::Response) -> serde_json::Value { + let bytes = resp.into_body().collect().await.unwrap().to_bytes(); + serde_json::from_slice(&bytes).unwrap() + } + + #[tokio::test] + async fn unknown_method_does_not_echo_method_name() { + let app = build_router_with_config(test_state(), None, 0); + let req = make_rpc_request("tasks/evil_probe", serde_json::json!({})); + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), 200); + let body = get_rpc_body(resp).await; + let msg = body["error"]["message"].as_str().unwrap_or(""); + assert_eq!(msg, "method not found", "must not echo method name"); + assert!( + !msg.contains("evil_probe"), + "method name must not appear in error" + ); + assert!( + !msg.contains("unknown"), + "must not leak 'unknown method' phrasing" + ); + } + + #[tokio::test] + async fn invalid_params_send_message_no_serde_details() { + let app = build_router_with_config(test_state(), None, 0); + // Pass wrong type for message to trigger serde deserialization error + let req = make_rpc_request("message/send", serde_json::json!({"message": 42})); + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), 200); + let body = get_rpc_body(resp).await; + let msg = body["error"]["message"].as_str().unwrap_or(""); + assert_eq!(msg, "invalid parameters"); + // Serde error text like "invalid type" or field names must not leak + assert!(!msg.contains("invalid type"), "serde details must not leak"); + assert!(!msg.contains("expected"), "serde details must not leak"); + } + + #[tokio::test] + async fn invalid_params_get_task_no_serde_details() { + let app = build_router_with_config(test_state(), None, 0); + // Pass wrong type for id field + let req = make_rpc_request("tasks/get", serde_json::json!({"id": [1, 2, 3]})); + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), 200); + let body = get_rpc_body(resp).await; + let msg = body["error"]["message"].as_str().unwrap_or(""); + assert_eq!(msg, "invalid parameters"); + assert!(!msg.contains("invalid type"), "serde details must not leak"); + } + + #[tokio::test] + async fn invalid_params_cancel_task_no_serde_details() { + let app = build_router_with_config(test_state(), None, 0); + let req = make_rpc_request("tasks/cancel", serde_json::json!({"id": false})); + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), 200); + let body = get_rpc_body(resp).await; + let msg = body["error"]["message"].as_str().unwrap_or(""); + assert_eq!(msg, "invalid parameters"); + assert!(!msg.contains("invalid type"), "serde details must not leak"); + } +} diff --git a/crates/zeph-a2a/src/server/router.rs b/crates/zeph-a2a/src/server/router.rs index b5826f56..8b9a6ecd 100644 --- a/crates/zeph-a2a/src/server/router.rs +++ b/crates/zeph-a2a/src/server/router.rs @@ -142,8 +142,9 @@ async fn rate_limit_middleware( before = before_eviction, after = after_eviction, limit = MAX_RATE_LIMIT_ENTRIES, - "rate limiter still at capacity after stale entry eviction" + "rate limiter at capacity after stale entry eviction, rejecting new IP" ); + return StatusCode::TOO_MANY_REQUESTS.into_response(); } } @@ -363,10 +364,13 @@ mod tests { } #[tokio::test] - async fn max_entries_cap_clears_map() { + async fn max_entries_cap_rejects_when_all_entries_fresh() { + // Fill map with fresh entries (within RATE_WINDOW) so retain() keeps them all. + // After retain() the map is still at capacity, so the middleware returns 429. let counters = Arc::new(Mutex::new(HashMap::new())); { let mut map = counters.lock().await; + let fresh = Instant::now(); for i in 0..MAX_RATE_LIMIT_ENTRIES { let ip = IpAddr::V4(std::net::Ipv4Addr::new( ((i >> 16) & 0xFF) as u8, @@ -374,25 +378,53 @@ mod tests { (i & 0xFF) as u8, 1, )); - map.insert(ip, (1, Instant::now())); + map.insert(ip, (1, fresh)); } assert_eq!(map.len(), MAX_RATE_LIMIT_ENTRIES); } - let state = RateLimitState { - limit: 10, - counters, - }; - let new_ip = IpAddr::V4(std::net::Ipv4Addr::new(255, 255, 255, 255)); - assert!(!state.counters.lock().await.contains_key(&new_ip)); - // Simulate what the middleware does when cap is exceeded - let mut map = state.counters.lock().await; - if map.len() >= MAX_RATE_LIMIT_ENTRIES && !map.contains_key(&new_ip) { - map.clear(); + // Simulate middleware logic: cap exceeded, run retain(), still full → 429 + let now = Instant::now(); + let mut map = counters.lock().await; + let before = map.len(); + map.retain(|_, (_, ts)| now.duration_since(*ts) < RATE_WINDOW); + let after = map.len(); + + // All entries are fresh so retain() must not remove any + assert_eq!(after, before, "retain must preserve fresh entries"); + // Map still at capacity: a new IP would be rejected + assert!( + after >= MAX_RATE_LIMIT_ENTRIES && !map.contains_key(&new_ip), + "new IP should be rejected when map is still at capacity after eviction" + ); + } + + #[tokio::test] + async fn max_entries_cap_allows_after_stale_eviction() { + // Fill map with stale entries. After retain() the map is empty, new IP is accepted. + let counters = Arc::new(Mutex::new(HashMap::new())); + { + let mut map = counters.lock().await; + let stale = Instant::now() - Duration::from_secs(120); + for i in 0..MAX_RATE_LIMIT_ENTRIES { + let ip = IpAddr::V4(std::net::Ipv4Addr::new( + ((i >> 16) & 0xFF) as u8, + ((i >> 8) & 0xFF) as u8, + (i & 0xFF) as u8, + 1, + )); + map.insert(ip, (1, stale)); + } } - assert_eq!(map.len(), 0); + + let now = Instant::now(); + let mut map = counters.lock().await; + map.retain(|_, (_, ts)| now.duration_since(*ts) < RATE_WINDOW); + + // All entries were stale; map should now be empty + assert_eq!(map.len(), 0, "stale entries must be evicted by retain"); } #[tokio::test] diff --git a/crates/zeph-core/Cargo.toml b/crates/zeph-core/Cargo.toml index a1524aba..01a0cf78 100644 --- a/crates/zeph-core/Cargo.toml +++ b/crates/zeph-core/Cargo.toml @@ -20,6 +20,7 @@ anyhow.workspace = true futures.workspace = true notify.workspace = true notify-debouncer-mini.workspace = true +regex.workspace = true serde = { workspace = true, features = ["derive"] } serde_json.workspace = true thiserror.workspace = true diff --git a/crates/zeph-core/src/redact.rs b/crates/zeph-core/src/redact.rs index 7ca495c8..b600d7b2 100644 --- a/crates/zeph-core/src/redact.rs +++ b/crates/zeph-core/src/redact.rs @@ -1,4 +1,7 @@ use std::borrow::Cow; +use std::sync::LazyLock; + +use regex::Regex; const SECRET_PREFIXES: &[&str] = &[ "sk-", @@ -11,44 +14,59 @@ const SECRET_PREFIXES: &[&str] = &[ "xoxb-", "xoxp-", "AIza", - "ya29.", + "ya29\\.", "glpat-", + "hf_", + "npm_", + "dckr_pat_", ]; +// Matches any secret prefix followed by non-whitespace characters. +// Using alternation so a single pass covers all prefixes. +static SECRET_REGEX: LazyLock = LazyLock::new(|| { + let pattern = SECRET_PREFIXES.join("|"); + let full = format!("(?:{pattern})[^\\s\"'`,;{{}}\\[\\]]*"); + Regex::new(&full).expect("secret redaction regex is valid") +}); + +static PATH_REGEX: LazyLock = LazyLock::new(|| { + Regex::new(r#"(?:/home/|/Users/|/root/|/tmp/|/var/)[^\s"'`,;{}\[\]]*"#) + .expect("path redaction regex is valid") +}); + /// Replace tokens containing known secret patterns with `[REDACTED]`. /// -/// Preserves all original whitespace (newlines, tabs, indentation). +/// Detects secrets embedded in URLs, JSON values, and quoted strings. /// Returns `Cow::Borrowed` when no secrets found (zero-allocation fast path). #[must_use] pub fn redact_secrets(text: &str) -> Cow<'_, str> { - if !SECRET_PREFIXES.iter().any(|p| text.contains(p)) { + // Fast path: check for any prefix substring before running regex. + let raw_prefixes = &[ + "sk-", + "sk_live_", + "sk_test_", + "AKIA", + "ghp_", + "gho_", + "-----BEGIN", + "xoxb-", + "xoxp-", + "AIza", + "ya29.", + "glpat-", + "hf_", + "npm_", + "dckr_pat_", + ]; + if !raw_prefixes.iter().any(|p| text.contains(p)) { return Cow::Borrowed(text); } - let bytes = text.as_bytes(); - let len = bytes.len(); - let mut result = String::with_capacity(len); - let mut i = 0; - - while i < len { - if bytes[i].is_ascii_whitespace() { - result.push(bytes[i] as char); - i += 1; - } else { - let start = i; - while i < len && !bytes[i].is_ascii_whitespace() { - i += 1; - } - let token = &text[start..i]; - if SECRET_PREFIXES.iter().any(|prefix| token.contains(prefix)) { - result.push_str("[REDACTED]"); - } else { - result.push_str(token); - } - } + let result = SECRET_REGEX.replace_all(text, "[REDACTED]"); + match result { + Cow::Borrowed(_) => Cow::Borrowed(text), + Cow::Owned(s) => Cow::Owned(s), } - - Cow::Owned(result) } /// Replace absolute filesystem paths with `[PATH]` to prevent information disclosure. @@ -60,30 +78,11 @@ pub fn sanitize_paths(text: &str) -> Cow<'_, str> { return Cow::Borrowed(text); } - let bytes = text.as_bytes(); - let len = bytes.len(); - let mut result = String::with_capacity(len); - let mut i = 0; - - while i < len { - if bytes[i].is_ascii_whitespace() { - result.push(bytes[i] as char); - i += 1; - } else { - let start = i; - while i < len && !bytes[i].is_ascii_whitespace() { - i += 1; - } - let token = &text[start..i]; - if PATH_PREFIXES.iter().any(|prefix| token.contains(prefix)) { - result.push_str("[PATH]"); - } else { - result.push_str(token); - } - } + let result = PATH_REGEX.replace_all(text, "[PATH]"); + match result { + Cow::Borrowed(_) => Cow::Borrowed(text), + Cow::Owned(s) => Cow::Owned(s), } - - Cow::Owned(result) } #[cfg(test)] @@ -206,11 +205,27 @@ mod tests { #[test] fn all_secret_prefixes_tested() { - for prefix in super::SECRET_PREFIXES { + for prefix in &[ + "sk-", + "sk_live_", + "sk_test_", + "AKIA", + "ghp_", + "gho_", + "-----BEGIN", + "xoxb-", + "xoxp-", + "AIza", + "ya29.", + "glpat-", + "hf_", + "npm_", + "dckr_pat_", + ] { let text = format!("token: {prefix}abc123"); let result = redact_secrets(&text); assert!(result.contains("[REDACTED]"), "Failed for prefix: {prefix}"); - assert!(!result.contains(prefix), "Prefix not redacted: {prefix}"); + assert!(!result.contains(*prefix), "Prefix not redacted: {prefix}"); } } @@ -250,6 +265,22 @@ mod tests { assert_eq!(result, "token: [REDACTED]"); } + #[test] + fn redacts_secret_in_url() { + let text = "https://api.example.com?key=sk-abc123xyz"; + let result = redact_secrets(text); + assert!(result.contains("[REDACTED]")); + assert!(!result.contains("sk-abc123xyz")); + } + + #[test] + fn redacts_secret_in_json() { + let text = r#"{"api_key":"sk-abc123def456"}"#; + let result = redact_secrets(text); + assert!(result.contains("[REDACTED]")); + assert!(!result.contains("sk-abc123def456")); + } + #[test] fn sanitize_home_path() { let text = "error at /home/user/project/src/main.rs:42"; @@ -271,4 +302,28 @@ mod tests { let result = sanitize_paths(text); assert!(matches!(result, Cow::Borrowed(_))); } + + #[test] + fn redacts_huggingface_token() { + let text = "HuggingFace token: hf_abcdefghijklmnopqrstuvwxyz"; + let result = redact_secrets(text); + assert!(result.contains("[REDACTED]")); + assert!(!result.contains("hf_")); + } + + #[test] + fn redacts_npm_token() { + let text = "NPM token npm_abc123XYZ"; + let result = redact_secrets(text); + assert!(result.contains("[REDACTED]")); + assert!(!result.contains("npm_abc")); + } + + #[test] + fn redacts_docker_pat() { + let text = "Docker token: dckr_pat_xxxxxxxxxxxx"; + let result = redact_secrets(text); + assert!(result.contains("[REDACTED]")); + assert!(!result.contains("dckr_pat_")); + } } diff --git a/crates/zeph-tools/src/scrape.rs b/crates/zeph-tools/src/scrape.rs index d0a3e2e6..631ff2df 100644 --- a/crates/zeph-tools/src/scrape.rs +++ b/crates/zeph-tools/src/scrape.rs @@ -1,3 +1,4 @@ +use std::net::{IpAddr, SocketAddr}; use std::time::Duration; use schemars::JsonSchema; @@ -50,24 +51,26 @@ impl ExtractMode { /// fetches the URL, and parses HTML with `scrape-core`. #[derive(Debug)] pub struct WebScrapeExecutor { - client: reqwest::Client, + timeout: Duration, max_body_bytes: usize, } impl WebScrapeExecutor { #[must_use] pub fn new(config: &ScrapeConfig) -> Self { - let client = reqwest::Client::builder() - .timeout(Duration::from_secs(config.timeout)) - .redirect(reqwest::redirect::Policy::limited(3)) - .build() - .unwrap_or_default(); - Self { - client, + timeout: Duration::from_secs(config.timeout), max_body_bytes: config.max_body_bytes, } } + + fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client { + let mut builder = reqwest::Client::builder() + .timeout(self.timeout) + .redirect(reqwest::redirect::Policy::limited(3)); + builder = builder.resolve_to_addrs(host, addrs); + builder.build().unwrap_or_default() + } } impl ToolExecutor for WebScrapeExecutor { @@ -136,8 +139,12 @@ impl WebScrapeExecutor { &self, instruction: &ScrapeInstruction, ) -> Result { - validate_url(&instruction.url)?; - let html = self.fetch_html(&instruction.url).await?; + let parsed = validate_url(&instruction.url)?; + let (host, addrs) = resolve_and_validate(&parsed).await?; + // Build a per-request client pinned to the validated addresses, eliminating + // TOCTOU between DNS validation and the actual HTTP connection. + let client = self.build_client(&host, &addrs); + let html = self.fetch_html(&client, &instruction.url).await?; let selector = instruction.select.clone(); let extract = ExtractMode::parse(&instruction.extract); let limit = instruction.limit.unwrap_or(10); @@ -146,9 +153,8 @@ impl WebScrapeExecutor { .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))? } - async fn fetch_html(&self, url: &str) -> Result { - let resp = self - .client + async fn fetch_html(&self, client: &reqwest::Client, url: &str) -> Result { + let resp = client .get(url) .send() .await @@ -183,7 +189,7 @@ fn extract_scrape_blocks(text: &str) -> Vec<&str> { crate::executor::extract_fenced_blocks(text, "scrape") } -fn validate_url(raw: &str) -> Result<(), ToolError> { +fn validate_url(raw: &str) -> Result { let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked { command: format!("invalid URL: {raw}"), })?; @@ -205,20 +211,19 @@ fn validate_url(raw: &str) -> Result<(), ToolError> { }); } - Ok(()) + Ok(parsed) } -fn is_private_host(host: &url::Host<&str>) -> bool { - match host { - url::Host::Domain(d) => *d == "localhost", - url::Host::Ipv4(v4) => { +pub(crate) fn is_private_ip(ip: IpAddr) -> bool { + match ip { + IpAddr::V4(v4) => { v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified() || v4.is_broadcast() } - url::Host::Ipv6(v6) => { + IpAddr::V6(v6) => { if v6.is_loopback() || v6.is_unspecified() { return true; } @@ -247,6 +252,50 @@ fn is_private_host(host: &url::Host<&str>) -> bool { } } +fn is_private_host(host: &url::Host<&str>) -> bool { + match host { + url::Host::Domain(d) => { + // Exact match or subdomain of localhost (e.g. foo.localhost) + // and .internal/.local TLDs used in cloud/k8s environments. + #[allow(clippy::case_sensitive_file_extension_comparisons)] + { + *d == "localhost" + || d.ends_with(".localhost") + || d.ends_with(".internal") + || d.ends_with(".local") + } + } + url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)), + url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)), + } +} + +/// Resolves DNS for the URL host, validates all resolved IPs against private ranges, +/// and returns the hostname and validated socket addresses. +/// +/// Returning the addresses allows the caller to pin the HTTP client to these exact +/// addresses, eliminating TOCTOU between DNS validation and the actual connection. +async fn resolve_and_validate(url: &Url) -> Result<(String, Vec), ToolError> { + let Some(host) = url.host_str() else { + return Ok((String::new(), vec![])); + }; + let port = url.port_or_known_default().unwrap_or(443); + let addrs: Vec = tokio::net::lookup_host(format!("{host}:{port}")) + .await + .map_err(|e| ToolError::Blocked { + command: format!("DNS resolution failed: {e}"), + })? + .collect(); + for addr in &addrs { + if is_private_ip(addr.ip()) { + return Err(ToolError::Blocked { + command: format!("SSRF protection: private IP {} for host {host}", addr.ip()), + }); + } + } + Ok((host.to_owned(), addrs)) +} + fn parse_and_extract( html: &str, selector: &str, @@ -837,4 +886,95 @@ mod tests { assert!(req.iter().any(|v| v.as_str() == Some("select"))); assert!(!req.iter().any(|v| v.as_str() == Some("extract"))); } + + // --- is_private_host: new domain checks (AUD-02) --- + + #[test] + fn subdomain_localhost_blocked() { + let host: url::Host<&str> = url::Host::Domain("foo.localhost"); + assert!(is_private_host(&host)); + } + + #[test] + fn internal_tld_blocked() { + let host: url::Host<&str> = url::Host::Domain("service.internal"); + assert!(is_private_host(&host)); + } + + #[test] + fn local_tld_blocked() { + let host: url::Host<&str> = url::Host::Domain("printer.local"); + assert!(is_private_host(&host)); + } + + #[test] + fn public_domain_not_blocked() { + let host: url::Host<&str> = url::Host::Domain("example.com"); + assert!(!is_private_host(&host)); + } + + // --- resolve_and_validate: private IP rejection --- + + #[tokio::test] + async fn resolve_loopback_rejected() { + // 127.0.0.1 resolves directly (literal IP in DNS query) + let url = url::Url::parse("https://127.0.0.1/path").unwrap(); + // validate_url catches this before resolve_and_validate, but test directly + let result = resolve_and_validate(&url).await; + assert!( + result.is_err(), + "loopback IP must be rejected by resolve_and_validate" + ); + let err = result.unwrap_err(); + assert!(matches!(err, crate::executor::ToolError::Blocked { .. })); + } + + #[tokio::test] + async fn resolve_private_10_rejected() { + let url = url::Url::parse("https://10.0.0.1/path").unwrap(); + let result = resolve_and_validate(&url).await; + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + crate::executor::ToolError::Blocked { .. } + )); + } + + #[tokio::test] + async fn resolve_private_192_rejected() { + let url = url::Url::parse("https://192.168.1.1/path").unwrap(); + let result = resolve_and_validate(&url).await; + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + crate::executor::ToolError::Blocked { .. } + )); + } + + #[tokio::test] + async fn resolve_ipv6_loopback_rejected() { + let url = url::Url::parse("https://[::1]/path").unwrap(); + let result = resolve_and_validate(&url).await; + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + crate::executor::ToolError::Blocked { .. } + )); + } + + #[tokio::test] + async fn resolve_no_host_returns_ok() { + // URL without a resolvable host — should pass through + let url = url::Url::parse("https://example.com/path").unwrap(); + // We can't do a live DNS test, but we can verify a URL with no host + let url_no_host = url::Url::parse("data:text/plain,hello").unwrap(); + // data: URLs have no host; resolve_and_validate should return Ok with empty addrs + let result = resolve_and_validate(&url_no_host).await; + assert!(result.is_ok()); + let (host, addrs) = result.unwrap(); + assert!(host.is_empty()); + assert!(addrs.is_empty()); + drop(url); + drop(url_no_host); + } }