diff --git a/CHANGELOG.md b/CHANGELOG.md index c211b14e..3dc4d57c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ## [Unreleased] +### Added +- Structured LLM output via `chat_typed()` on `LlmProvider` trait with JSON schema enforcement (#456) +- OpenAI/Compatible native `response_format: json_schema` structured output (#457) +- Claude structured output via forced tool use pattern (#458) +- `Extractor` utility for typed data extraction from LLM responses (#459) + ## [0.10.0] - 2026-02-18 ### Fixed diff --git a/Cargo.lock b/Cargo.lock index 778c1fae..0eb224dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8454,6 +8454,7 @@ dependencies = [ "hf-hub", "ollama-rs", "reqwest 0.13.2", + "schemars 1.2.1", "serde", "serde_json", "thiserror 2.0.18", diff --git a/crates/zeph-llm/Cargo.toml b/crates/zeph-llm/Cargo.toml index af301bcb..16c7ce3f 100644 --- a/crates/zeph-llm/Cargo.toml +++ b/crates/zeph-llm/Cargo.toml @@ -24,6 +24,7 @@ hf-hub = { workspace = true, optional = true } ollama-rs.workspace = true reqwest = { workspace = true, features = ["json", "rustls", "stream"] } serde = { workspace = true, features = ["derive"] } +schemars.workspace = true serde_json.workspace = true tokenizers = { workspace = true, optional = true } tokio = { workspace = true, features = ["rt", "sync", "time"] } diff --git a/crates/zeph-llm/src/any.rs b/crates/zeph-llm/src/any.rs index e91651df..8a1da5cd 100644 --- a/crates/zeph-llm/src/any.rs +++ b/crates/zeph-llm/src/any.rs @@ -7,6 +7,9 @@ use crate::mock::MockProvider; use crate::ollama::OllamaProvider; use crate::openai::OpenAiProvider; use crate::orchestrator::ModelOrchestrator; +use schemars::JsonSchema; +use serde::de::DeserializeOwned; + use crate::provider::{ChatResponse, ChatStream, LlmProvider, Message, StatusTx, ToolDefinition}; use crate::router::RouterProvider; @@ -54,6 +57,16 @@ impl AnyProvider { } } + /// # Errors + /// + /// Returns an error if the provider fails or the response cannot be parsed. + pub async fn chat_typed_erased(&self, messages: &[Message]) -> Result + where + T: DeserializeOwned + JsonSchema, + { + delegate_provider!(self, |p| p.chat_typed::(messages).await) + } + /// Propagate a status sender to the inner provider (where supported). pub fn set_status_tx(&mut self, tx: StatusTx) { match self { @@ -110,6 +123,10 @@ impl LlmProvider for AnyProvider { delegate_provider!(self, |p| p.name()) } + fn supports_structured_output(&self) -> bool { + delegate_provider!(self, |p| p.supports_structured_output()) + } + fn supports_tool_use(&self) -> bool { delegate_provider!(self, |p| p.supports_tool_use()) } @@ -416,4 +433,48 @@ mod tests { let debug = format!("{provider:?}"); assert!(debug.contains("OpenAi")); } + + #[cfg(feature = "mock")] + #[tokio::test] + async fn chat_typed_erased_dispatches_to_mock() { + #[derive(Debug, serde::Deserialize, schemars::JsonSchema, PartialEq)] + struct TestOutput { + value: String, + } + + let mock = + crate::mock::MockProvider::with_responses(vec![r#"{"value": "from_mock"}"#.into()]); + let provider = AnyProvider::Mock(mock); + let messages = vec![Message::from_legacy(Role::User, "test")]; + let result: TestOutput = provider.chat_typed_erased(&messages).await.unwrap(); + assert_eq!( + result, + TestOutput { + value: "from_mock".into() + } + ); + } + + #[test] + fn any_openai_supports_structured_output() { + let provider = AnyProvider::OpenAi(crate::openai::OpenAiProvider::new( + "key".into(), + "https://api.openai.com/v1".into(), + "gpt-4o".into(), + 1024, + None, + None, + )); + assert!(provider.supports_structured_output()); + } + + #[test] + fn any_ollama_does_not_support_structured_output() { + let provider = AnyProvider::Ollama(OllamaProvider::new( + "http://localhost:11434", + "test".into(), + "embed".into(), + )); + assert!(!provider.supports_structured_output()); + } } diff --git a/crates/zeph-llm/src/claude.rs b/crates/zeph-llm/src/claude.rs index 1f848464..2592b8c6 100644 --- a/crates/zeph-llm/src/claude.rs +++ b/crates/zeph-llm/src/claude.rs @@ -254,6 +254,84 @@ impl LlmProvider for ClaudeProvider { "claude" } + fn supports_structured_output(&self) -> bool { + true + } + + async fn chat_typed(&self, messages: &[Message]) -> Result + where + T: serde::de::DeserializeOwned + schemars::JsonSchema, + Self: Sized, + { + let schema = schemars::schema_for!(T); + let schema_value = + serde_json::to_value(&schema).map_err(|e| LlmError::StructuredParse(e.to_string()))?; + let type_name = std::any::type_name::() + .rsplit("::") + .next() + .unwrap_or("Output"); + + let tool_name = format!("submit_{type_name}"); + let tool = ToolDefinition { + name: tool_name.clone(), + description: format!("Submit the structured {type_name} result"), + parameters: schema_value, + }; + + let (system, chat_messages) = split_messages_structured(messages); + let api_tool = AnthropicTool { + name: &tool.name, + description: &tool.description, + input_schema: &tool.parameters, + }; + + let system_blocks = system.map(|s| split_system_into_blocks(&s)); + let body = TypedToolRequestBody { + model: &self.model, + max_tokens: self.max_tokens, + system: system_blocks, + messages: &chat_messages, + tools: &[api_tool], + tool_choice: ToolChoice { + r#type: "tool", + name: &tool_name, + }, + }; + + let response = self + .client + .post(API_URL) + .header("x-api-key", &self.api_key) + .header("anthropic-version", ANTHROPIC_VERSION) + .header("anthropic-beta", ANTHROPIC_BETA) + .header("content-type", "application/json") + .json(&body) + .send() + .await?; + + let status = response.status(); + let text = response.text().await.map_err(LlmError::Http)?; + + if !status.is_success() { + return Err(LlmError::Other(format!( + "Claude API request failed (status {status})" + ))); + } + + let resp: ToolApiResponse = serde_json::from_str(&text)?; + + for block in resp.content { + if let AnthropicContentBlock::ToolUse { input, .. } = block { + return serde_json::from_value::(input) + .map_err(|e| LlmError::StructuredParse(e.to_string())); + } + } + + Err(LlmError::StructuredParse( + "no tool_use block in response".into(), + )) + } + fn supports_tool_use(&self) -> bool { true } @@ -506,6 +584,23 @@ fn split_system_into_blocks(system: &str) -> Vec { blocks } +#[derive(Serialize)] +struct TypedToolRequestBody<'a> { + model: &'a str, + max_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + system: Option>, + messages: &'a [StructuredApiMessage], + tools: &'a [AnthropicTool<'a>], + tool_choice: ToolChoice<'a>, +} + +#[derive(Serialize)] +struct ToolChoice<'a> { + r#type: &'a str, + name: &'a str, +} + #[derive(Serialize)] struct AnthropicTool<'a> { name: &'a str, diff --git a/crates/zeph-llm/src/compatible.rs b/crates/zeph-llm/src/compatible.rs index b6b4e390..35639c1d 100644 --- a/crates/zeph-llm/src/compatible.rs +++ b/crates/zeph-llm/src/compatible.rs @@ -79,6 +79,18 @@ impl LlmProvider for CompatibleProvider { self.leaked_name } + fn supports_structured_output(&self) -> bool { + self.inner.supports_structured_output() + } + + async fn chat_typed(&self, messages: &[Message]) -> Result + where + T: serde::de::DeserializeOwned + schemars::JsonSchema, + Self: Sized, + { + self.inner.chat_typed(messages).await + } + fn supports_tool_use(&self) -> bool { self.inner.supports_tool_use() } diff --git a/crates/zeph-llm/src/error.rs b/crates/zeph-llm/src/error.rs index b358b915..50afcbf3 100644 --- a/crates/zeph-llm/src/error.rs +++ b/crates/zeph-llm/src/error.rs @@ -37,6 +37,9 @@ pub enum LlmError { #[error("candle error: {0}")] Candle(#[from] candle_core::Error), + #[error("structured output parse failed: {0}")] + StructuredParse(String), + #[error("{0}")] Other(String), } diff --git a/crates/zeph-llm/src/extractor.rs b/crates/zeph-llm/src/extractor.rs new file mode 100644 index 00000000..3298910a --- /dev/null +++ b/crates/zeph-llm/src/extractor.rs @@ -0,0 +1,148 @@ +use schemars::JsonSchema; +use serde::de::DeserializeOwned; + +use crate::LlmError; +use crate::provider::{LlmProvider, Message, Role}; + +pub struct Extractor<'a, P: LlmProvider> { + provider: &'a P, + preamble: Option, +} + +impl<'a, P: LlmProvider> Extractor<'a, P> { + pub fn new(provider: &'a P) -> Self { + Self { + provider, + preamble: None, + } + } + + #[must_use] + pub fn with_preamble(mut self, preamble: impl Into) -> Self { + self.preamble = Some(preamble.into()); + self + } + + /// # Errors + /// + /// Returns an error if the provider fails or the response cannot be parsed. + pub async fn extract(&self, input: &str) -> Result + where + T: DeserializeOwned + JsonSchema, + { + let mut messages = Vec::new(); + if let Some(ref preamble) = self.preamble { + messages.push(Message::from_legacy(Role::System, preamble.clone())); + } + messages.push(Message::from_legacy(Role::User, input)); + self.provider.chat_typed::(&messages).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::provider::{ChatStream, LlmProvider, Message}; + + struct StubProvider { + response: String, + } + + impl LlmProvider for StubProvider { + async fn chat(&self, _messages: &[Message]) -> Result { + Ok(self.response.clone()) + } + + async fn chat_stream(&self, messages: &[Message]) -> Result { + let response = self.chat(messages).await?; + Ok(Box::pin(tokio_stream::once(Ok(response)))) + } + + fn supports_streaming(&self) -> bool { + false + } + + async fn embed(&self, _text: &str) -> Result, LlmError> { + Err(LlmError::EmbedUnsupported { provider: "stub" }) + } + + fn supports_embeddings(&self) -> bool { + false + } + + fn name(&self) -> &'static str { + "stub" + } + } + + #[derive(Debug, serde::Deserialize, schemars::JsonSchema, PartialEq)] + struct TestOutput { + value: String, + } + + #[tokio::test] + async fn extract_without_preamble() { + let provider = StubProvider { + response: r#"{"value": "result"}"#.into(), + }; + let extractor = Extractor::new(&provider); + let result: TestOutput = extractor.extract("test input").await.unwrap(); + assert_eq!( + result, + TestOutput { + value: "result".into() + } + ); + } + + #[tokio::test] + async fn extract_with_preamble() { + let provider = StubProvider { + response: r#"{"value": "with_preamble"}"#.into(), + }; + let extractor = Extractor::new(&provider).with_preamble("Analyze this"); + let result: TestOutput = extractor.extract("test input").await.unwrap(); + assert_eq!( + result, + TestOutput { + value: "with_preamble".into() + } + ); + } + + #[tokio::test] + async fn extract_error_propagation() { + struct FailProvider; + + impl LlmProvider for FailProvider { + async fn chat(&self, _messages: &[Message]) -> Result { + Err(LlmError::Unavailable) + } + + async fn chat_stream(&self, _messages: &[Message]) -> Result { + Err(LlmError::Unavailable) + } + + fn supports_streaming(&self) -> bool { + false + } + + async fn embed(&self, _text: &str) -> Result, LlmError> { + Err(LlmError::Unavailable) + } + + fn supports_embeddings(&self) -> bool { + false + } + + fn name(&self) -> &'static str { + "fail" + } + } + + let provider = FailProvider; + let extractor = Extractor::new(&provider); + let result = extractor.extract::("test").await; + assert!(matches!(result, Err(LlmError::Unavailable))); + } +} diff --git a/crates/zeph-llm/src/lib.rs b/crates/zeph-llm/src/lib.rs index 8099d944..471734e1 100644 --- a/crates/zeph-llm/src/lib.rs +++ b/crates/zeph-llm/src/lib.rs @@ -6,6 +6,7 @@ pub mod candle_provider; pub mod claude; pub mod compatible; pub mod error; +pub mod extractor; #[cfg(feature = "mock")] pub mod mock; pub mod ollama; @@ -15,4 +16,5 @@ pub mod provider; pub mod router; pub use error::LlmError; +pub use extractor::Extractor; pub use provider::LlmProvider; diff --git a/crates/zeph-llm/src/openai.rs b/crates/zeph-llm/src/openai.rs index 88e8258d..263b777b 100644 --- a/crates/zeph-llm/src/openai.rs +++ b/crates/zeph-llm/src/openai.rs @@ -305,6 +305,66 @@ impl LlmProvider for OpenAiProvider { self.last_cache.lock().ok().and_then(|g| *g) } + fn supports_structured_output(&self) -> bool { + true + } + + async fn chat_typed(&self, messages: &[Message]) -> Result + where + T: serde::de::DeserializeOwned + schemars::JsonSchema, + Self: Sized, + { + let schema = schemars::schema_for!(T); + let schema_value = + serde_json::to_value(&schema).map_err(|e| LlmError::StructuredParse(e.to_string()))?; + let type_name = std::any::type_name::() + .rsplit("::") + .next() + .unwrap_or("Output"); + + let api_messages = convert_messages(messages); + let body = TypedChatRequest { + model: &self.model, + messages: &api_messages, + max_tokens: self.max_tokens, + response_format: ResponseFormat { + r#type: "json_schema", + json_schema: JsonSchemaFormat { + name: type_name, + schema: schema_value, + strict: true, + }, + }, + }; + + let response = self + .client + .post(format!("{}/chat/completions", self.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await?; + + let status = response.status(); + let text = response.text().await.map_err(LlmError::Http)?; + + if !status.is_success() { + return Err(LlmError::Other(format!( + "OpenAI API request failed (status {status})" + ))); + } + + let resp: OpenAiChatResponse = serde_json::from_str(&text)?; + let content = resp + .choices + .first() + .map(|c| c.message.content.as_str()) + .ok_or(LlmError::EmptyResponse { provider: "openai" })?; + + serde_json::from_str::(content).map_err(|e| LlmError::StructuredParse(e.to_string())) + } + fn supports_tool_use(&self) -> bool { true } @@ -708,6 +768,27 @@ struct EmbeddingData { embedding: Vec, } +#[derive(Serialize)] +struct TypedChatRequest<'a> { + model: &'a str, + messages: &'a [ApiMessage<'a>], + max_tokens: u32, + response_format: ResponseFormat<'a>, +} + +#[derive(Serialize)] +struct ResponseFormat<'a> { + r#type: &'a str, + json_schema: JsonSchemaFormat<'a>, +} + +#[derive(Serialize)] +struct JsonSchemaFormat<'a> { + name: &'a str, + schema: serde_json::Value, + strict: bool, +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/zeph-llm/src/provider.rs b/crates/zeph-llm/src/provider.rs index aef62b3c..c76b7d80 100644 --- a/crates/zeph-llm/src/provider.rs +++ b/crates/zeph-llm/src/provider.rs @@ -239,6 +239,70 @@ pub trait LlmProvider: Send + Sync { fn last_cache_usage(&self) -> Option<(u64, u64)> { None } + + /// Whether this provider supports native structured output. + fn supports_structured_output(&self) -> bool { + false + } + + /// Send messages and parse the response into a typed value `T`. + /// + /// Default implementation injects JSON schema into the system prompt and retries once + /// on parse failure. Providers with native structured output should override this. + #[allow(async_fn_in_trait)] + async fn chat_typed(&self, messages: &[Message]) -> Result + where + T: serde::de::DeserializeOwned + schemars::JsonSchema, + Self: Sized, + { + let schema = schemars::schema_for!(T); + let schema_json = serde_json::to_string_pretty(&schema) + .map_err(|e| LlmError::StructuredParse(e.to_string()))?; + let type_name = std::any::type_name::() + .rsplit("::") + .next() + .unwrap_or("Output"); + + let mut augmented = messages.to_vec(); + let instruction = format!( + "Respond with a valid JSON object matching this schema. \ + Output ONLY the JSON, no markdown fences or extra text.\n\n\ + Type: {type_name}\nSchema:\n```json\n{schema_json}\n```" + ); + augmented.insert(0, Message::from_legacy(Role::System, instruction)); + + let raw = self.chat(&augmented).await?; + let cleaned = strip_json_fences(&raw); + match serde_json::from_str::(cleaned) { + Ok(val) => Ok(val), + Err(first_err) => { + augmented.push(Message::from_legacy(Role::Assistant, &raw)); + augmented.push(Message::from_legacy( + Role::User, + format!( + "Your response was not valid JSON. Error: {first_err}. \ + Please output ONLY valid JSON matching the schema." + ), + )); + let retry_raw = self.chat(&augmented).await?; + let retry_cleaned = strip_json_fences(&retry_raw); + serde_json::from_str::(retry_cleaned).map_err(|e| { + LlmError::StructuredParse(format!("parse failed after retry: {e}")) + }) + } + } + } +} + +/// Strip markdown code fences from LLM output. Only handles outer fences; +/// JSON containing trailing triple backticks in string values may be +/// incorrectly trimmed (acceptable for MVP — see review R2). +fn strip_json_fences(s: &str) -> &str { + s.trim() + .trim_start_matches("```json") + .trim_start_matches("```") + .trim_end_matches("```") + .trim() } #[cfg(test)] @@ -720,4 +784,171 @@ mod tests { panic!("expected ToolOutput"); } } + + // --- M27: strip_json_fences tests --- + + #[test] + fn strip_json_fences_plain_json() { + assert_eq!(strip_json_fences(r#"{"a": 1}"#), r#"{"a": 1}"#); + } + + #[test] + fn strip_json_fences_with_json_fence() { + assert_eq!(strip_json_fences("```json\n{\"a\": 1}\n```"), r#"{"a": 1}"#); + } + + #[test] + fn strip_json_fences_with_plain_fence() { + assert_eq!(strip_json_fences("```\n{\"a\": 1}\n```"), r#"{"a": 1}"#); + } + + #[test] + fn strip_json_fences_whitespace() { + assert_eq!(strip_json_fences(" \n "), ""); + } + + #[test] + fn strip_json_fences_empty() { + assert_eq!(strip_json_fences(""), ""); + } + + #[test] + fn strip_json_fences_outer_whitespace() { + assert_eq!( + strip_json_fences(" ```json\n{\"a\": 1}\n``` "), + r#"{"a": 1}"# + ); + } + + #[test] + fn strip_json_fences_only_opening_fence() { + assert_eq!(strip_json_fences("```json\n{\"a\": 1}"), r#"{"a": 1}"#); + } + + // --- M27: chat_typed tests --- + + #[derive(Debug, serde::Deserialize, schemars::JsonSchema, PartialEq)] + struct TestOutput { + value: String, + } + + struct SequentialStub { + responses: std::sync::Mutex>>, + } + + impl SequentialStub { + fn new(responses: Vec>) -> Self { + Self { + responses: std::sync::Mutex::new(responses), + } + } + } + + impl LlmProvider for SequentialStub { + async fn chat(&self, _messages: &[Message]) -> Result { + let mut responses = self.responses.lock().unwrap(); + if responses.is_empty() { + return Err(LlmError::Other("no more responses".into())); + } + responses.remove(0) + } + + async fn chat_stream(&self, messages: &[Message]) -> Result { + let response = self.chat(messages).await?; + Ok(Box::pin(tokio_stream::once(Ok(response)))) + } + + fn supports_streaming(&self) -> bool { + false + } + + async fn embed(&self, _text: &str) -> Result, LlmError> { + Err(LlmError::EmbedUnsupported { + provider: "sequential-stub", + }) + } + + fn supports_embeddings(&self) -> bool { + false + } + + fn name(&self) -> &'static str { + "sequential-stub" + } + } + + #[tokio::test] + async fn chat_typed_happy_path() { + let provider = StubProvider { + response: r#"{"value": "hello"}"#.into(), + }; + let messages = vec![Message::from_legacy(Role::User, "test")]; + let result: TestOutput = provider.chat_typed(&messages).await.unwrap(); + assert_eq!( + result, + TestOutput { + value: "hello".into() + } + ); + } + + #[tokio::test] + async fn chat_typed_retry_succeeds() { + let provider = SequentialStub::new(vec![ + Ok("not valid json".into()), + Ok(r#"{"value": "ok"}"#.into()), + ]); + let messages = vec![Message::from_legacy(Role::User, "test")]; + let result: TestOutput = provider.chat_typed(&messages).await.unwrap(); + assert_eq!(result, TestOutput { value: "ok".into() }); + } + + #[tokio::test] + async fn chat_typed_both_fail() { + let provider = SequentialStub::new(vec![Ok("bad json".into()), Ok("still bad".into())]); + let messages = vec![Message::from_legacy(Role::User, "test")]; + let result = provider.chat_typed::(&messages).await; + let err = result.unwrap_err(); + assert!(err.to_string().contains("parse failed after retry")); + } + + #[tokio::test] + async fn chat_typed_chat_error_propagates() { + let provider = SequentialStub::new(vec![Err(LlmError::Unavailable)]); + let messages = vec![Message::from_legacy(Role::User, "test")]; + let result = provider.chat_typed::(&messages).await; + assert!(matches!(result, Err(LlmError::Unavailable))); + } + + #[tokio::test] + async fn chat_typed_strips_fences() { + let provider = StubProvider { + response: "```json\n{\"value\": \"fenced\"}\n```".into(), + }; + let messages = vec![Message::from_legacy(Role::User, "test")]; + let result: TestOutput = provider.chat_typed(&messages).await.unwrap(); + assert_eq!( + result, + TestOutput { + value: "fenced".into() + } + ); + } + + #[test] + fn supports_structured_output_default_false() { + let provider = StubProvider { + response: String::new(), + }; + assert!(!provider.supports_structured_output()); + } + + #[test] + fn structured_parse_error_display() { + let err = LlmError::StructuredParse("test error".into()); + assert_eq!( + err.to_string(), + "structured output parse failed: test error" + ); + } }