Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add validate attribute to InstructMacro #10

Merged
merged 2 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/getting-started/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ version = "0.1.0"
edition = "2021"

[dependencies]
instructor-ai = "0.1.0"
instruct-macros = "0.1.1"
instructor-ai = { path = "../../instructor" }
instruct-macros = { path = "../../instruct-macros" }
openai-api-rs = "4.1.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
instruct-macros-types = "0.1.2"
instruct-macros-types = { path = "../../instruct-macros-types" }
16 changes: 14 additions & 2 deletions examples/getting-started/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::env;

use instruct_macros::InstructMacro;
use instruct_macros::{validate, InstructMacro};
use instruct_macros_types::{ParameterInfo, StructInfo};
use instructor_ai::from_openai;
use openai_api_rs::v1::{
Expand All @@ -18,11 +18,23 @@ fn main() {
// This represents a single user
struct UserInfo {
// This represents the name of the user
#[validate(custom = "validate_uppercase")]
name: String,
// This represents the age of the user
age: u8,
}

#[validate]
fn validate_uppercase(s: &String) -> Result<String, String> {
if s.chars().any(|c| c.is_lowercase()) {
return Err(format!(
"Name '{}' should be entirely in uppercase. Examples: 'TIMOTHY', 'JANE SMITH'",
s
));
}
Ok(s.to_uppercase())
}

let req = ChatCompletionRequest::new(
GPT3_5_TURBO.to_string(),
vec![chat_completion::ChatCompletionMessage {
Expand All @@ -38,6 +50,6 @@ fn main() {
.chat_completion::<UserInfo>(req, 3)
.unwrap();

println!("{}", result.name); // John Doe
println!("{}", result.name); // JOHN DOE
println!("{}", result.age); // 30
}
1 change: 1 addition & 0 deletions instruct-macros-types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use serde::{Deserialize, Serialize};

pub trait InstructMacro {
fn get_info() -> StructInfo;
fn validate(&self) -> Result<(), String>;
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
Expand Down
1 change: 1 addition & 0 deletions instruct-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ license = "MIT OR Apache-2.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
instruct-macros-types = { path = "../instruct-macros-types", version = "0.1.0" }
proc-macro2 = "1.0.86"

[dependencies.syn]
version = "1.0"
Expand Down
104 changes: 88 additions & 16 deletions instruct-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,48 @@
extern crate proc_macro;
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput, Fields, Lit, Meta};
use quote::{format_ident, quote};
use syn::{parse_macro_input, Data, DeriveInput, Fields, Lit, Meta, NestedMeta};

#[proc_macro_derive(InstructMacro)]
pub fn instruct_macro_derive(input: TokenStream) -> TokenStream {
#[proc_macro_derive(InstructMacro, attributes(validate))]
pub fn instruct_validate_derive(input: TokenStream) -> TokenStream {
// Parse the input tokens into a syntax tree
let input = parse_macro_input!(input as DeriveInput);

// Used in the quasi-quotation below as `#name`
let name = &input.ident;

let fields = match &input.data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => &fields.named,
_ => panic!("Only named fields are supported"),
},
_ => panic!("Only structs are supported"),
};

let validation_fields: Vec<_> = fields
.iter()
.filter_map(|f| {
let field_name = &f.ident;
f.attrs
.iter()
.find(|attr| attr.path.is_ident("validate"))
.map(|attr| {
let meta = attr.parse_meta().expect("Unable to parse attribute");
parse_validation_attribute(field_name, &meta)
})
})
.collect();

// Extract struct-level comment
let struct_comment = input
.attrs
.iter()
.filter_map(|attr| {
if attr.path.is_ident("doc") {
match attr.parse_meta().ok()? {
Meta::NameValue(meta) => {
if let Lit::Str(lit) = meta.lit {
return Some(lit.value());
}
if let Ok(Meta::NameValue(meta)) = attr.parse_meta() {
if let Lit::Str(lit) = meta.lit {
return Some(lit.value());
}
_ => {}
}
}
None
Expand Down Expand Up @@ -55,13 +74,10 @@ pub fn instruct_macro_derive(input: TokenStream) -> TokenStream {
.iter()
.filter_map(|attr| {
if attr.path.is_ident("doc") {
match attr.parse_meta().ok()? {
Meta::NameValue(meta) => {
if let Lit::Str(lit) = meta.lit {
return Some(lit.value());
}
if let Ok(Meta::NameValue(meta)) = attr.parse_meta() {
if let Lit::Str(lit) = meta.lit {
return Some(lit.value());
}
_ => {}
}
}
None
Expand Down Expand Up @@ -91,9 +107,65 @@ pub fn instruct_macro_derive(input: TokenStream) -> TokenStream {
parameters,
}
}

fn validate(&self) -> Result<(), String> {
#(#validation_fields)*
Ok(())
}
}
};

// Hand the output tokens back to the compiler
TokenStream::from(expanded)
}

/// Parses the validation attribute and generates corresponding validation code.
///
/// This function processes custom validation attributes, expanding them into function calls
/// that perform the specified validation. It supports custom validators that take a reference
/// to the field type and return a Result with a string error type.
fn parse_validation_attribute(
field_name: &Option<syn::Ident>,
meta: &Meta,
) -> proc_macro2::TokenStream {
let Meta::List(list) = meta else { panic!("Unsupported meta") };

list.nested.iter().map(|nm| {
let NestedMeta::Meta(Meta::NameValue(nv)) = nm else { panic!("Unsupported nested attribute") };
let ident = &nv.path;
let lit = &nv.lit;

match ident.get_ident().unwrap().to_string().as_str() {
"custom" => {
let Lit::Str(s) = lit else { panic!("Custom validator must be a string literal") };
let func = format_ident!("{}", s.value());
quote! {
if let Err(e) = #func(&self.#field_name) {
return Err(format!("Validation failed for field '{}': {}", stringify!(#field_name), e));
}
}
},
_ => panic!("Unsupported validation type"),
}
}).collect()
}

/// Custom attribute macro for field validation in structs.
///
/// This procedural macro attribute is designed to be applied to structs,
/// enabling custom validation for their fields. When the `validate` method
/// is called on an instance of the decorated struct, it triggers the specified
/// custom validation functions for each annotated field.
#[proc_macro_attribute]
pub fn validate(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as syn::ItemFn);
let syn::ItemFn { sig, block, .. } = input;

let expanded = quote! {
#sig {
#block
}
};

TokenStream::from(expanded)
}
41 changes: 40 additions & 1 deletion instruct-macros/tests/integration_test.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
extern crate instruct_macros_types;

use instruct_macros::InstructMacro; // Add this line
use instruct_macros::{validate, InstructMacro};
use instruct_macros_types::{InstructMacro, ParameterInfo, StructInfo};

#[cfg(test)]
Expand All @@ -10,6 +10,7 @@ mod tests {
#[test]
fn test_string_conversion() {
#[derive(InstructMacro, Debug)]
#[allow(dead_code)]
struct TestStruct {
///This is a test field
field1: String,
Expand All @@ -35,4 +36,42 @@ mod tests {
};
assert!(info == desired_struct);
}

#[test]
fn test_validation_macro() {
#[derive(InstructMacro, Debug)]
pub struct UserInfo {
#[validate(custom = "validate_uppercase")]
pub name: String,
pub age: u8,
}

#[validate]
fn validate_uppercase(name: &String) -> Result<String, String> {
if name.chars().any(|c| c.is_lowercase()) {
return Err(format!(
"Name '{}' should be entirely in uppercase. Examples: 'TIMOTHY', 'JANE SMITH'",
name
));
}
Ok(name.to_uppercase())
}

let user_info = UserInfo {
name: "JoHn DoE".to_string(),
age: 100,
};

assert_eq!(
user_info.validate().unwrap_err(),
"Validation failed for field 'name': Name 'JoHn DoE' should be entirely in uppercase. Examples: 'TIMOTHY', 'JANE SMITH'"
);

let user_info = UserInfo {
name: "JOHN DOE".to_string(),
age: 30,
};

assert!(user_info.validate().is_ok());
}
}
4 changes: 2 additions & 2 deletions instructor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ license = "MIT OR Apache-2.0"
changelog = "CHANGELOG.md"

[dependencies]
instruct-macros = "0.1.1"
instruct-macros = { path = "../instruct-macros" }
openai-api-rs = "4.1.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
instruct-macros-types = "0.1.2"
instruct-macros-types = { path = "../instruct-macros-types" }

[lib]
name = "instructor_ai"
Expand Down
13 changes: 12 additions & 1 deletion instructor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ impl InstructorClient {
let result = self._retry_sync::<T>(req.clone(), parsed_model.clone());
match result {
Ok(value) => {
match T::validate(&value) {
Ok(_) => {}
Err(e) => {
error_message =
Some(format!("Validation Error: {:?}. Please fix the issue", e));
continue;
}
}

return Ok(value);
}
Err(e) => {
Expand All @@ -84,7 +93,9 @@ impl InstructorClient {
}
}

panic!("Unable to derive model")
Err(APIError {
message: format!("Unable to derive model: {:?}", error_message),
})
}

fn _retry_sync<T>(
Expand Down