From 723cf050eff1c3518da4c67003992929fa6ce984 Mon Sep 17 00:00:00 2001 From: FL03 Date: Fri, 23 Feb 2024 14:47:57 -0600 Subject: [PATCH] update Signed-off-by: FL03 --- acme/tests/autodiff.rs | 44 ++++++++++++++++-------------------- core/src/utils.rs | 2 +- derive/src/ast/mod.rs | 2 +- derive/src/cmp/mod.rs | 2 +- derive/src/cmp/params/mod.rs | 11 ++++----- derive/src/lib.rs | 7 +++--- derive/src/utils.rs | 2 +- 7 files changed, 31 insertions(+), 39 deletions(-) diff --git a/acme/tests/autodiff.rs b/acme/tests/autodiff.rs index 05cf38b4..6a23db39 100644 --- a/acme/tests/autodiff.rs +++ b/acme/tests/autodiff.rs @@ -7,7 +7,7 @@ #[cfg(test)] extern crate acme; -use acme::prelude::{autodiff, sigmoid}; +use acme::prelude::{autodiff, sigmoid, Sigmoid}; use approx::assert_abs_diff_eq; use num::traits::Float; use std::ops::Add; @@ -26,18 +26,6 @@ where x.neg().exp() / (T::one() + x.neg().exp()).powi(2) } -pub trait Sigmoid { - fn sigmoid(self) -> Self; -} - -impl Sigmoid for T -where - T: Float, -{ - fn sigmoid(self) -> Self { - (T::one() + self.neg().exp()).recip() - } -} trait Square { fn square(self) -> Self; } @@ -54,17 +42,8 @@ where #[test] fn test_autodiff() { let (x, y) = (1.0, 2.0); - // differentiating a function item w.r.t. a - assert_eq!( - autodiff!(a: fn addition(a: f64, b: f64) -> f64 { a + b }), - 1.0 - ); // differentiating a closure item w.r.t. x assert_eq!(autodiff!(x: | x: f64, y: f64 | x * y ), 2.0); - // differentiating a function call w.r.t. x - assert_eq!(autodiff!(x: add(x, y)), 1.0); - // differentiating a function call w.r.t. some variable - assert_eq!(autodiff!(a: add(x, y)), 0.0); // differentiating a method call w.r.t. the reciever (x) assert_eq!(autodiff!(x: x.add(y)), 1.0); // differentiating an expression w.r.t. x @@ -181,11 +160,28 @@ fn test_sigmoid() { ); } +#[ignore = "Function items are currently not supported"] +#[test] +fn test_fn_item() { + let (x, y) = (1_f64, 2_f64); + // differentiating a function item w.r.t. a + // assert_eq!( + // autodiff!(y: fn mul(x: A, y: B) -> C where A: std::ops::Mul { x * y }), + // 2_f64 + // ); + + assert_eq!(autodiff!(y: fn mul(x: f64, y: f64) -> f64 { x * y }), 2_f64); +} + #[ignore = "Currently, support for function calls is not fully implemented"] #[test] fn test_function_call() { - let x = 2_f64; - assert_eq!(autodiff!(x: sigmoid::(x)), sigmoid_prime(x)); + let (x, y) = (1_f64, 2_f64); + // differentiating a function call w.r.t. x + assert_eq!(autodiff!(x: add(x, y)), 1.0); + // differentiating a function call w.r.t. some variable + assert_eq!(autodiff!(a: add(x, y)), 0.0); + assert_eq!(autodiff!(y: sigmoid::(y)), sigmoid_prime(y)); } #[ignore = "Custom trait methods are not yet supported"] diff --git a/core/src/utils.rs b/core/src/utils.rs index a14b94e7..61b4fef4 100644 --- a/core/src/utils.rs +++ b/core/src/utils.rs @@ -22,4 +22,4 @@ where fn sigmoid(self) -> Self { (T::one() + self.neg().exp()).recip() } -} \ No newline at end of file +} diff --git a/derive/src/ast/mod.rs b/derive/src/ast/mod.rs index 97cc5a44..ff6ef7ff 100644 --- a/derive/src/ast/mod.rs +++ b/derive/src/ast/mod.rs @@ -1,4 +1,4 @@ /* Appellation: ast Contrib: FL03 -*/ \ No newline at end of file +*/ diff --git a/derive/src/cmp/mod.rs b/derive/src/cmp/mod.rs index 3ebcca91..eb18459f 100644 --- a/derive/src/cmp/mod.rs +++ b/derive/src/cmp/mod.rs @@ -3,4 +3,4 @@ Contrib: FL03 */ -pub mod params; \ No newline at end of file +pub mod params; diff --git a/derive/src/cmp/params/mod.rs b/derive/src/cmp/params/mod.rs index 25b54d60..d282f3f5 100644 --- a/derive/src/cmp/params/mod.rs +++ b/derive/src/cmp/params/mod.rs @@ -16,10 +16,8 @@ pub fn generate_keys(fields: &Fields, name: &Ident) -> TokenStream { fn handle_named_fields(fields: &FieldsNamed, name: &Ident) -> TokenStream { let FieldsNamed { named, .. } = fields; - let fields_str = named.iter().cloned().map(|field| { - field.ident.unwrap() - }); - let variants = named.iter().cloned().map(|field | { + let fields_str = named.iter().cloned().map(|field| field.ident.unwrap()); + let variants = named.iter().cloned().map(|field| { let ident = field.ident.unwrap(); let variant_ident = format_ident!("{}", capitalize_first(&ident.to_string())); Variant { @@ -30,12 +28,11 @@ fn handle_named_fields(fields: &FieldsNamed, name: &Ident) -> TokenStream { } }); - quote! { #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] - + pub enum #name { #(#variants),* } } -} \ No newline at end of file +} diff --git a/derive/src/lib.rs b/derive/src/lib.rs index be5752e3..638d0a51 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -8,7 +8,7 @@ extern crate proc_macro; use proc_macro::TokenStream; use quote::{format_ident, quote}; -use syn::{parse_macro_input, Data, DataStruct, DeriveInput,}; +use syn::{parse_macro_input, Data, DataStruct, DeriveInput}; pub(crate) mod ast; pub(crate) mod cmp; @@ -47,17 +47,16 @@ pub fn params(input: TokenStream) -> TokenStream { let DataStruct { fields, .. } = s; crate::cmp::params::generate_keys(fields, &store_name) - }, + } _ => panic!("Only structs are supported"), }; // Combine the generated code let generated_code = quote! { - + #param_keys_enum }; // Return the generated code as a TokenStream generated_code.into() } - diff --git a/derive/src/utils.rs b/derive/src/utils.rs index 5b4961c3..76255e20 100644 --- a/derive/src/utils.rs +++ b/derive/src/utils.rs @@ -10,4 +10,4 @@ pub fn capitalize_first(s: &str) -> String { .flat_map(|f| f.to_uppercase()) .chain(s.chars().skip(1)) .collect() -} \ No newline at end of file +}