Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: FL03 <jo3mccain@icloud.com>
  • Loading branch information
FL03 committed Feb 23, 2024
1 parent b209321 commit 723cf05
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 39 deletions.
44 changes: 20 additions & 24 deletions acme/tests/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#[cfg(test)]
extern crate acme;

use acme::prelude::{autodiff, sigmoid};
use acme::prelude::{autodiff, sigmoid, Sigmoid};
use approx::assert_abs_diff_eq;
use num::traits::Float;
use std::ops::Add;
Expand All @@ -26,18 +26,6 @@ where
x.neg().exp() / (T::one() + x.neg().exp()).powi(2)
}

pub trait Sigmoid {
fn sigmoid(self) -> Self;
}

impl<T> Sigmoid for T
where
T: Float,
{
fn sigmoid(self) -> Self {
(T::one() + self.neg().exp()).recip()
}
}
trait Square {
fn square(self) -> Self;
}
Expand All @@ -54,17 +42,8 @@ where
#[test]
fn test_autodiff() {
let (x, y) = (1.0, 2.0);
// differentiating a function item w.r.t. a
assert_eq!(
autodiff!(a: fn addition(a: f64, b: f64) -> f64 { a + b }),
1.0
);
// differentiating a closure item w.r.t. x
assert_eq!(autodiff!(x: | x: f64, y: f64 | x * y ), 2.0);
// differentiating a function call w.r.t. x
assert_eq!(autodiff!(x: add(x, y)), 1.0);
// differentiating a function call w.r.t. some variable
assert_eq!(autodiff!(a: add(x, y)), 0.0);
// differentiating a method call w.r.t. the reciever (x)
assert_eq!(autodiff!(x: x.add(y)), 1.0);
// differentiating an expression w.r.t. x
Expand Down Expand Up @@ -181,11 +160,28 @@ fn test_sigmoid() {
);
}

#[ignore = "Function items are currently not supported"]
#[test]
fn test_fn_item() {
let (x, y) = (1_f64, 2_f64);
// differentiating a function item w.r.t. a
// assert_eq!(
// autodiff!(y: fn mul<A, B, C>(x: A, y: B) -> C where A: std::ops::Mul<B, Output = C> { x * y }),
// 2_f64
// );

assert_eq!(autodiff!(y: fn mul(x: f64, y: f64) -> f64 { x * y }), 2_f64);
}

#[ignore = "Currently, support for function calls is not fully implemented"]
#[test]
fn test_function_call() {
let x = 2_f64;
assert_eq!(autodiff!(x: sigmoid::<f64>(x)), sigmoid_prime(x));
let (x, y) = (1_f64, 2_f64);
// differentiating a function call w.r.t. x
assert_eq!(autodiff!(x: add(x, y)), 1.0);
// differentiating a function call w.r.t. some variable
assert_eq!(autodiff!(a: add(x, y)), 0.0);
assert_eq!(autodiff!(y: sigmoid::<f64>(y)), sigmoid_prime(y));
}

#[ignore = "Custom trait methods are not yet supported"]
Expand Down
2 changes: 1 addition & 1 deletion core/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ where
fn sigmoid(self) -> Self {
(T::one() + self.neg().exp()).recip()
}
}
}
2 changes: 1 addition & 1 deletion derive/src/ast/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*
Appellation: ast <module>
Contrib: FL03 <jo3mccain@icloud.com>
*/
*/
2 changes: 1 addition & 1 deletion derive/src/cmp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
Contrib: FL03 <jo3mccain@icloud.com>
*/

pub mod params;
pub mod params;
11 changes: 4 additions & 7 deletions derive/src/cmp/params/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@ pub fn generate_keys(fields: &Fields, name: &Ident) -> TokenStream {

fn handle_named_fields(fields: &FieldsNamed, name: &Ident) -> TokenStream {
let FieldsNamed { named, .. } = fields;
let fields_str = named.iter().cloned().map(|field| {
field.ident.unwrap()
});
let variants = named.iter().cloned().map(|field | {
let fields_str = named.iter().cloned().map(|field| field.ident.unwrap());
let variants = named.iter().cloned().map(|field| {
let ident = field.ident.unwrap();
let variant_ident = format_ident!("{}", capitalize_first(&ident.to_string()));
Variant {
Expand All @@ -30,12 +28,11 @@ fn handle_named_fields(fields: &FieldsNamed, name: &Ident) -> TokenStream {
}
});


quote! {
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]

pub enum #name {
#(#variants),*
}
}
}
}
7 changes: 3 additions & 4 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
extern crate proc_macro;
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_macro_input, Data, DataStruct, DeriveInput,};
use syn::{parse_macro_input, Data, DataStruct, DeriveInput};

pub(crate) mod ast;
pub(crate) mod cmp;
Expand Down Expand Up @@ -47,17 +47,16 @@ pub fn params(input: TokenStream) -> TokenStream {
let DataStruct { fields, .. } = s;

crate::cmp::params::generate_keys(fields, &store_name)
},
}
_ => panic!("Only structs are supported"),
};

// Combine the generated code
let generated_code = quote! {

#param_keys_enum
};

// Return the generated code as a TokenStream
generated_code.into()
}

2 changes: 1 addition & 1 deletion derive/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ pub fn capitalize_first(s: &str) -> String {
.flat_map(|f| f.to_uppercase())
.chain(s.chars().skip(1))
.collect()
}
}

0 comments on commit 723cf05

Please sign in to comment.