diff --git a/.github/dependabot.yml b/.github/dependabot.yml index c0ace329..fd5da2be 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -24,3 +24,7 @@ updates: directory: /macros schedule: interval: daily + - package-ecosystem: cargo + directory: /tensor + schedule: + interval: daily \ No newline at end of file diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 8bdb0d7e..e74d07cc 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -22,18 +22,17 @@ on: jobs: build: - name: Build and Test + name: Build strategy: matrix: platform: [ ubuntu-latest ] - toolchain: [ stable, nightly ] runs-on: ${{ matrix.platform }} steps: - uses: actions/checkout@v3 - name: setup (langspace) run: | rustup update - rustup default ${{ matrix.toolchain }} + rustup default nightly - name: Build id: rust-build run: cargo build -F full -r -v --workspace @@ -46,8 +45,31 @@ jobs: ~/.cargo/git target/release key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + test: + name: Test + strategy: + matrix: + platform: [ ubuntu-latest ] + toolchain: [ nightly ] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: setup (langspace) + run: | + rustup update + rustup default ${{ matrix.toolchain }} - name: Test + id: rust-test run: cargo test --all -F full -r -v + bench: + name: Bench + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: setup (langspace) + run: | + rustup update + rustup default nightly - name: Bench - if: matrix.toolchain == 'nightly' + id: rust-bench run: cargo bench --all -F full -r -v \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 7c496636..ef9a165a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,13 +13,14 @@ version = "0.3.0" # TODO - Update cargo package version default-members = [ "acme" ] -exclude = [ "xtask" ] +exclude = [ ] members = [ "acme", "core", "derive", - "macros" -, "tensor"] + "macros", + "tensor" +] resolver = "2" [workspace.dependencies] diff --git a/acme/Cargo.toml b/acme/Cargo.toml index baaf3932..8a14db14 100644 --- a/acme/Cargo.toml +++ b/acme/Cargo.toml @@ -12,7 +12,13 @@ readme.workspace = true repository.workspace = true version.workspace = true +[[example]] +name = "autodiff" +required-features = ["macros"] +[[test]] +name = "autodiff" +required-features = ["macros"] [features] default = ["core", "tensor"] @@ -24,19 +30,20 @@ full = [ ] core = [ - "acme-core" + "dep:acme-core" ] derive = [ - "acme-derive" + "dep:acme-derive", + "macros" ] macros = [ - "acme-macros" + "dep:acme-macros" ] tensor = [ - "acme-tensor" + "dep:acme-tensor" ] [lib] @@ -54,6 +61,8 @@ acme-macros = { features = [], optional = true, path = "../macros", version = "0 acme-tensor = { features = [], optional = true, path = "../tensor", version = "0.3.0" } [dev-dependencies] +approx = "0.5" +num = "0.4" [package.metadata.docs.rs] all-features = true diff --git a/acme/examples/autodiff.rs b/acme/examples/autodiff.rs new file mode 100644 index 00000000..d8afb0b1 --- /dev/null +++ b/acme/examples/autodiff.rs @@ -0,0 +1,33 @@ +/* + Appellation: autodiff + Contrib: FL03 +*/ +#![feature(fn_traits)] +extern crate acme; + +use acme::{autodiff, show_item}; +use acme::prelude::sigmoid; + +macro_rules! eval { + ($var:ident: $ex:expr) => { + println!("Eval: {:?}", $ex); + println!("Gradient: {:?}", autodiff!($var: $ex)); + } +} + +fn main() -> Result<(), Box> { + let x: f64 = 2.0; + + eval!(x: x.tan()); + + eval!(x: x.sin()); + + eval!(x: x.cos().sin()); + // show_item!(sigmoid::); + unsafe { + println!("{:?}", sigmoid::.call((2_f64,))); + } + + + Ok(()) +} diff --git a/acme/examples/cgraph.rs b/acme/examples/cgraph.rs index 74428569..d097dab9 100644 --- a/acme/examples/cgraph.rs +++ b/acme/examples/cgraph.rs @@ -4,20 +4,20 @@ */ extern crate acme; -use acme::prelude::{Graph, Result}; +use acme::prelude::{Result, Scg}; fn main() -> Result<()> { - let mut dcg = Graph::new(); - let x = dcg.variable(1.0); - let y = dcg.variable(2.0); + let mut scg = Scg::new(); + let x = scg.variable(1.0); + let y = scg.variable(2.0); - let z = dcg.add(x, y)?; - let w = dcg.mul(z, y)?; + let z = scg.add(x, y)?; + let w = scg.mul(z, y)?; - let eval = dcg.get_value(w).unwrap(); + let eval = scg.get_value(w).unwrap(); println!("{:?}", *eval); - let grad = dcg.backward(); + let grad = scg.backward(); println!("{:?}", grad); Ok(()) diff --git a/acme/src/lib.rs b/acme/src/lib.rs index d7c0b101..a6fba212 100644 --- a/acme/src/lib.rs +++ b/acme/src/lib.rs @@ -7,7 +7,6 @@ //! Acme is an autodifferentiaion library for Rust. It is designed to be a //! flexible and powerful tool for building machine learning models and //! other differentiable programs. - #[cfg(feature = "core")] pub use acme_core as core; #[cfg(feature = "derive")] diff --git a/acme/tests/autodiff.rs b/acme/tests/autodiff.rs new file mode 100644 index 00000000..4f61cdff --- /dev/null +++ b/acme/tests/autodiff.rs @@ -0,0 +1,185 @@ +/* + Appellation: gradient + Contrib: FL03 +*/ +#![allow(unused_variables)] + +#[cfg(test)] +extern crate acme; + +use acme::prelude::{autodiff, sigmoid}; +use approx::assert_abs_diff_eq; +use num::traits::Float; +use std::ops::Add; + +pub fn add(a: A, b: B) -> C +where + A: std::ops::Add, +{ + a + b +} + +pub fn sigmoid_prime(x: T) -> T +where + T: Float, +{ + x.neg().exp() / (T::one() + x.neg().exp()).powi(2) +} + +pub trait Sigmoid { + fn sigmoid(self) -> Self; +} + +impl Sigmoid for T +where + T: Float, +{ + fn sigmoid(self) -> Self { + (T::one() + self.neg().exp()).recip() + } +} +trait Square { + fn square(self) -> Self; +} + +impl Square for T +where + T: Copy + std::ops::Mul, +{ + fn square(self) -> Self { + self * self + } +} + +#[test] +fn test_autodiff() { + let (x, y) = (1.0, 2.0); + // differentiating a function item w.r.t. a + assert_eq!( + autodiff!(a: fn addition(a: f64, b: f64) -> f64 { a + b }), + 1.0 + ); + // differentiating a closure item w.r.t. x + assert_eq!(autodiff!(x: | x: f64, y: f64 | x * y ), 2.0); + // differentiating a function call w.r.t. x + assert_eq!(autodiff!(x: add(x, y)), 1.0); + // differentiating a function call w.r.t. some variable + assert_eq!(autodiff!(a: add(x, y)), 0.0); + // differentiating a method call w.r.t. the reciever (x) + assert_eq!(autodiff!(x: x.add(y)), 1.0); + // differentiating an expression w.r.t. x + assert_eq!(autodiff!(x: x + y), 1.0); + assert_eq!(autodiff!(y: x += y), 1.0); +} + +#[test] +fn test_array() { + let x = [1.0, 2.0]; + let y = [2.0, 2.0]; + assert_eq!(autodiff!(x: x + y), 1.0); +} + +#[test] +fn test_add() { + let x = [1.0]; + let y = 2.0; + assert_eq!(autodiff!(x: x + y), 1.0); + assert_eq!(autodiff!(y: x += y), 1.0); +} + +#[test] +fn test_div() { + let x = 1.0; + let y = 2.0; + assert_eq!(autodiff!(x: x / y), 1.0 / 2.0); + assert_eq!(autodiff!(y: x / y), -1.0 / 4.0); + assert_eq!(autodiff!(x: x /= y), 1.0 / 2.0); + assert_eq!(autodiff!(y: x /= y), -1.0 / 4.0); +} + +#[test] +fn test_mul() { + let x = 1.0; + let y = 2.0; + assert_eq!(autodiff!(x: x * y), 2.0); + assert_eq!(autodiff!(y: x * y), 1.0); + assert_eq!(autodiff!(x: x *= y), 2.0); + assert_eq!(autodiff!(y: x *= y), 1.0); + assert_eq!(autodiff!(y: x * y + 3.0), 1.0); +} + +#[test] +fn test_sub() { + let x = 1.0; + let y = 2.0; + assert_eq!(autodiff!(x: x - y), 1.0); + assert_eq!(autodiff!(y: x - y), -1.0); + assert_eq!(autodiff!(x: x -= y), 1.0); + assert_eq!(autodiff!(y: x -= y), -1.0); +} + +#[test] +fn test_foil() { + let (x, y) = (1_f64, 2_f64); + + assert_eq!(autodiff!(x: (x + y) * (x + y)), 2_f64 * (x + y)); + assert_eq!(autodiff!(x: (x + y) * (x + y)), autodiff!(y: (x + y) * (x + y))); +} + +#[test] +fn test_chain_rule() { + let (x, y) = (1_f64, 2_f64); + + assert_eq!(autodiff!(x: y * (x + y)), 2.0); + assert_eq!(autodiff!(y: y * (x + y)), 5.0); + assert_eq!(autodiff!(x: (x + y) * y), 2.0); + assert_eq!(autodiff!(y: (x + y) * y), 5.0); +} + +#[test] +fn test_trig() { + let x: f64 = 2.0; + assert_eq!(autodiff!(x: x.cos()), -x.sin()); + assert_eq!(autodiff!(x: x.sin()), x.cos()); + assert_eq!(autodiff!(x: x.tan()), x.cos().square().recip()); +} + +#[test] +fn test_log() { + let x: f64 = 2.0; + + assert_eq!(autodiff!(x: x.ln()), 2_f64.recip()); + assert_eq!(autodiff!(x: (x + 1.0).ln()), 3_f64.recip()); +} + +#[test] +fn test_chained() { + let x: f64 = 2.0; + assert_abs_diff_eq!(autodiff!(x: x.sin() * x.cos()), 2_f64 * x.cos().square() - 1_f64, epsilon = 1e-8); + assert_eq!(autodiff!(x: x.sin().cos()), -x.cos() * x.sin().sin()); + assert_eq!(autodiff!(x: x.ln().ln()), (x * x.ln()).recip()); +} + +#[test] +fn test_sigmoid() { + let x = 2_f64; + assert_eq!(autodiff!(x: 1.0 / (1.0 + (-x).exp())), sigmoid_prime(x)); + assert_eq!(autodiff!(x: | x: f64 | 1.0 / (1.0 + (-x).exp())), sigmoid_prime(x)); + assert_eq!(autodiff!(x: fn sigmoid(x: f64) -> f64 { 1_f64 / (1_f64 + (-x).exp()) }), sigmoid_prime(x)); +} + +// #[ignore = "Currently, support for function calls is not fully implemented"] +#[test] +fn test_function_call() { + let x = 2_f64; + assert_eq!(autodiff!(x: sigmoid::(x)), sigmoid_prime(x)); +} + +#[ignore = "Custom trait methods are not yet supported"] +#[test] +fn test_method() { + let (x, y) = (1_f64, 2_f64); + assert_eq!(autodiff!(x: x.mul(y)), 2.0); + + assert_eq!(autodiff!(x: x.sigmoid()), sigmoid_prime(x)); +} \ No newline at end of file diff --git a/core/Cargo.toml b/core/Cargo.toml index 2b625b35..64a310a4 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -26,9 +26,10 @@ test = true [dependencies] anyhow.workspace = true -daggy = { features = ["serde-1"], version = "0.8" } +# daggy = { features = ["serde-1"], version = "0.8" } lazy_static = "1" num = "0.4" +petgraph = { features = ["serde-1"], version = "0.6" } serde.workspace = true serde_json.workspace = true strum.workspace = true diff --git a/core/src/cmp/id/id.rs b/core/src/cmp/id/id.rs index a40d1385..3076213c 100644 --- a/core/src/cmp/id/id.rs +++ b/core/src/cmp/id/id.rs @@ -3,25 +3,25 @@ Contrib: FL03 */ use super::AtomicId; -use daggy::NodeIndex; +use petgraph::prelude::NodeIndex; use serde::{Deserialize, Serialize}; #[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)] pub struct Id { - id: usize, + id: AtomicId, index: NodeIndex, } impl Id { pub fn new(index: NodeIndex) -> Self { Self { - id: *AtomicId::new(), + id: AtomicId::new(), index, } } pub fn id(&self) -> usize { - self.id + *self.id } pub fn index(&self) -> NodeIndex { diff --git a/core/src/cmp/mod.rs b/core/src/cmp/mod.rs index 0e510389..bd9f4b3e 100644 --- a/core/src/cmp/mod.rs +++ b/core/src/cmp/mod.rs @@ -14,7 +14,7 @@ pub(crate) mod variables; pub mod id; -use daggy::NodeIndex; +use petgraph::prelude::NodeIndex; pub trait NodeConfig { type Eval; diff --git a/core/src/cmp/operators.rs b/core/src/cmp/operators.rs index 6b722e2c..2af9f5cb 100644 --- a/core/src/cmp/operators.rs +++ b/core/src/cmp/operators.rs @@ -2,10 +2,10 @@ Appellation: operators Contrib: FL03 */ -use daggy::NodeIndex; +use super::id::Id; pub struct Operator { - inputs: Vec, + inputs: Vec, name: String, } @@ -22,7 +22,7 @@ impl Operator { self } - pub fn inputs(&self) -> &[NodeIndex] { + pub fn inputs(&self) -> &[Id] { &self.inputs } diff --git a/core/src/errors/error.rs b/core/src/errors/error.rs index e6127f38..44457f8c 100644 --- a/core/src/errors/error.rs +++ b/core/src/errors/error.rs @@ -40,23 +40,23 @@ impl From> for Error { } } -impl From> for Error { - fn from(err: daggy::WouldCycle) -> Self { - Self::new(ErrorKind::Graph, err.to_string()) +impl From> for Error { + fn from(err: std::sync::TryLockError) -> Self { + Self::new(ErrorKind::Sync, err.to_string()) } } -impl From> for Error +impl From> for Error where E: Copy + std::fmt::Debug, { - fn from(err: daggy::petgraph::algo::Cycle) -> Self { - Self::new(ErrorKind::Graph, format!("{:?}", err.node_id())) + fn from(err: petgraph::algo::Cycle) -> Self { + Self::new(ErrorKind::Graph, format!("Cycle: {:?}", err.node_id())) } } -impl From> for Error { - fn from(err: std::sync::TryLockError) -> Self { - Self::new(ErrorKind::Sync, err.to_string()) +impl From for Error { + fn from(err: petgraph::algo::NegativeCycle) -> Self { + Self::new(ErrorKind::Graph, "Negative Cycle detected") } } diff --git a/core/src/exp/dynamic/graph.rs b/core/src/exp/dynamic/graph.rs index d94dea89..94ec0933 100644 --- a/core/src/exp/dynamic/graph.rs +++ b/core/src/exp/dynamic/graph.rs @@ -2,30 +2,33 @@ Appellation: graph Contrib: FL03 */ +use super::{DcgEdge, Node}; use crate::prelude::Result; use crate::stores::{GradientStore, Store}; -use daggy::petgraph::algo::toposort; -use daggy::{Dag, NodeIndex}; +use petgraph::algo::toposort; +use petgraph::prelude::{DiGraph, NodeIndex}; pub struct Dcg { - graph: Dag, + graph: DiGraph, DcgEdge>, } impl Dcg { pub fn new() -> Self { - Self { graph: Dag::new() } + Self { + graph: DiGraph::new(), + } } pub fn clear(&mut self) { self.graph.clear(); } - pub fn get(&self, index: NodeIndex) -> Option<&T> { + pub fn get(&self, index: NodeIndex) -> Option<&Node> { self.graph.node_weight(index) } pub fn variable(&mut self, value: T) -> NodeIndex { - self.graph.add_node(value) + self.graph.add_node(Node::new().with_value(value)) } } diff --git a/core/src/exp/dynamic/mod.rs b/core/src/exp/dynamic/mod.rs index c1660763..bb72bd82 100644 --- a/core/src/exp/dynamic/mod.rs +++ b/core/src/exp/dynamic/mod.rs @@ -26,7 +26,7 @@ mod tests { // let e = dag.add(c, a).unwrap(); - assert_eq!(*dag.get(a).unwrap(), 1.0); + assert_eq!(*dag.get(a).unwrap().value().unwrap(), 1.0); // assert_eq!(*dag.get(e).unwrap(), 2.0); } } diff --git a/core/src/exp/dynamic/node.rs b/core/src/exp/dynamic/node.rs index bb05067f..d056a72a 100644 --- a/core/src/exp/dynamic/node.rs +++ b/core/src/exp/dynamic/node.rs @@ -7,8 +7,10 @@ //! //! The edges connecting to any given node are considered to be inputs and help to determine the flow of information use crate::prelude::Ops; -use daggy::NodeIndex; +use petgraph::prelude::NodeIndex; +use serde::{Deserialize, Serialize}; +#[derive(Clone, Debug, Default, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)] pub struct Node { inputs: Vec, operation: Option, diff --git a/core/src/graphs/mod.rs b/core/src/graphs/mod.rs index b77d7e18..20f9d124 100644 --- a/core/src/graphs/mod.rs +++ b/core/src/graphs/mod.rs @@ -9,88 +9,5 @@ //! //! In a dynamic computational graph (DCG), the graph considers the nodes to be tensors and the edges to be operations. //! -pub use self::{edge::*, graph::*, node::*}; -pub(crate) mod edge; -pub(crate) mod graph; -pub(crate) mod node; - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_dag() { - let mut dag = Graph::new(); - let x = dag.variable(1_f64); - let y = dag.variable(2_f64); - // f(x, y) = x + y - let c = dag.add(x, y).unwrap(); - // verify the value of c to be the sum of x and y - assert_eq!(*dag.get_value(c).unwrap(), 3.0); - // f(x, y) = y * (x + y) - let d = dag.mul(c, y).unwrap(); - // verify the value of d to be the product of c and y - assert_eq!(*dag.get_value(d).unwrap(), 6.0); - - let gc = dag.gradient_at(c).unwrap(); - - assert_eq!(gc[&x], 1.0); - assert_eq!(gc[&y], 1.0); - - let gd = dag.backward().unwrap(); - - assert_eq!(gd[&x], 2.0); - assert_eq!(gd[&y], 5.0); - } - - #[test] - fn test_backward() { - let mut dag = Graph::new(); - let x = dag.variable(1_f64); - let y = dag.variable(2_f64); - - let c = dag.sub(x, y).unwrap(); - - let d = dag.mul(c, y).unwrap(); - - assert_eq!(*dag.get_value(c).unwrap(), -1.0); - assert_eq!(*dag.get_value(d).unwrap(), -2.0); - - let gc = dag.gradient_at(c).unwrap(); - - assert_eq!(gc[&x], 1.0); - assert_eq!(gc[&y], -1.0); - - let gd = dag.backward().unwrap(); - - assert_eq!(gd[&x], 2.0); - assert_eq!(gd[&y], -3.0); - } - - #[ignore = "Not yet implemented"] - #[test] - fn test_division() { - let mut dag = Graph::new(); - let one = dag.constant(1_f64); - let x = dag.variable(1_f64); - let y = dag.variable(2_f64); - - let c = dag.add(x, y).unwrap(); - - let d = dag.div(one, c).unwrap(); - - assert_eq!(*dag.get_value(c).unwrap(), 3.0); - assert_eq!(*dag.get_value(d).unwrap(), 1.0 / 3.0); - - let gc = dag.gradient_at(c).unwrap(); - - assert_eq!(gc[&x], 1.0); - assert_eq!(gc[&y], 1.0); - - let gd = dag.backward().unwrap(); - - assert_eq!(gd[&x], -1.0); - assert_eq!(gd[&y], -1.0); - } -} +pub mod scg; diff --git a/core/src/graphs/edge.rs b/core/src/graphs/scg/edge.rs similarity index 100% rename from core/src/graphs/edge.rs rename to core/src/graphs/scg/edge.rs diff --git a/core/src/graphs/graph.rs b/core/src/graphs/scg/graph.rs similarity index 93% rename from core/src/graphs/graph.rs rename to core/src/graphs/scg/graph.rs index 08e2cba9..137210b7 100644 --- a/core/src/graphs/graph.rs +++ b/core/src/graphs/scg/graph.rs @@ -4,21 +4,21 @@ */ use super::Node; use crate::prelude::{BinaryOp, BinaryOperation, Ops, Result}; -use daggy::petgraph::algo::toposort; -use daggy::{Dag, NodeIndex}; use num::traits::{NumAssign, NumOps, Signed}; +use petgraph::algo::toposort; +use petgraph::prelude::{DiGraph, NodeIndex}; use std::collections::HashMap; #[derive(Clone, Debug)] -pub struct Graph { - graph: Dag, +pub struct Scg { + graph: DiGraph, vals: HashMap, } -impl Graph { +impl Scg { pub fn new() -> Self { Self { - graph: Dag::new(), + graph: DiGraph::new(), vals: HashMap::new(), } } @@ -54,7 +54,7 @@ impl Graph { let v = self.graph.add_node(node.clone()); let edges = node.inputs().iter().map(|i| (*i, v)); let _val = self.vals.insert(v, result.unwrap_or_default()); - self.graph.extend_with_edges(edges)?; + self.graph.extend_with_edges(edges); Ok(v) } @@ -65,7 +65,7 @@ impl Graph { } } -impl Graph +impl Scg where T: Copy + Default + NumAssign + NumOps + Signed + 'static, { @@ -104,12 +104,12 @@ where -grad * out / (val * val) } } - BinaryOp::Mul => { + BinaryOp::Mul(_) => { let out = self.vals[&i]; let val = self.vals[input]; grad * out / val } - BinaryOp::Sub => { + BinaryOp::Sub(_) => { if j % 2 == 0 { grad } else { @@ -133,7 +133,7 @@ where } } -impl Graph +impl Scg where T: Copy + Default + NumOps + PartialOrd, { diff --git a/core/src/graphs/scg/mod.rs b/core/src/graphs/scg/mod.rs new file mode 100644 index 00000000..58ccf33c --- /dev/null +++ b/core/src/graphs/scg/mod.rs @@ -0,0 +1,92 @@ +/* + Appellation: scg + Contrib: FL03 +*/ +//! # Static Computational Graph +//! +//! +pub use self::{edge::*, graph::*, node::*}; + +pub(crate) mod edge; +pub(crate) mod graph; +pub(crate) mod node; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_dag() { + let mut dag = Scg::new(); + let x = dag.variable(1_f64); + let y = dag.variable(2_f64); + // f(x, y) = x + y + let c = dag.add(x, y).unwrap(); + // verify the value of c to be the sum of x and y + assert_eq!(*dag.get_value(c).unwrap(), 3.0); + // f(x, y) = y * (x + y) + let d = dag.mul(c, y).unwrap(); + // verify the value of d to be the product of c and y + assert_eq!(*dag.get_value(d).unwrap(), 6.0); + + let gc = dag.gradient_at(c).unwrap(); + + assert_eq!(gc[&x], 1.0); + assert_eq!(gc[&y], 1.0); + + let gd = dag.backward().unwrap(); + + assert_eq!(gd[&x], 2.0); + assert_eq!(gd[&y], 5.0); + } + + #[test] + fn test_backward() { + let mut dag = Scg::new(); + let x = dag.variable(1_f64); + let y = dag.variable(2_f64); + + let c = dag.sub(x, y).unwrap(); + + let d = dag.mul(c, y).unwrap(); + + assert_eq!(*dag.get_value(c).unwrap(), -1.0); + assert_eq!(*dag.get_value(d).unwrap(), -2.0); + + let gc = dag.gradient_at(c).unwrap(); + + assert_eq!(gc[&x], 1.0); + assert_eq!(gc[&y], -1.0); + + let gd = dag.backward().unwrap(); + + assert_eq!(gd[&x], 2.0); + assert_eq!(gd[&y], -3.0); + } + + #[ignore = "Not yet implemented"] + #[test] + fn test_division() { + let mut dag = Scg::new(); + let one = dag.constant(1_f64); + let x = dag.variable(1_f64); + let y = dag.variable(2_f64); + + let c = dag.add(x, y).unwrap(); + + let d = dag.div(one, c).unwrap(); + + assert_eq!(*dag.get_value(c).unwrap(), 3.0); + assert_eq!(*dag.get_value(d).unwrap(), 1.0 / 3.0); + + let gc = dag.gradient_at(c).unwrap(); + + assert_eq!(gc[&x], 1.0); + assert_eq!(gc[&y], 1.0); + + let gd = dag.backward().unwrap(); + + assert_eq!(gd[&x], -1.0); + assert_eq!(gd[&y], -1.0); + } +} diff --git a/core/src/graphs/node.rs b/core/src/graphs/scg/node.rs similarity index 97% rename from core/src/graphs/node.rs rename to core/src/graphs/scg/node.rs index 432e202f..4949bf23 100644 --- a/core/src/graphs/node.rs +++ b/core/src/graphs/scg/node.rs @@ -8,7 +8,7 @@ //! The edges connecting to any given node are considered to be inputs and help to determine the flow of information use crate::cmp::id::AtomicId; use crate::ops::Ops; -use daggy::NodeIndex; +use petgraph::prelude::NodeIndex; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Default, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)] diff --git a/core/src/lib.rs b/core/src/lib.rs index ca515b97..b18adb9c 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -26,10 +26,11 @@ pub mod stores; pub mod prelude { pub use crate::primitives::*; // pub use crate::specs::*; - // pub use crate::utils::*; + pub use crate::utils::*; pub use crate::cmp::*; pub use crate::errors::*; + pub use crate::graphs::scg::Scg; pub use crate::graphs::*; pub use crate::ops::*; pub use crate::stores::*; diff --git a/core/src/ops/arithmetic.rs b/core/src/ops/arithmetic.rs index 60fb730a..345b65a7 100644 --- a/core/src/ops/arithmetic.rs +++ b/core/src/ops/arithmetic.rs @@ -5,6 +5,11 @@ use serde::{Deserialize, Serialize}; use std::ops::{Add, Div, Mul, Sub}; +pub trait Trig { + fn sin(self) -> Self; + fn cos(self) -> Self; + fn tan(self) -> Self; +} #[derive( Clone, Copy, Debug, Default, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize, )] diff --git a/core/src/ops/gradient.rs b/core/src/ops/gradient.rs new file mode 100644 index 00000000..6cd91eab --- /dev/null +++ b/core/src/ops/gradient.rs @@ -0,0 +1,31 @@ +/* + Appellation: gradient + Contrib: FL03 +*/ + +pub trait Differentiable { + type Derivative; + + fn diff(&self, args: T) -> Self::Derivative; +} + +pub trait Gradient { + type Gradient; + + fn grad(&self, args: T) -> Self::Gradient; +} + +// Mathematically, the gradient of a function is a vector of partial derivatives. + +pub struct Derivative { + pub wrt: T, + pub f: Box T>, +} + +impl Differentiable for Derivative { + type Derivative = T; + + fn diff(&self, args: T) -> Self::Derivative { + (self.f)(args) + } +} diff --git a/core/src/ops/kinds.rs b/core/src/ops/kinds.rs index e8de526d..27a82623 100644 --- a/core/src/ops/kinds.rs +++ b/core/src/ops/kinds.rs @@ -74,8 +74,8 @@ pub enum BinaryOp { Div(Division), Maximum, Minimum, - Mul, - Sub, + Mul(Multiplication), + Sub(Subtraction), } impl BinaryOp { @@ -96,16 +96,16 @@ impl BinaryOp { } pub fn mul() -> Self { - Self::Mul + Self::Mul(Multiplication) } pub fn sub() -> Self { - Self::Sub + Self::Sub(Subtraction) } pub fn is_commutative(&self) -> bool { match self { - Self::Add(_) | Self::Mul => true, + Self::Add(_) | Self::Mul(_) => true, _ => false, } } @@ -135,8 +135,8 @@ where rhs } } - Self::Mul => lhs * rhs, - Self::Sub => lhs - rhs, + Self::Mul(_) => lhs * rhs, + Self::Sub(_) => lhs - rhs, } } } diff --git a/core/src/ops/mod.rs b/core/src/ops/mod.rs index 856807e1..317671f6 100644 --- a/core/src/ops/mod.rs +++ b/core/src/ops/mod.rs @@ -5,13 +5,20 @@ //! # Operations //! //! -pub use self::{arithmetic::*, kinds::*}; +pub use self::{arithmetic::*, gradient::*, kinds::*, operator::*}; pub(crate) mod arithmetic; +pub(crate) mod gradient; pub(crate) mod kinds; +pub(crate) mod operator; use crate::prelude::Result; -use std::marker::Tuple; + +pub trait Expressive { + type Graph; + + fn expand(&self) -> Self::Graph; +} pub trait Backward { type Store; @@ -25,13 +32,6 @@ pub trait Compute { fn compute(&self, args: T) -> Self::Output; } -pub trait Differentiable { - type Derivative; - - fn eval(&self, at: T) -> T; - fn derivative(&self, at: T) -> Self::Derivative; -} - pub trait Evaluate { type Output; @@ -46,25 +46,6 @@ impl Evaluate for f64 { } } -pub trait Gradient { - type Gradient; - - fn grad(&self, args: T) -> Self::Gradient; -} - -pub trait Operand -where - Args: Tuple, -{ - type Output; - - fn name(&self) -> &str; - - fn eval(&self, args: Args) -> Self::Output; - - fn grad(&self, args: Self::Output) -> Option; -} - pub trait BinaryOperation { type Output; diff --git a/core/src/ops/operator.rs b/core/src/ops/operator.rs new file mode 100644 index 00000000..fb7052eb --- /dev/null +++ b/core/src/ops/operator.rs @@ -0,0 +1,18 @@ +/* + Appellation: operator + Contrib: FL03 +*/ +use std::marker::Tuple; + +pub trait Operand +where + Args: Tuple, +{ + type Output; + + fn name(&self) -> &str; + + fn eval(&self, args: Args) -> Self::Output; + + fn grad(&self, args: Self::Output) -> Option; +} diff --git a/core/src/stores/gradient.rs b/core/src/stores/gradient.rs index 830380ee..a80c8451 100644 --- a/core/src/stores/gradient.rs +++ b/core/src/stores/gradient.rs @@ -3,7 +3,7 @@ Contrib: FL03 */ use super::Store; -use daggy::NodeIndex; +use petgraph::prelude::NodeIndex; use std::any::Any; use std::collections::BTreeMap; diff --git a/core/src/stores/mod.rs b/core/src/stores/mod.rs index 62da5ef6..8d0b4582 100644 --- a/core/src/stores/mod.rs +++ b/core/src/stores/mod.rs @@ -2,9 +2,12 @@ Appellation: stores Contrib: FL03 */ -pub use self::gradient::*; +pub use self::{gradient::*, stack::*}; pub(crate) mod gradient; +pub(crate) mod stack; + +use std::collections::{BTreeMap, HashMap}; pub trait Store { fn get(&self, key: &K) -> Option<&V>; @@ -16,38 +19,33 @@ pub trait Store { fn remove(&mut self, key: &K) -> Option; } -// impl Store for BTreeMap where K: Ord { -// fn get(&self, key: &K) -> Option<&V> { -// BTreeMap::get(self, &key) -// } - -// fn get_mut(&mut self, key: &K) -> Option<&mut V> { -// BTreeMap::get_mut(self, &key) -// } - -// fn insert(&mut self, key: K, value: V) { -// BTreeMap::insert(self, key, value); -// } - -// fn remove(&mut self, key: &K) -> Option { -// BTreeMap::remove(self, &key) -// } -// } - -// impl Store for HashMap where K: Eq + std::hash::Hash { -// fn get(&self, key: &K) -> Option<&V> { -// HashMap::get(self, &key) -// } - -// fn get_mut(&mut self, key: &K) -> Option<&mut V> { -// HashMap::get_mut(self, &key) -// } - -// fn insert(&mut self, key: K, value: V) { -// HashMap::insert(self, key, value); -// } - -// fn remove(&mut self, key: &K) -> Option { -// HashMap::remove(self, &key) -// } -// } +pub trait OrInsert { + fn or_insert(&mut self, key: K, value: V) -> &mut V; +} + +macro_rules! impl_store { + ($t:ty, where $($preds:tt)* ) => { + + impl Store for $t where $($preds)* { + fn get(&self, key: &K) -> Option<&V> { + <$t>::get(self, &key) + } + + fn get_mut(&mut self, key: &K) -> Option<&mut V> { + <$t>::get_mut(self, &key) + } + + fn insert(&mut self, key: K, value: V) -> Option { + <$t>::insert(self, key, value) + } + + fn remove(&mut self, key: &K) -> Option { + <$t>::remove(self, &key) + } + } + + }; +} + +impl_store!(BTreeMap, where K: Ord); +impl_store!(HashMap, where K: Eq + std::hash::Hash); diff --git a/core/src/stores/stack.rs b/core/src/stores/stack.rs new file mode 100644 index 00000000..4a4a98e7 --- /dev/null +++ b/core/src/stores/stack.rs @@ -0,0 +1,8 @@ +/* + Appellation: stack + Contrib: FL03 +*/ + +pub struct Stack { + pub(crate) store: Vec<(K, V)>, +} diff --git a/core/src/utils.rs b/core/src/utils.rs index 752dabaf..21054947 100644 --- a/core/src/utils.rs +++ b/core/src/utils.rs @@ -2,3 +2,11 @@ Appellation: utils Contrib: FL03 */ +use num::Float; + +pub fn sigmoid(x: T) -> T +where + T: Float, +{ + (T::one() + x.neg().exp()).recip() +} \ No newline at end of file diff --git a/derive/examples/params.rs b/derive/examples/params.rs new file mode 100644 index 00000000..555aa731 --- /dev/null +++ b/derive/examples/params.rs @@ -0,0 +1,18 @@ +/* + Appellation: params + Contrib: FL03 +*/ +extern crate acme_derive as acme; + +use acme::Params; + +fn main() -> Result<(), Box> { + let params = LinearParams { weight: 1.0 }; + let wk = LinearParamsKey::Weight; + Ok(()) +} + +#[derive(Params)] +pub struct LinearParams { + pub weight: f64, +} diff --git a/derive/src/lib.rs b/derive/src/lib.rs index a23a6753..b9f8b5a5 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -7,8 +7,81 @@ //! extern crate proc_macro; use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use syn::{parse_macro_input, Data, DeriveInput, Fields, Variant}; + +fn capitalize_first(s: &str) -> String { + s.chars() + .take(1) + .flat_map(|f| f.to_uppercase()) + .chain(s.chars().skip(1)) + .collect() +} #[proc_macro_derive(AnswerFn)] pub fn derive_answer_fn(_item: TokenStream) -> TokenStream { "fn answer() -> u32 { 42 }".parse().unwrap() } + +#[proc_macro_derive(HelperAttr, attributes(helper))] +pub fn derive_helper_attr(_item: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[proc_macro_derive(Params, attributes(param))] +pub fn params(input: TokenStream) -> TokenStream { + // Parse the input tokens into a syntax tree + let input = parse_macro_input!(input as DeriveInput); + + // Get the name of the struct + let struct_name = &input.ident; + let store_name = format_ident!("{}Key", struct_name); + + // Generate the parameter struct definition + let param_struct = match &input.data { + Data::Struct(s) => match &s.fields { + _ => {} + }, + _ => panic!("Only structs are supported"), + }; + + // Generate the parameter keys enum + let param_keys_enum = match &input.data { + Data::Struct(s) => match &s.fields { + Fields::Named(fields) => { + let field_names = fields.named.iter().map(|f| &f.ident); + let varaints = field_names.clone().map(|ident| { + let ident_str = ident.as_ref().unwrap().to_string(); + let ident_str = format_ident!("{}", capitalize_first(&ident_str)); + Variant { + attrs: vec![], + ident: ident_str, + fields: Fields::Unit, + discriminant: None, + } + }); + let varaints_str = varaints.clone().map(|v| v.ident); + + quote! { + #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd,)] + pub enum #store_name { + #( + #varaints, + )* + } + } + } + _ => panic!("Only named fields are supported"), + }, + _ => panic!("Only structs are supported"), + }; + + // Combine the generated code + let generated_code = quote! { + // #param_struct + #param_keys_enum + }; + + // Return the generated code as a TokenStream + generated_code.into() +} diff --git a/macros/Cargo.toml b/macros/Cargo.toml index 7f37b997..ed97cb9b 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -17,11 +17,15 @@ test = true [build-dependencies] [dependencies] -proc-macro2 = "1" +num = "0.4" +petgraph = { features = [], version = "0.6" } +proc-macro2 = { features = ["nightly", "span-locations"], version = "1" } quote = "1" -syn = { features = ["full"], version = "2" } + +syn = { features = ["extra-traits", "fold", "full"], version = "2" } [dev-dependencies] +approx = "0.5" [package.metadata.docs.rs] all-features = true diff --git a/macros/examples/sample.rs b/macros/examples/sample.rs index f1f989ad..f9745a4a 100644 --- a/macros/examples/sample.rs +++ b/macros/examples/sample.rs @@ -4,14 +4,10 @@ */ extern crate acme_macros as macros; -use macros::*; +use macros::show_streams; fn main() -> Result<(), Box> { foo(); - let x = 1.0; - let y = 2.0; - let z = partial!(y: x + y;); - println!("Partial Derivative: {:?}", z); Ok(()) } diff --git a/macros/src/ad/autodiff.rs b/macros/src/ad/autodiff.rs new file mode 100644 index 00000000..e05ace53 --- /dev/null +++ b/macros/src/ad/autodiff.rs @@ -0,0 +1,22 @@ +/* + Appellation: autodiff + Contrib: FL03 +*/ +use super::handle::expr::handle_expr; +use super::handle::item::handle_item; +use crate::ast::partials::*; +use proc_macro2::TokenStream; +use syn::Ident; + +pub fn generate_autodiff(partial: &PartialAst) -> TokenStream { + let PartialAst { expr, var, .. } = partial; + let grad = handle_input(&expr, &var); + grad +} + +fn handle_input(input: &PartialFn, var: &Ident) -> TokenStream { + match input { + PartialFn::Expr(inner) => handle_expr(&inner, var), + PartialFn::Item(inner) => handle_item(&inner.clone().into(), var), + } +} diff --git a/macros/src/ad/handle/block.rs b/macros/src/ad/handle/block.rs new file mode 100644 index 00000000..6d7fa3a0 --- /dev/null +++ b/macros/src/ad/handle/block.rs @@ -0,0 +1,18 @@ +/* + Appellation: block + Contrib: FL03 +*/ +use super::stmt::handle_stmt; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{Block, Ident}; + +pub fn handle_block(block: &Block, var: &Ident) -> TokenStream { + let Block { stmts, .. } = block; + let mut grad = quote! { 0.0 }; + for stmt in stmts { + let stmt = handle_stmt(stmt, var); + grad = quote! { #grad + #stmt }; + } + grad +} diff --git a/macros/src/ad/handle/expr/binary.rs b/macros/src/ad/handle/expr/binary.rs new file mode 100644 index 00000000..e7558676 --- /dev/null +++ b/macros/src/ad/handle/expr/binary.rs @@ -0,0 +1,149 @@ +/* + Appellation: binary + Contrib: FL03 +*/ +use super::handle_expr; +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::{BinOp, Expr, ExprBinary, ExprParen, Ident, Token}; + +pub fn handle_binary(expr: &ExprBinary, var: &Ident) -> TokenStream { + let ExprBinary { + left, right, op, .. + } = expr; + + // Compute the partial derivative of the left expression w.r.t. the variable + let dl = handle_expr(&left, var); + // Compute the partial derivative of the right expression w.r.t. the variable + let dr = handle_expr(&right, var); + + // Apply the chain rule based on the operator + match op { + // Differentiate addition + BinOp::Add(_) | BinOp::AddAssign(_) => { + quote! { + #dl + #dr + } + } + // Differentiate division using the quotient rule + BinOp::Div(_) | BinOp::DivAssign(_) => { + quote! { + (#right * #dl - #left * #dr) / (#right * #right) + } + } + // Differentiate multiplication + BinOp::Mul(_) | BinOp::MulAssign(_) => { + if let Expr::Paren(pl) = *left.clone() { + if let Expr::Paren(pr) = *right.clone() { + foil(&pl, &pr, var) + } else { + foil_expr(right, &pl, var) + } + } else if let Expr::Paren(pr) = *right.clone() { + if let Expr::Paren(pl) = *left.clone() { + foil(&pl, &pr, var) + } else { + foil_expr(left, &pr, var) + } + } else { + quote! { + #dl * #right + #dr * #left + } + } + } + // Differentiate subtraction + BinOp::Sub(_) | BinOp::SubAssign(_) => { + quote! { + #dl - #dr + } + } + _ => panic!("Unsupported operation!"), + } +} + +fn foil_expr(a: &Expr, b: &ExprParen, var: &Ident) -> TokenStream { + let ExprParen { expr, .. } = b; + let box_a = Box::new(a.clone()); + let star = Token![*](Span::call_site()); + if let Expr::Binary(inner) = *expr.clone() { + // Multiply the first term of the first expression by the first term of the second expression + let pleft = ExprBinary { + attrs: Vec::new(), + left: box_a.clone(), + op: BinOp::Mul(star), + right: inner.left, + }; + let pright = ExprBinary { + attrs: Vec::new(), + left: box_a, + op: BinOp::Mul(star), + right: inner.right, + }; + // Create a new expression with the two new terms; (a + b) * c = a * c + b * c + let new_expr = ExprBinary { + attrs: Vec::new(), + left: Box::new(pleft.into()), + op: inner.op, + right: Box::new(pright.into()), + }; + + // let _dl = handle_expr(&pleft.into(), var); + // let _dr = handle_expr(&pright.into(), var); + return handle_expr(&new_expr.into(), var); + } + panic!("FOILER") +} + +// (a + b) * (c + d) = a * c + a * d + b * c + b * d +// (a + b) * (c - d) = a * c - a * d + b * c - b * d +fn foil(a: &ExprParen, b: &ExprParen, var: &Ident) -> TokenStream { + let ExprParen { expr: expr_a, .. } = a; + let ExprParen { expr: expr_b, .. } = b; + let star = Token![*](Span::call_site()); + if let Expr::Binary(inner_a) = *expr_a.clone() { + if let Expr::Binary(inner_b) = *expr_b.clone() { + let al = ExprBinary { + attrs: Vec::new(), + left: inner_a.left.clone(), + op: BinOp::Mul(star.clone()), + right: inner_b.left.clone(), + }; + let ar = ExprBinary { + attrs: Vec::new(), + left: inner_a.left.clone(), + op: BinOp::Mul(star.clone()), + right: inner_b.right.clone(), + }; + let bl = ExprBinary { + attrs: Vec::new(), + left: inner_a.right.clone(), + op: BinOp::Mul(star.clone()), + right: inner_b.left.clone(), + }; + let br = ExprBinary { + attrs: Vec::new(), + left: inner_a.right.clone(), + op: BinOp::Mul(star.clone()), + right: inner_b.right.clone(), + }; + let pleft = ExprBinary { + attrs: Vec::new(), + left: Box::new(al.into()), + op: inner_a.op, + right: Box::new(ar.into()), + }; + let pright = ExprBinary { + attrs: Vec::new(), + left: Box::new(bl.into()), + op: inner_a.op, + right: Box::new(br.into()), + }; + let dl = handle_expr(&pleft.into(), var); + let dr = handle_expr(&pright.into(), var); + return quote! { + #dl + #dr + } + } + } + panic!("FOILER") +} \ No newline at end of file diff --git a/macros/src/ad/handle/expr/method.rs b/macros/src/ad/handle/expr/method.rs new file mode 100644 index 00000000..d7edbbd0 --- /dev/null +++ b/macros/src/ad/handle/expr/method.rs @@ -0,0 +1,97 @@ +/* + Appellation: method + Contrib: FL03 +*/ +use super::handle_expr; +use crate::ad::ops::{Methods, UnaryMethod}; +use proc_macro2::TokenStream; +use quote::quote; +use syn::spanned::Spanned; +use std::str::FromStr; +use syn::{Expr, ExprCall, ExprMethodCall, Ident}; +use syn::ExprPath; + +pub fn handle_call(expr: &ExprCall, var: &Ident) -> TokenStream { + let ExprCall { args, func, .. } = expr; + let mut grad = quote! { 0.0 }; + for arg in args { + let arg = handle_expr(&arg, var); + grad = quote! { #grad + #arg }; + } + if let Expr::Path(path) = &**func { + println!("{:?}", expr.span().unwrap().source_file().path()); + if let Some(block) = expr.span().source_text() { + println!("********\n\n\t\tFunction\n{:?}\nArgs:\n{:?}\n{:?}\n\n********", func, args, &block); + } + } + // + let df = handle_expr(&func, var); + + + quote! { #df + #grad } +} + +pub fn handle_method(expr: &ExprMethodCall, var: &Ident) -> TokenStream { + let ExprMethodCall { + args, + method, + receiver, + .. + } = expr; + let method_name = method.clone().to_string(); + let dr = handle_expr(&receiver, var); + if let Ok(method) = Methods::from_str(&method_name) { + let dm = match method { + Methods::Unary(method) => handle_unary_method(&method, &receiver, var), + }; + + return quote! { #dm * #dr }; + } + let mut grad = quote! { 0.0 }; + for arg in args { + let da = handle_expr(&arg, var); + grad = quote! { #grad + #da }; + } + quote! { #dr + #grad } +} + +pub fn handle_unary_method(method: &UnaryMethod, recv: &Expr, _var: &Ident) -> TokenStream { + match method { + UnaryMethod::Abs => quote! { #recv / #recv.abs() }, + UnaryMethod::Cos => quote! { -#recv.sin() }, + UnaryMethod::Cosh => quote! { #recv.sinh() }, + UnaryMethod::Exp => { + quote! { + if #recv.is_sign_negative() { + -#recv.exp() + } else { + #recv.exp() + } + } + } + UnaryMethod::Inverse | UnaryMethod::Recip => quote! { -#recv.powi(-2) }, + UnaryMethod::Ln => quote! { #recv.recip() }, + UnaryMethod::Sin => quote! { #recv.cos() }, + UnaryMethod::Sinh => quote! { #recv.cosh() }, + UnaryMethod::Sqrt => quote! { (2.0 * #recv.sqrt()).recip() }, + UnaryMethod::Tan => quote! { #recv.cos().powi(2).recip() }, + UnaryMethod::Tanh => quote! { #recv.cosh().powi(2).recip() }, + } +} + +pub fn extract_block_logic(expr: &ExprCall) -> Option { + // Get the span of the function call expression + let span = expr.span(); + let source = span.clone().unwrap().source_file(); + + if let Expr::Path(inner) = &*expr.func { + let ExprPath { path, .. } = inner; + // Get the span of the last segment of the path + let span = path.segments.last().unwrap().ident.span(); + + + + } + + None +} diff --git a/macros/src/ad/handle/expr/mod.rs b/macros/src/ad/handle/expr/mod.rs new file mode 100644 index 00000000..019b784d --- /dev/null +++ b/macros/src/ad/handle/expr/mod.rs @@ -0,0 +1,56 @@ +/* + Appellation: expr + Contrib: FL03 +*/ +pub use self::{binary::*, method::*, unary::*}; + +pub(crate) mod binary; +pub(crate) mod method; +pub(crate) mod unary; + +use proc_macro2::TokenStream; +use quote::quote; +use syn::{Expr, Ident}; + +pub fn handle_expr(expr: &Expr, variable: &Ident) -> TokenStream { + match expr { + // Handle differentiable arrays + Expr::Array(inner) => { + let grad = inner.elems.iter().map(|e| handle_expr(e, variable)); + quote! { [#(#grad),*] } + } + // Handle differentiable binary operations + Expr::Binary(inner) => handle_binary(inner, variable), + // Handle differentiable function calls + Expr::Call(inner) => handle_call(inner, variable), + // Handle differentiable closures + Expr::Closure(inner) => handle_expr(&inner.body, variable), + // Differentiate constants + Expr::Const(_) => quote! { 0.0 }, + // Differentiate groups + Expr::Group(inner) => handle_expr(&inner.expr, variable), + // Differentiate literals + Expr::Lit(_) => quote! { 0.0 }, + // Differentiate method calls + Expr::MethodCall(inner) => handle_method(inner, variable), + // Differentiate parenthesized expressions + Expr::Paren(inner) => handle_expr(&inner.expr, variable), + // Differentiate variable expressions + Expr::Path(inner) => { + let syn::ExprPath { path, .. } = inner; + if path.segments.len() != 1 { + panic!("Unsupported path!"); + } + if path.segments[0].ident == *variable { + quote! { 1.0 } + } else { + quote! { 0.0 } + } + } + Expr::Reference(inner) => handle_expr(&inner.expr, variable), + // Differentiate unary expressions + Expr::Unary(inner) => handle_unary(inner, variable), + // Differentiate other expressions + _ => panic!("Unsupported expression!"), + } +} diff --git a/macros/src/ad/handle/expr/unary.rs b/macros/src/ad/handle/expr/unary.rs new file mode 100644 index 00000000..39ecd86f --- /dev/null +++ b/macros/src/ad/handle/expr/unary.rs @@ -0,0 +1,18 @@ +/* + Appellation: unary + Contrib: FL03 +*/ +use super::handle_expr; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{ExprUnary, Ident, UnOp}; + +pub fn handle_unary(expr: &ExprUnary, variable: &Ident) -> TokenStream { + let dv = handle_expr(&expr.expr, variable); + match expr.op { + UnOp::Neg(_) => { + quote! { -#dv } + } + _ => panic!("Unsupported unary operator!"), + } +} diff --git a/macros/src/ad/handle/item.rs b/macros/src/ad/handle/item.rs new file mode 100644 index 00000000..3c796619 --- /dev/null +++ b/macros/src/ad/handle/item.rs @@ -0,0 +1,18 @@ +/* + Appellation: item + Contrib: FL03 +*/ +use super::block::handle_block; +use proc_macro2::TokenStream; +use syn::{Ident, Item, ItemFn}; + +pub fn handle_item(item: &Item, var: &Ident) -> TokenStream { + match item { + Item::Fn(inner) => { + let ItemFn { block, .. } = inner; + handle_block(&block, var) + } + + _ => panic!("Unsupported item!"), + } +} diff --git a/macros/src/ad/handle/mod.rs b/macros/src/ad/handle/mod.rs new file mode 100644 index 00000000..1bf24c2a --- /dev/null +++ b/macros/src/ad/handle/mod.rs @@ -0,0 +1,11 @@ +/* + Appellation: handle + Contrib: FL03 +*/ +//! # Autodifferentiation (AD) +//! + +pub mod block; +pub mod expr; +pub mod item; +pub mod stmt; diff --git a/macros/src/ad/handle/stmt.rs b/macros/src/ad/handle/stmt.rs new file mode 100644 index 00000000..390c42bf --- /dev/null +++ b/macros/src/ad/handle/stmt.rs @@ -0,0 +1,23 @@ +/* + Appellation: item + Contrib: FL03 +*/ +use super::expr::handle_expr; +use super::item::handle_item; +use proc_macro2::TokenStream; +use syn::{Ident, Local, Stmt}; + +pub fn handle_stmt(stmt: &Stmt, var: &Ident) -> TokenStream { + match stmt { + Stmt::Local(local) => { + let Local { init, .. } = local; + if let Some(tmp) = init { + return handle_expr(&tmp.expr, var); + } + panic!("Local variable not initialized!") + } + Stmt::Item(item) => handle_item(item, var), + Stmt::Expr(expr, _) => handle_expr(expr, var), + _ => panic!("Unsupported statement!"), + } +} diff --git a/macros/src/ad/mod.rs b/macros/src/ad/mod.rs new file mode 100644 index 00000000..c547d018 --- /dev/null +++ b/macros/src/ad/mod.rs @@ -0,0 +1,12 @@ +/* + Appellation: ad + Contrib: FL03 +*/ +//! # Autodifferentiation (AD) +//! +pub use self::autodiff::generate_autodiff; + +pub(crate) mod autodiff; + +pub mod handle; +pub mod ops; diff --git a/macros/src/ad/ops/mod.rs b/macros/src/ad/ops/mod.rs new file mode 100644 index 00000000..a3265eb6 --- /dev/null +++ b/macros/src/ad/ops/mod.rs @@ -0,0 +1,28 @@ +/* + Appellation: ops + Contrib: FL03 +*/ +//! # Operations +//! +pub use self::unary::*; + +pub(crate) mod unary; + +use std::str::FromStr; + +#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum Methods { + Unary(UnaryMethod), +} + +impl FromStr for Methods { + type Err = Box; + + fn from_str(s: &str) -> std::result::Result { + if let Ok(method) = UnaryMethod::from_str(s) { + return Ok(Methods::Unary(method)); + } + + Err("Method not found".into()) + } +} diff --git a/macros/src/ad/ops/unary.rs b/macros/src/ad/ops/unary.rs new file mode 100644 index 00000000..e900621b --- /dev/null +++ b/macros/src/ad/ops/unary.rs @@ -0,0 +1,99 @@ +use crate::kw; +use proc_macro2::Span; +use std::str::FromStr; +use syn::parse::{Parse, ParseStream, Result}; + +#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum UnaryMethod { + Abs, + Cos, + Cosh, + Exp, + Inverse, + Ln, + Recip, + Sin, + Sinh, + Sqrt, + Tan, + Tanh, +} + +impl FromStr for UnaryMethod { + type Err = Box; + + fn from_str(s: &str) -> std::result::Result { + match s { + "abs" => Ok(UnaryMethod::Abs), + "cos" | "cosine" => Ok(UnaryMethod::Cos), + "cosh" => Ok(UnaryMethod::Cosh), + "exp" => Ok(UnaryMethod::Exp), + "inv" | "inverse" => Ok(UnaryMethod::Inverse), + "ln" => Ok(UnaryMethod::Ln), + "recip" => Ok(UnaryMethod::Recip), + "sin" | "sine" => Ok(UnaryMethod::Sin), + "sinh" => Ok(UnaryMethod::Sinh), + "sqrt" | "square_root" => Ok(UnaryMethod::Sqrt), + "tan" | "tangent" => Ok(UnaryMethod::Tan), + "tanh" => Ok(UnaryMethod::Tanh), + _ => Err("Method not found".into()), + } + } +} + +impl Parse for UnaryMethod { + fn parse(input: ParseStream) -> Result { + if input.peek(syn::Token![.]) { + if input.peek2(syn::Ident) { + let method = input.parse::()?; + if let Ok(method) = UnaryMethod::from_str(method.to_string().as_str()) { + return Ok(method); + } + } + } + Err(input.error("Expected a method call")) + } +} + +pub enum UnaryOps { + Cosine(kw::cos), + Exp(kw::e), + Ln(kw::ln), + Sine(kw::sin), + Tan(kw::tan), + Std(syn::UnOp), +} + +impl Parse for UnaryOps { + fn parse(input: ParseStream) -> Result { + if input.peek(syn::Token![.]) { + if input.peek2(syn::Ident) { + let method = input.parse::()?; + let span = Span::call_site(); + match method.to_string().as_str() { + "cos" => return Ok(UnaryOps::Cosine(kw::cos(span))), + "exp" => return Ok(UnaryOps::Exp(kw::e(span))), + "ln" => return Ok(UnaryOps::Ln(kw::ln(span))), + "sin" => return Ok(UnaryOps::Sine(kw::sin(span))), + "tan" => return Ok(UnaryOps::Tan(kw::tan(span))), + _ => return Err(input.error("Method not found")), + } + } + if input.peek2(kw::cos) { + input.parse::().map(UnaryOps::Cosine) + } else if input.peek2(kw::sin) { + input.parse::().map(UnaryOps::Sine) + } else if input.peek2(kw::tan) { + input.parse::().map(UnaryOps::Tan) + } else if input.peek2(kw::ln) { + input.parse::().map(UnaryOps::Ln) + } else if input.peek2(kw::e) { + input.parse::().map(UnaryOps::Exp) + } else { + Err(input.error("Expected a method call")) + } + } else { + input.parse::().map(UnaryOps::Std) + } + } +} diff --git a/macros/src/ast/mod.rs b/macros/src/ast/mod.rs new file mode 100644 index 00000000..7da61062 --- /dev/null +++ b/macros/src/ast/mod.rs @@ -0,0 +1,6 @@ +/* + Appellation: ast + Contrib: FL03 +*/ + +pub mod partials; diff --git a/macros/src/ast/partials.rs b/macros/src/ast/partials.rs new file mode 100644 index 00000000..5146ddad --- /dev/null +++ b/macros/src/ast/partials.rs @@ -0,0 +1,82 @@ +/* + Appellation: partials + Contrib: FL03 +*/ +use syn::parse::{Parse, ParseStream, Result}; +use syn::punctuated::Punctuated; +use syn::{Attribute, Block, Expr, Ident, ItemFn, Signature, Token, Type, Visibility}; + +pub struct Partial { + pub expr: Expr, + pub var: Ident, +} + +impl Parse for Partial { + fn parse(input: ParseStream) -> Result { + let variable = input.parse()?; + input.parse::()?; + let expr = input.parse()?; + Ok(Partial { + expr, + var: variable, + }) + } +} + +pub struct Partials { + pub expr: Expr, + pub split: Token![:], + pub vars: Punctuated, +} + +impl Parse for Partials { + fn parse(input: ParseStream) -> Result { + let vars = input.parse_terminated(Type::parse, Token![,])?; + let split = input.parse::()?; + let expr = input.parse()?; + Ok(Self { expr, split, vars }) + } +} + +pub struct StructuredPartial { + +} + +pub struct PartialFnCall { + pub attrs: Vec, + pub body: Box, + pub sig: Signature, + pub vis: Visibility, +} + +pub enum PartialFn { + Expr(Expr), + Item(ItemFn), +} + +impl Parse for PartialFn { + fn parse(input: ParseStream) -> Result { + if let Ok(item) = input.parse() { + Ok(Self::Item(item)) + } else if let Ok(expr) = input.parse() { + Ok(Self::Expr(expr)) + } else { + Err(input.error("Expected a function call or method call")) + } + } +} + +pub struct PartialAst { + pub expr: PartialFn, + pub split: Token![:], + pub var: Ident, +} + +impl Parse for PartialAst { + fn parse(input: ParseStream) -> Result { + let var = input.parse()?; + let split = input.parse::()?; + let expr = input.parse()?; + Ok(Self { expr, split, var }) + } +} diff --git a/macros/src/cmp/graph.rs b/macros/src/cmp/graph.rs new file mode 100644 index 00000000..af2273d0 --- /dev/null +++ b/macros/src/cmp/graph.rs @@ -0,0 +1,10 @@ +/* + Appellation: graph + Contrib: FL03 +*/ +use syn::Expr; + +pub struct Node { + id: usize, + expr: Box, +} diff --git a/macros/src/cmp/mod.rs b/macros/src/cmp/mod.rs index d4fe4c08..eb8caa83 100644 --- a/macros/src/cmp/mod.rs +++ b/macros/src/cmp/mod.rs @@ -1,3 +1,8 @@ -pub use self::partials::*; +/* + Appellation: cmp + Contrib: FL03 +*/ +pub use self::{graph::*, store::*}; -pub(crate) mod partials; +pub(crate) mod graph; +pub(crate) mod store; diff --git a/macros/src/cmp/partials.rs b/macros/src/cmp/partials.rs deleted file mode 100644 index f89f2691..00000000 --- a/macros/src/cmp/partials.rs +++ /dev/null @@ -1,21 +0,0 @@ -/* - Appellation: partials - Contrib: FL03 -*/ -use syn::parse::{Parse, ParseStream, Result}; -use syn::{Expr, Ident, Token}; - -pub struct PartialDerivative { - pub expr: Expr, - pub variable: Ident, -} - -impl Parse for PartialDerivative { - fn parse(input: ParseStream) -> Result { - let variable = input.parse()?; - input.parse::()?; - let expr = input.parse()?; - input.parse::()?; - Ok(PartialDerivative { expr, variable }) - } -} diff --git a/macros/src/cmp/store.rs b/macros/src/cmp/store.rs new file mode 100644 index 00000000..bb0293a5 --- /dev/null +++ b/macros/src/cmp/store.rs @@ -0,0 +1,117 @@ +/* + Appellation: store + Contrib: FL03 +*/ +use proc_macro2::TokenStream; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use syn::Expr; + +pub struct GradientStore { + pub(crate) store: HashMap, +} + +impl GradientStore +where + K: Eq + std::hash::Hash, +{ + pub fn new() -> Self { + Self { + store: HashMap::new(), + } + } + + pub fn entry(&mut self, k: K) -> Entry { + self.store.entry(k) + } + + pub fn get(&self, k: &K) -> Option<&TokenStream> { + self.store.get(k) + } + + pub fn get_mut(&mut self, k: &K) -> Option<&mut TokenStream> { + self.store.get_mut(k) + } + + pub fn insert(&mut self, k: K, v: TokenStream) -> Option { + self.store.insert(k, v) + } + + pub fn or_insert(&mut self, k: K, v: TokenStream) -> &mut TokenStream { + self.entry(k).or_insert(v) + } + + pub fn remove(&mut self, k: &K) -> Option { + self.store.remove(k) + } + + pub fn retain(&mut self, f: F) + where + F: FnMut(&K, &mut TokenStream) -> bool, + { + self.store.retain(f); + } +} + +impl GradientStore { + pub fn retain_vars(&mut self) { + self.retain(|k, _v| matches!(k, Expr::Path(_))); + } +} + +impl std::ops::Deref for GradientStore { + type Target = HashMap; + + fn deref(&self) -> &Self::Target { + &self.store + } +} + +impl std::ops::DerefMut for GradientStore { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.store + } +} + +impl std::ops::Index<&K> for GradientStore +where + K: Eq + std::hash::Hash, +{ + type Output = TokenStream; + + fn index(&self, k: &K) -> &Self::Output { + self.get(k).expect("Key not found") + } +} + +impl std::ops::IndexMut<&K> for GradientStore +where + K: Eq + std::hash::Hash, +{ + fn index_mut(&mut self, k: &K) -> &mut Self::Output { + self.get_mut(k).expect("Key not found") + } +} + +impl IntoIterator for GradientStore { + type Item = (K, TokenStream); + type IntoIter = std::collections::hash_map::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.store.into_iter() + } +} + +impl FromIterator<(K, TokenStream)> for GradientStore +where + K: Eq + std::hash::Hash, +{ + fn from_iter(iter: I) -> Self + where + I: IntoIterator, + { + Self { + store: HashMap::from_iter(iter), + } + } +} diff --git a/macros/src/gradient.rs b/macros/src/gradient.rs new file mode 100644 index 00000000..72f1a2bb --- /dev/null +++ b/macros/src/gradient.rs @@ -0,0 +1,109 @@ +/* + Appellation: gradient + Contrib: FL03 +*/ +use crate::cmp::GradientStore; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{Expr, ExprBinary, ExprUnary}; + +pub fn compute_grad(expr: &Expr) -> TokenStream { + // Initialize an empty HashMap to hold the gradient values + let mut store = GradientStore::new(); + // begin by computing the gradient of the expression w.r.t. itself + // store.insert(expr.clone(), quote! { 1.0 }); + + // Generate code to compute the gradient of the expression w.r.t. each variable + handle_expr(expr, &mut store); + + store.retain_vars(); + + let values = store + .into_iter() + .map(|(k, v)| { + quote! { (#k, #v) } + }) + .collect::>(); + // Convert the gradient values into a token stream + quote! { [#(#values),*] } +} + +pub fn handle_expr(expr: &Expr, store: &mut GradientStore) -> TokenStream { + match expr { + Expr::Binary(inner) => { + let df = binary_grad(inner, store); + df + } + // Handle constants + Expr::Const(_) => quote! { 0.0 }, + // Handle literals + Expr::Lit(_) => quote! { 0.0 }, + Expr::Paren(inner) => handle_expr(&inner.expr, store), + // Handle path variables (identifiers) + Expr::Path(inner) => { + let path = &inner.path; + // Only considers single-segment paths; i.e., x in the expression let x = ___; + if path.segments.len() != 1 { + panic!("Unsupported path!"); + } + let grad = quote! { 1.0 }; + // store.insert(node, grad.clone()); + grad + } + // Handle references (borrowed variables denoted with & or &mut) + Expr::Reference(inner) => handle_expr(&inner.expr, store), + // Handle unary expressions (e.g., negation, natural log, etc.) + Expr::Unary(inner) => { + // Compute the gradient of the expression + let df = handle_unary(inner, store); + + df + } + // Handle other expressions + _ => panic!("Unsupported expression!"), + } +} + +fn binary_grad(expr: &ExprBinary, store: &mut GradientStore) -> TokenStream { + use syn::BinOp; + // create a cloned reference to the expression + let node: Expr = expr.clone().into(); + // let grad = store.entry(node).or_insert(quote! { 0.0 }).clone(); + let grad = store.remove(&node).unwrap_or(quote! { 0.0 }); + let ExprBinary { + left, op, right, .. + } = expr; + + // Recursivley compute the gradient of the left and right children + let dl = handle_expr(left, store); + let dr = handle_expr(right, store); + match op { + BinOp::Add(_) => { + let gl = store.or_insert(*left.clone(), quote! { 0.0 }); + *gl = quote! { #gl + #dl }; + let gr = store.or_insert(*right.clone(), quote! { 0.0 }); + *gr = quote! { #gr + #dr }; + } + BinOp::Mul(_) => { + let gl = store.or_insert(*left.clone(), quote! { 0.0 }); + *gl = quote! { #gl + #right * #dl }; + let gr = store.or_insert(*right.clone(), quote! { 0.0 }); + *gr = quote! { #gr + #left * #dr }; + } + _ => panic!("Unsupported binary operator!"), + }; + grad +} + +fn handle_unary(expr: &ExprUnary, store: &mut GradientStore) -> TokenStream { + use syn::UnOp; + handle_expr(&expr.expr, store); + let dv = &store[&expr.expr.clone()]; + let df = match expr.op { + UnOp::Neg(_) => { + quote! { -#dv } + } + _ => panic!("Unsupported unary operator!"), + }; + df +} diff --git a/macros/src/graph.rs b/macros/src/graph.rs new file mode 100644 index 00000000..ee1c36da --- /dev/null +++ b/macros/src/graph.rs @@ -0,0 +1,193 @@ +/* + Appellation: graph + Contrib: FL03 +*/ +use petgraph::{ + algo::toposort, + prelude::{DiGraph, NodeIndex}, +}; +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use std::collections::HashMap; +use syn::{Expr, ExprBinary}; + +pub struct Context { + graph: DiGraph, +} + +impl Context { + pub fn new() -> Self { + Context { + graph: DiGraph::new(), + } + } + + pub fn add_node(&mut self, expr: Expr) -> NodeIndex { + self.graph.add_node(expr) + } + + pub fn add_edge(&mut self, src: NodeIndex, target: NodeIndex) { + self.graph.add_edge(src, target, ()); + } + + pub fn backward(&self) -> HashMap { + let sorted = toposort(&self.graph, None).expect("The graph is cyclic"); + let target = sorted.last().unwrap().clone(); + + let mut stack = Vec::<(NodeIndex, TokenStream)>::new(); + stack.push((target, quote! { 1.0 })); + let mut store = HashMap::::from_iter(stack.clone()); + + // Iterate through the edges of the graph to compute gradients + while let Some((i, grad)) = stack.pop() { + // get the current node + let node = &self.graph[i]; + + match node { + Expr::Binary(expr_binary) => { + // Compute the gradient of the left child + let left = self + .graph + .neighbors_directed(i, petgraph::Direction::Outgoing) + .next() + .unwrap(); + let left_grad = quote! { #grad * #expr_binary.right }; + stack.push((left, left_grad)); + + // Compute the gradient of the right child + let right = self + .graph + .neighbors_directed(i, petgraph::Direction::Outgoing) + .last() + .unwrap(); + let right_grad = quote! { #grad * #expr_binary.left }; + stack.push((right, right_grad)); + } + Expr::Unary(expr_unary) => { + // Compute the gradient of the child + let child = self + .graph + .neighbors_directed(i, petgraph::Direction::Outgoing) + .next() + .unwrap(); + let child_grad = quote! { #grad * #expr_unary.expr }; + stack.push((child, child_grad)); + } + _ => { + // Do nothing + } + } + } + + store + } + + pub fn traverse(&mut self, expr: &Expr) { + let c = self.add_node(expr.clone()); + + match expr { + Expr::Binary(inner) => { + let ExprBinary { left, right, .. } = inner; + // Add edges for left and right children + let a = self.add_node(*left.clone()); + let b = self.add_node(*right.clone()); + self.add_edge(a, c); + self.add_edge(b, c); + + // Recursive traversal for left and right children + self.traverse(left); + self.traverse(right); + } + + Expr::Unary(inner) => { + // Add an edge for the child + let a = self.add_node(*inner.expr.clone()); + self.add_edge(a, c); + + // Recursive traversal for the child + self.traverse(&inner.expr); + } + _ => {} + } + } +} + +fn handle_expr(expr: &Expr) -> Grad { + match expr { + Expr::Binary(inner) => handle_binary(inner).into(), + _ => panic!("Unsupported expression!"), + } +} + +fn handle_binary(expr: &ExprBinary) -> BinaryGrad { + use syn::BinOp; + let ExprBinary { + left, op, right, .. + } = expr.clone(); + + let dl = handle_expr(&left); + let dr = handle_expr(&right); + match op { + BinOp::Add(_) => { + // Implement addition handling + BinaryGrad { + left: quote! { #dl }, + right: quote! { #dr }, + } + } + BinOp::Mul(_) => { + // Implement multiplication handling + BinaryGrad { + left: quote! { #dl * #right }, + right: quote! { #dr * #left }, + } + } + _ => panic!("Unsupported binary operator!"), + } +} + +pub struct BinaryGrad { + pub left: TokenStream, + pub right: TokenStream, +} + +impl ToTokens for BinaryGrad { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.left.to_tokens(tokens); + self.right.to_tokens(tokens); + } +} + +pub enum Grad { + Binary(BinaryGrad), + Unary(TokenStream), + Verbatim(TokenStream), +} + +impl From for Grad { + fn from(grad: BinaryGrad) -> Self { + Grad::Binary(grad) + } +} + +impl From for Grad { + fn from(grad: TokenStream) -> Self { + Grad::Verbatim(grad) + } +} + +impl ToTokens for Grad { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + Grad::Binary(grad) => { + grad.to_tokens(tokens); + } + Grad::Unary(grad) => { + grad.to_tokens(tokens); + } + Grad::Verbatim(grad) => { + grad.to_tokens(tokens); + } + } + } +} diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 29102305..6034a0a4 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -5,23 +5,21 @@ //! # acme-macros //! //! +#![feature(proc_macro_span,)] extern crate proc_macro; -use proc_macro::TokenStream; -use proc_macro2::TokenStream as Ts; -use quote::quote; -use syn::{parse_macro_input, Expr, Ident}; + +pub(crate) mod ad; +pub(crate) mod ast; pub(crate) mod cmp; -use cmp::PartialDerivative; +pub(crate) mod gradient; +pub(crate) mod graph; -#[proc_macro] -pub fn express(item: TokenStream) -> TokenStream { - let input = Ts::from(item); - // let output = parse!(input as Expr); - println!("item: \"{:?}\"", &input.to_string()); - TokenStream::from(quote! { #input }) -} +use ast::partials::*; +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse_macro_input, Expr}; #[proc_macro_attribute] pub fn show_streams(attr: TokenStream, item: TokenStream) -> TokenStream { @@ -31,199 +29,75 @@ pub fn show_streams(attr: TokenStream, item: TokenStream) -> TokenStream { } #[proc_macro] -pub fn differentiate(input: TokenStream) -> TokenStream { - // Parse the input expression into a syntax tree - let expr = parse_macro_input!(input as Expr); - - // Generate code to perform automatic differentiation - let result = match_differentiate(&expr); - - // Return the generated code as a token stream - TokenStream::from(result) -} - -fn match_differentiate(expr: &Expr) -> Ts { - match expr { - Expr::Assign(expr_assign) => { - let left = &expr_assign.left; - let right = &expr_assign.right; - - // Differentiate the right subexpression - let right_diff = match_differentiate(right); - - // Return the differentiated expression - quote! { - { - let right_diff = #right_diff; - #left = right_diff; - } - } - } - Expr::Binary(expr_binary) => { - let left = &expr_binary.left; - let right = &expr_binary.right; - let op = &expr_binary.op; - - // Differentiate left and right subexpressions - let left_diff = match_differentiate(left); - let right_diff = match_differentiate(right); - - // Apply the chain rule based on the operator - match op { - // Differentiate addition and subtraction - syn::BinOp::Add(_plus) => { - quote! { - { - let left_diff = #left_diff; - let right_diff = #right_diff; - left_diff + right_diff - } - } - } - // Differentiate multiplication and division - syn::BinOp::Mul(_) => { - quote! { - { - let left_diff = #left_diff; - let right_diff = #right_diff; - left_diff * #right + #left * right_diff - } - } - } - _ => panic!("Unsupported operator!"), - } - } - // Differentiate literal expressions (constants) - Expr::Const(_) => quote! { 0.0 }, - // Differentiate literal expressions (constants) - Expr::Lit(_) => quote! { 0.0 }, - Expr::Reference(_) => quote! { 1.0 }, - Expr::Path(_) => quote! { 1.0 }, - _ => panic!("Unsupported expression!"), - } +pub fn show_item(item: TokenStream) -> TokenStream { + println!("item: \"{:?}\"", syn::parse_macro_input!(item as syn::ItemFn)); + quote! { }.into() } #[proc_macro] -pub fn partial(input: TokenStream) -> TokenStream { - // Parse the input token stream into a syntax tree representing the expression and variable - let PartialDerivative { expr, variable } = parse_macro_input!(input as PartialDerivative); +pub fn autodiff(input: TokenStream) -> TokenStream { + // Parse the input expression into a syntax tree + let expr = parse_macro_input!(input as PartialAst); - // Generate code to perform partial differentiation - let result = match_partial_differentiate(&expr, &variable); + // Generate code to compute the gradient + let result = ad::generate_autodiff(&expr); // Return the generated code as a token stream TokenStream::from(result) } -fn match_partial_differentiate(expr: &Expr, variable: &Ident) -> proc_macro2::TokenStream { - match expr { - Expr::Binary(expr_binary) => { - let left = &expr_binary.left; - let right = &expr_binary.right; - let op = &expr_binary.op; - - // Differentiate left and right subexpressions - let left_diff = match_partial_differentiate(left, variable); - let right_diff = match_partial_differentiate(right, variable); - - // Apply the chain rule based on the operator - match op { - // Differentiate addition - syn::BinOp::Add(_) => { - quote! { - { - if #left == #variable { - #left_diff - } else { - #right_diff - } - } - } - } - // Differentiate multiplication - syn::BinOp::Mul(_) => { - quote! { - { - let left_diff = #left_diff; - let right_diff = #right_diff; - #left * right_diff + left_diff * #right - } - } - } - _ => panic!("Unsupported operation!"), - } - } - // Differentiate variable expressions - Expr::Path(expr_path) - if expr_path.path.segments.len() == 1 - && expr_path.path.segments[0].ident == *variable => - { - quote! { 1.0 } // The derivative of the variable with respect to itself is 1 - } - // Differentiate other expressions - _ => quote! { 0.0 }, // The derivative of anything else is 0 - } +#[proc_macro] +pub fn compute(input: TokenStream) -> TokenStream { + use graph::Context; + // Parse the input expression into a syntax tree + let expr = parse_macro_input!(input as Expr); + + // Build a computational graph representing the expression + let mut graph = Context::new(); + graph.traverse(&expr); + + // Generate code to compute gradients and return as a HashMap + let grad = graph.backward(); + let grads = grad + .into_iter() + .map(|(k, v)| { + let k = k.index(); + quote! { (#k, #v) } + }) + .collect::>(); + quote! { [#(#grads),*] }.into() } #[proc_macro] -pub fn gradient(input: TokenStream) -> TokenStream { +pub fn grad(input: TokenStream) -> TokenStream { // Parse the input expression into a syntax tree let expr = parse_macro_input!(input as Expr); // Generate code to compute the gradient - let result = compute_gradient(&expr); + let result = gradient::compute_grad(&expr); // Return the generated code as a token stream TokenStream::from(result) } -fn compute_gradient(expr: &Expr) -> Ts { - // Initialize an empty Vec to hold the gradient values - let mut gradient_values = Vec::new(); - - // Generate code to compute the gradient of the expression with respect to each variable - generate_gradient(expr, &mut gradient_values); +#[proc_macro] +pub fn partial(input: TokenStream) -> TokenStream { + // Parse the input token stream into a structured syntax tree + let partial = parse_macro_input!(input as Partial); - // Convert the gradient values into a token stream - let gradient_array = quote! { [#(#gradient_values),*] }; + // Generate code to perform partial differentiation + let result = ad::handle::expr::handle_expr(&partial.expr, &partial.var); // Return the generated code as a token stream - gradient_array + TokenStream::from(result) } -fn generate_gradient(expr: &Expr, gradient_values: &mut Vec) { - match expr { - // Handle binary expressions (e.g., addition, multiplication) - Expr::Binary(expr_binary) => { - let left = &expr_binary.left; - let right = &expr_binary.right; - - // Recursively compute gradient for left and right subexpressions - generate_gradient(left, gradient_values); - generate_gradient(right, gradient_values); - } - - // Handle literals (constants) - Expr::Const(_) => { - // For constants, add 0 to the gradient vector - gradient_values.push(quote! { 0.0 }); - } - Expr::Lit(_) => { - // For literals, add 0 to the gradient vector - gradient_values.push(quote! { 0.0 }); - } - // Handle variables (identifiers) - Expr::Path(expr_path) => { - if expr_path.path.segments.len() != 1 { - panic!("Unsupported path!"); - } - let _path = &expr_path.path; - // For variables, add 1 to the gradient vector - gradient_values.push(quote! { 1.0 }); - } - Expr::Reference(_) => { - gradient_values.push(quote! { 1.0 }); - } - _ => panic!("Unsupported expression!"), - } +pub(crate) mod kw { + syn::custom_keyword!(grad); + + syn::custom_keyword!(cos); + syn::custom_keyword!(e); + syn::custom_keyword!(ln); + syn::custom_keyword!(sin); + syn::custom_keyword!(tan); } diff --git a/macros/tests/gradient.rs b/macros/tests/gradient.rs index b8f7fa96..b296dba9 100644 --- a/macros/tests/gradient.rs +++ b/macros/tests/gradient.rs @@ -5,12 +5,71 @@ #[cfg(test)] extern crate acme_macros as macros; -use macros::gradient; +use macros::grad; #[test] -fn test_gradient() { +fn test_grad_addition() { let x = 1.0; let y = 2.0; - assert_eq!(gradient!(x + y), [1.0; 2]); - assert_eq!(gradient!(x + y + 1.0), [1.0, 1.0, 0.0]); + let df = grad!(x + y); + // let df = BTreeMap::from_iter(df); + assert_eq!( + df.into_iter().filter(|(k, _v)| k == &x).collect::>(), + [(x, 1.0)] + ); + assert_eq!( + df.into_iter().filter(|(k, _v)| k == &y).collect::>(), + [(y, 1.0)] + ); + let z = 3.0; + let df = grad!(x + y + z); + assert_eq!( + df.into_iter().filter(|(k, _v)| k == &x).collect::>(), + [(x, 1.0)] + ); + assert_eq!( + df.into_iter().filter(|(k, _v)| k == &y).collect::>(), + [(y, 1.0)] + ); + assert_eq!( + df.into_iter().filter(|(k, _v)| k == &z).collect::>(), + [(z, 1.0)] + ); +} + +#[test] +fn test_grad_multiply() { + let x = 1.0; + let y = 2.0; + let df = grad!(x * y); + assert_eq!( + df.into_iter().filter(|(k, _v)| k == &x).collect::>(), + [(x, 2.0)] + ); + assert_eq!( + df.into_iter().filter(|(k, _v)| k == &y).collect::>(), + [(y, 1.0)] + ); + let df = grad!(x * y + 3.0); + assert_eq!( + df.into_iter().filter(|(k, _v)| k == &x).collect::>(), + [(x, 2.0)] + ); + assert_eq!( + df.into_iter().filter(|(k, _v)| k == &y).collect::>(), + [(y, 1.0)] + ); +} + +#[ignore = "Needs to be fixed"] +#[test] +fn test_grad_mixed() { + let x = 1.0; + let y = 2.0; + let df = grad!(y * (x + y)); + // assert_eq!(df.into_iter().filter(|(k, _v)| k == &x).collect::>(), [(x, 2.0)]); + assert_eq!( + df.into_iter().filter(|(k, _v)| k == &y).collect::>(), + [(y, 5.0)] + ); } diff --git a/macros/tests/partial.rs b/macros/tests/partial.rs index 4d67be96..7b41f479 100644 --- a/macros/tests/partial.rs +++ b/macros/tests/partial.rs @@ -7,10 +7,57 @@ extern crate acme_macros as macros; use macros::partial; +macro_rules! partials { + ($($x:ident),* : $f:expr) => { + { + let mut store = Vec::new(); + $( + store.push(partial!($x: $f)); + )* + store + } + }; +} + +#[test] +fn test_add() { + let (x, y) = (1_f64, 2_f64); + assert_eq!(partial!(x: x + y), 1.0); + assert_eq!(partial!(y: x += y), 1.0); + assert_eq!(partials!(x, y: x + y + 3.0), [1.0; 2]); +} + +#[test] +fn test_div() { + let (x, y) = (1_f64, 2_f64); + + assert_eq!(partial!(x: x / y), 1.0 / 2.0); + assert_eq!(partial!(y: x / y), -1.0 / 4.0); +} + #[test] -fn test_partial() { - let x = 1.0; - let y = 2.0; - assert_eq!(partial!(x: x + y;), 1.0); - assert_eq!(partial!(y: x + y;), 1.0); +fn test_mul() { + let (x, y) = (1_f64, 2_f64); + + assert_eq!(partial!(x: x * y), 2.0); + assert_eq!(partial!(y: x * y), 1.0); + assert_eq!(partial!(y: x * y + 3.0), 1.0); +} + +#[test] +fn test_sub() { + let (x, y) = (1_f64, 2_f64); + + assert_eq!(partial!(x: x - y), 1.0); + assert_eq!(partial!(y: x - y), -1.0); +} + +#[test] +fn test_chain_rule() { + let (x, y) = (1_f64, 2_f64); + + assert_eq!(partial!(x: y * (x + y)), y); + assert_eq!(partial!(y: y * (x + y)), 2_f64 * y + x); + assert_eq!(partial!(x: (x + y) * y), y); + assert_eq!(partial!(y: (x + y) * y), 2_f64 * y + x); } diff --git a/scripts/win/setup.cmd b/scripts/win/setup.cmd new file mode 100644 index 00000000..0be7df19 --- /dev/null +++ b/scripts/win/setup.cmd @@ -0,0 +1,2 @@ +rustup default nightly +set RUSTFLAGS="--cfg procmacro2_semver_exempt" \ No newline at end of file diff --git a/tensor/src/data/mod.rs b/tensor/src/data/mod.rs index 990ae609..7c29a97b 100644 --- a/tensor/src/data/mod.rs +++ b/tensor/src/data/mod.rs @@ -7,4 +7,35 @@ pub use self::scalar::*; pub(crate) mod scalar; #[cfg(test)] -mod tests {} +mod tests { + // use super::*; + + macro_rules! Scalar { + (complex) => { + Scalar!(cf64) + }; + (float) => { + Scalar!(f64) + }; + (cf64) => { + Complex + }; + (cf32) => { + Complex + }; + (f64) => { + f64 + }; + (f32) => { + f32 + }; + + } + + #[test] + fn test_scalar() { + let a: Scalar!(f64); + a = 3.0; + assert_eq!(a, 3_f64); + } +} diff --git a/tensor/src/data/scalar.rs b/tensor/src/data/scalar.rs index 4ebdcb52..bce4b6e6 100644 --- a/tensor/src/data/scalar.rs +++ b/tensor/src/data/scalar.rs @@ -153,13 +153,13 @@ where } macro_rules! impl_scalar { - ($t:ty) => { - impl Scalar for $t { - type Complex = Complex<$t>; - type Real = $t; + ($re:ty) => { + impl Scalar for $re { + type Complex = Complex<$re>; + type Real = $re; fn conj(&self) -> Self::Complex { - Complex::new(*self, -<$t>::default()) + Complex::new(*self, -<$re>::default()) } fn re(&self) -> Self::Real { @@ -167,47 +167,47 @@ macro_rules! impl_scalar { } fn cos(self) -> Self { - <$t>::cos(self) + <$re>::cos(self) } fn cosh(self) -> Self { - <$t>::cosh(self) + <$re>::cosh(self) } fn exp(self) -> Self { - <$t>::exp(self) + <$re>::exp(self) } fn ln(self) -> Self { - <$t>::ln(self) + <$re>::ln(self) } fn pow(self, exp: Self) -> Self { - <$t>::powf(self, exp) + <$re>::powf(self, exp) } fn powc(self, exp: Self::Complex) -> Self::Complex { - Complex::new(self, <$t>::default()).powc(exp) + Complex::new(self, <$re>::default()).powc(exp) } fn powf(self, exp: Self::Real) -> Self { - <$t>::powf(self, exp) + <$re>::powf(self, exp) } fn powi(self, exp: i32) -> Self { - <$t>::powi(self, exp) + <$re>::powi(self, exp) } fn sin(self) -> Self { - <$t>::sin(self) + <$re>::sin(self) } fn sinh(self) -> Self { - <$t>::sinh(self) + <$re>::sinh(self) } fn sqrt(self) -> Self { - <$t>::sqrt(self) + <$re>::sqrt(self) } fn square(self) -> Self::Real { @@ -215,11 +215,11 @@ macro_rules! impl_scalar { } fn tan(self) -> Self { - <$t>::tan(self) + <$re>::tan(self) } fn tanh(self) -> Self { - <$t>::tanh(self) + <$re>::tanh(self) } } };