From cff364d29d25c0907e5056fdd05dabbb991b66cc Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 26 Jun 2024 00:02:54 -0400 Subject: [PATCH 1/2] feat: add validate attribute to InstructMacro --- examples/getting-started/Cargo.toml | 6 +- examples/getting-started/src/main.rs | 16 +++- instruct-macros-types/src/lib.rs | 1 + instruct-macros/Cargo.toml | 1 + instruct-macros/src/lib.rs | 104 ++++++++++++++++++---- instruct-macros/tests/integration_test.rs | 41 ++++++++- instructor/Cargo.toml | 4 +- instructor/src/lib.rs | 13 ++- 8 files changed, 161 insertions(+), 25 deletions(-) diff --git a/examples/getting-started/Cargo.toml b/examples/getting-started/Cargo.toml index bab0e27..cf9a7c4 100644 --- a/examples/getting-started/Cargo.toml +++ b/examples/getting-started/Cargo.toml @@ -4,9 +4,9 @@ version = "0.1.0" edition = "2021" [dependencies] -instructor-ai = "0.1.0" -instruct-macros = "0.1.1" +instructor-ai = { path = "../../instructor" } +instruct-macros = { path = "../../instruct-macros" } openai-api-rs = "4.1.0" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -instruct-macros-types = "0.1.2" +instruct-macros-types = { path = "../../instruct-macros-types" } diff --git a/examples/getting-started/src/main.rs b/examples/getting-started/src/main.rs index defb390..b448578 100644 --- a/examples/getting-started/src/main.rs +++ b/examples/getting-started/src/main.rs @@ -1,6 +1,6 @@ use std::env; -use instruct_macros::InstructMacro; +use instruct_macros::{validate, InstructMacro}; use instruct_macros_types::{ParameterInfo, StructInfo}; use instructor_ai::from_openai; use openai_api_rs::v1::{ @@ -18,11 +18,23 @@ fn main() { // This represents a single user struct UserInfo { // This represents the name of the user + #[validate(custom = "validate_uppercase")] name: String, // This represents the age of the user age: u8, } + #[validate] + fn validate_uppercase(s: &String) -> Result { + if s.chars().any(|c| c.is_lowercase()) { + return Err(format!( + "Name '{}' should be entirely in uppercase. Examples: 'TIMOTHY', 'JANE SMITH'", + s + )); + } + Ok(s.to_uppercase()) + } + let req = ChatCompletionRequest::new( GPT3_5_TURBO.to_string(), vec![chat_completion::ChatCompletionMessage { @@ -38,6 +50,6 @@ fn main() { .chat_completion::(req, 3) .unwrap(); - println!("{}", result.name); // John Doe + println!("{}", result.name); // JOHN DOE println!("{}", result.age); // 30 } diff --git a/instruct-macros-types/src/lib.rs b/instruct-macros-types/src/lib.rs index 3e81d14..cfbb15b 100644 --- a/instruct-macros-types/src/lib.rs +++ b/instruct-macros-types/src/lib.rs @@ -2,6 +2,7 @@ use serde::{Deserialize, Serialize}; pub trait InstructMacro { fn get_info() -> StructInfo; + fn validate(&self) -> Result<(), String>; } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] diff --git a/instruct-macros/Cargo.toml b/instruct-macros/Cargo.toml index 0bf5faa..aa3352c 100644 --- a/instruct-macros/Cargo.toml +++ b/instruct-macros/Cargo.toml @@ -18,6 +18,7 @@ license = "MIT OR Apache-2.0" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" instruct-macros-types = { path = "../instruct-macros-types", version = "0.1.0" } +proc-macro2 = "1.0.86" [dependencies.syn] version = "1.0" diff --git a/instruct-macros/src/lib.rs b/instruct-macros/src/lib.rs index 0e6ebff..beb02ca 100644 --- a/instruct-macros/src/lib.rs +++ b/instruct-macros/src/lib.rs @@ -1,29 +1,48 @@ extern crate proc_macro; use proc_macro::TokenStream; -use quote::quote; -use syn::{parse_macro_input, Data, DeriveInput, Fields, Lit, Meta}; +use quote::{format_ident, quote}; +use syn::{parse_macro_input, Data, DeriveInput, Fields, Lit, Meta, NestedMeta}; -#[proc_macro_derive(InstructMacro)] -pub fn instruct_macro_derive(input: TokenStream) -> TokenStream { +#[proc_macro_derive(InstructMacro, attributes(validate))] +pub fn instruct_validate_derive(input: TokenStream) -> TokenStream { // Parse the input tokens into a syntax tree let input = parse_macro_input!(input as DeriveInput); // Used in the quasi-quotation below as `#name` let name = &input.ident; + let fields = match &input.data { + Data::Struct(data) => match &data.fields { + Fields::Named(fields) => &fields.named, + _ => panic!("Only named fields are supported"), + }, + _ => panic!("Only structs are supported"), + }; + + let validation_fields: Vec<_> = fields + .iter() + .filter_map(|f| { + let field_name = &f.ident; + f.attrs + .iter() + .find(|attr| attr.path.is_ident("validate")) + .map(|attr| { + let meta = attr.parse_meta().expect("Unable to parse attribute"); + parse_validation_attribute(field_name, &meta) + }) + }) + .collect(); + // Extract struct-level comment let struct_comment = input .attrs .iter() .filter_map(|attr| { if attr.path.is_ident("doc") { - match attr.parse_meta().ok()? { - Meta::NameValue(meta) => { - if let Lit::Str(lit) = meta.lit { - return Some(lit.value()); - } + if let Ok(Meta::NameValue(meta)) = attr.parse_meta() { + if let Lit::Str(lit) = meta.lit { + return Some(lit.value()); } - _ => {} } } None @@ -55,13 +74,10 @@ pub fn instruct_macro_derive(input: TokenStream) -> TokenStream { .iter() .filter_map(|attr| { if attr.path.is_ident("doc") { - match attr.parse_meta().ok()? { - Meta::NameValue(meta) => { - if let Lit::Str(lit) = meta.lit { - return Some(lit.value()); - } + if let Ok(Meta::NameValue(meta)) = attr.parse_meta() { + if let Lit::Str(lit) = meta.lit { + return Some(lit.value()); } - _ => {} } } None @@ -91,9 +107,65 @@ pub fn instruct_macro_derive(input: TokenStream) -> TokenStream { parameters, } } + + fn validate(&self) -> Result<(), String> { + #(#validation_fields)* + Ok(()) + } } }; // Hand the output tokens back to the compiler TokenStream::from(expanded) } + +/// Parses the validation attribute and generates corresponding validation code. +/// +/// This function processes custom validation attributes, expanding them into function calls +/// that perform the specified validation. It supports custom validators that take a reference +/// to the field type and return a Result with a string error type. +fn parse_validation_attribute( + field_name: &Option, + meta: &Meta, +) -> proc_macro2::TokenStream { + let Meta::List(list) = meta else { panic!("Unsupported meta") }; + + list.nested.iter().map(|nm| { + let NestedMeta::Meta(Meta::NameValue(nv)) = nm else { panic!("Unsupported nested attribute") }; + let ident = &nv.path; + let lit = &nv.lit; + + match ident.get_ident().unwrap().to_string().as_str() { + "custom" => { + let Lit::Str(s) = lit else { panic!("Custom validator must be a string literal") }; + let func = format_ident!("{}", s.value()); + quote! { + if let Err(e) = #func(&self.#field_name) { + return Err(format!("Validation failed for field '{}': {}", stringify!(#field_name), e)); + } + } + }, + _ => panic!("Unsupported validation type"), + } + }).collect() +} + +/// Custom attribute macro for field validation in structs. +/// +/// This procedural macro attribute is designed to be applied to structs, +/// enabling custom validation for their fields. When the `validate` method +/// is called on an instance of the decorated struct, it triggers the specified +/// custom validation functions for each annotated field. +#[proc_macro_attribute] +pub fn validate(_attr: TokenStream, item: TokenStream) -> TokenStream { + let input = parse_macro_input!(item as syn::ItemFn); + let syn::ItemFn { sig, block, .. } = input; + + let expanded = quote! { + #sig { + #block + } + }; + + TokenStream::from(expanded) +} \ No newline at end of file diff --git a/instruct-macros/tests/integration_test.rs b/instruct-macros/tests/integration_test.rs index 2f03e80..7abd9c6 100644 --- a/instruct-macros/tests/integration_test.rs +++ b/instruct-macros/tests/integration_test.rs @@ -1,6 +1,6 @@ extern crate instruct_macros_types; -use instruct_macros::InstructMacro; // Add this line +use instruct_macros::{validate, InstructMacro}; use instruct_macros_types::{InstructMacro, ParameterInfo, StructInfo}; #[cfg(test)] @@ -10,6 +10,7 @@ mod tests { #[test] fn test_string_conversion() { #[derive(InstructMacro, Debug)] + #[allow(dead_code)] struct TestStruct { ///This is a test field field1: String, @@ -35,4 +36,42 @@ mod tests { }; assert!(info == desired_struct); } + + #[test] + fn test_validation_macro() { + #[derive(InstructMacro, Debug)] + pub struct UserInfo { + #[validate(custom = "validate_uppercase")] + pub name: String, + pub age: u8, + } + + #[validate] + fn validate_uppercase(name: &String) -> Result { + if name.chars().any(|c| c.is_lowercase()) { + return Err(format!( + "Name '{}' should be entirely in uppercase. Examples: 'TIMOTHY', 'JANE SMITH'", + name + )); + } + Ok(name.to_uppercase()) + } + + let user_info = UserInfo { + name: "JoHn DoE".to_string(), + age: 100, + }; + + assert_eq!( + user_info.validate().unwrap_err(), + "Validation failed for field 'name': Name 'JoHn DoE' should be entirely in uppercase. Examples: 'TIMOTHY', 'JANE SMITH'" + ); + + let user_info = UserInfo { + name: "JOHN DOE".to_string(), + age: 30, + }; + + assert!(user_info.validate().is_ok()); + } } diff --git a/instructor/Cargo.toml b/instructor/Cargo.toml index fbbe50c..1d29e8b 100644 --- a/instructor/Cargo.toml +++ b/instructor/Cargo.toml @@ -13,11 +13,11 @@ license = "MIT OR Apache-2.0" changelog = "CHANGELOG.md" [dependencies] -instruct-macros = "0.1.1" +instruct-macros = { path = "../instruct-macros" } openai-api-rs = "4.1.0" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -instruct-macros-types = "0.1.2" +instruct-macros-types = { path = "../instruct-macros-types" } [lib] name = "instructor_ai" diff --git a/instructor/src/lib.rs b/instructor/src/lib.rs index 01e0c1f..a0874f3 100644 --- a/instructor/src/lib.rs +++ b/instructor/src/lib.rs @@ -74,6 +74,15 @@ impl InstructorClient { let result = self._retry_sync::(req.clone(), parsed_model.clone()); match result { Ok(value) => { + match T::validate(&value) { + Ok(_) => {} + Err(e) => { + error_message = + Some(format!("Validation Error: {:?}. Please fix the issue", e)); + continue; + } + } + return Ok(value); } Err(e) => { @@ -84,7 +93,9 @@ impl InstructorClient { } } - panic!("Unable to derive model") + Err(APIError { + message: format!("Unable to derive model: {:?}", error_message), + }) } fn _retry_sync( From fc3e16f56235f7d99b5f0071649f39506e03151f Mon Sep 17 00:00:00 2001 From: Ivan Leo Date: Thu, 4 Jul 2024 14:04:56 +0800 Subject: [PATCH 2/2] fixed up readme --- README.md | 20 ++++++++------------ instructor/tests/test_retries.rs | 24 +++++++++--------------- 2 files changed, 17 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 9a3dd61..0a970f0 100644 --- a/README.md +++ b/README.md @@ -49,25 +49,21 @@ let instructor_client = from_openai(client); // This represents a single user struct UserInfo { // This represents the name of the user - #[serde(deserialize_with = "validate_uppercase")] + #[validate(custom = "validate_uppercase")] name: String, // This represents the age of the user age: u8, } -fn validate_uppercase<'de, D>(de: D) -> Result -where - D: Deserializer<'de>, -{ - let s = String::deserialize(de)?; - println!("{}", s); - if s.chars().any(|c| c.is_lowercase()) { - return Err(de::Error::custom(format!( +#[validate] +fn validate_uppercase(name: &String) -> Result { + if name.chars().any(|c| c.is_lowercase()) { + return Err(format!( "Name '{}' should be entirely in uppercase. Examples: 'TIMOTHY', 'JANE SMITH'", - s - ))); + name + )); } - Ok(s.to_uppercase()) + Ok(name.to_uppercase()) } let req = ChatCompletionRequest::new( diff --git a/instructor/tests/test_retries.rs b/instructor/tests/test_retries.rs index 7dfc26c..9d906a1 100644 --- a/instructor/tests/test_retries.rs +++ b/instructor/tests/test_retries.rs @@ -1,13 +1,11 @@ extern crate instruct_macros; extern crate instruct_macros_types; -use instruct_macros::InstructMacro; +use instruct_macros::{validate, InstructMacro}; use instruct_macros_types::{ParameterInfo, StructInfo}; use instructor_ai::from_openai; use openai_api_rs::v1::api::Client; -use serde::{de, Deserializer}; - #[cfg(test)] mod tests { use std::env; @@ -29,25 +27,21 @@ mod tests { // This represents a single user struct UserInfo { // This represents the name of the user - #[serde(deserialize_with = "validate_uppercase")] + #[validate(custom = "validate_uppercase")] name: String, // This represents the age of the user age: u8, } - fn validate_uppercase<'de, D>(de: D) -> Result - where - D: Deserializer<'de>, - { - let s = String::deserialize(de)?; - println!("{}", s); - if s.chars().any(|c| c.is_lowercase()) { - return Err(de::Error::custom(format!( + #[validate] + fn validate_uppercase(name: &String) -> Result { + if name.chars().any(|c| c.is_lowercase()) { + return Err(format!( "Name '{}' should be entirely in uppercase. Examples: 'TIMOTHY', 'JANE SMITH'", - s - ))); + name + )); } - Ok(s.to_uppercase()) + Ok(name.to_uppercase()) } let req = ChatCompletionRequest::new(