diff --git a/examples/structured-outputs/Cargo.toml b/examples/structured-outputs/Cargo.toml index 849098e4..0a77ed78 100644 --- a/examples/structured-outputs/Cargo.toml +++ b/examples/structured-outputs/Cargo.toml @@ -8,3 +8,5 @@ publish = false async-openai = {path = "../../async-openai"} serde_json = "1.0.127" tokio = { version = "1.39.3", features = ["full"] } +schemars = "0.8.21" +serde = "1.0.130" diff --git a/examples/structured-outputs/src/main.rs b/examples/structured-outputs/src/main.rs index 3948308d..f1e77db4 100644 --- a/examples/structured-outputs/src/main.rs +++ b/examples/structured-outputs/src/main.rs @@ -2,67 +2,96 @@ use std::error::Error; use async_openai::{ types::{ - ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, - CreateChatCompletionRequestArgs, ResponseFormat, ResponseFormatJsonSchema, + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, CreateChatCompletionRequestArgs, ResponseFormat, + ResponseFormatJsonSchema, }, Client, }; -use serde_json::json; +use schemars::{schema_for, JsonSchema}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; -#[tokio::main] -async fn main() -> Result<(), Box> { - let client = Client::new(); +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[serde(deny_unknown_fields)] +pub struct Step { + pub output: String, + pub explanation: String, +} - let schema = json!({ - "type": "object", - "properties": { - "steps": { - "type": "array", - "items": { - "type": "object", - "properties": { - "explanation": { "type": "string" }, - "output": { "type": "string" } - }, - "required": ["explanation", "output"], - "additionalProperties": false - } - }, - "final_answer": { "type": "string" } - }, - "required": ["steps", "final_answer"], - "additionalProperties": false - }); +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[serde(deny_unknown_fields)] +pub struct MathReasoningResponse { + pub final_answer: String, + pub steps: Vec, +} +pub async fn structured_output( + messages: Vec, +) -> Result, Box> { + let schema = schema_for!(T); + let schema_value = serde_json::to_value(&schema)?; let response_format = ResponseFormat::JsonSchema { json_schema: ResponseFormatJsonSchema { description: None, name: "math_reasoning".into(), - schema: Some(schema), + schema: Some(schema_value), strict: Some(true), }, }; let request = CreateChatCompletionRequestArgs::default() .max_tokens(512u32) - .model("gpt-4o-2024-08-06") - .messages([ - ChatCompletionRequestSystemMessage::from( - "You are a helpful math tutor. Guide the user through the solution step by step.", - ) - .into(), - ChatCompletionRequestUserMessage::from("how can I solve 8x + 7 = -23").into(), - ]) + .model("gpt-4o-mini") + .messages(messages) .response_format(response_format) .build()?; + let client = Client::new(); let response = client.chat().create(request).await?; for choice in response.choices { if let Some(content) = choice.message.content { - print!("{content}") + return Ok(Some(serde_json::from_str::(&content)?)); } } + Ok(None) +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Expecting output schema + // let schema = json!({ + // "type": "object", + // "properties": { + // "steps": { + // "type": "array", + // "items": { + // "type": "object", + // "properties": { + // "explanation": { "type": "string" }, + // "output": { "type": "string" } + // }, + // "required": ["explanation", "output"], + // "additionalProperties": false + // } + // }, + // "final_answer": { "type": "string" } + // }, + // "required": ["steps", "final_answer"], + // "additionalProperties": false + // }); + if let Some(response) = structured_output::(vec![ + ChatCompletionRequestSystemMessage::from( + "You are a helpful math tutor. Guide the user through the solution step by step.", + ) + .into(), + ChatCompletionRequestUserMessage::from("how can I solve 8x + 7 = -23").into(), + ]) + .await? + { + println!("{}", serde_json::to_string(&response).unwrap()); + } + Ok(()) }