Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Joe McCain III <jo3mccain@icloud.com>
  • Loading branch information
FL03 committed Apr 2, 2024
1 parent 46eaab8 commit deed2a7
Show file tree
Hide file tree
Showing 21 changed files with 102 additions and 56 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ keywords = ["acme", "autodiff", "mathematics", "tensor"]
license = "Apache-2.0"
repository = "https://github.com/FL03/acme"
readme = "README.md"
version = "0.3.0"
# version = "0.3.0-nightly.4"
# version = "0.3.0"
version = "0.3.0-nightly.4"

[workspace]
default-members = [
Expand Down
20 changes: 10 additions & 10 deletions acme/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,16 @@ required-features = ["macros"]
[build-dependencies]

[dependencies]
acme-core = { path = "../core", version = "0.3.0" }
acme-derive = { optional = true, path = "../derive", version = "0.3.0" }
acme-graphs = { optional = true, path = "../graphs", version = "0.3.0" }
acme-macros = { optional = true, path = "../macros", version = "0.3.0" }
acme-tensor = { optional = true, path = "../tensor", version = "0.3.0" }
# acme-core = { path = "../core", version = "0.3.0-nightly.4" }
# acme-derive = { optional = true, path = "../derive", version = "0.3.0-nightly.4" }
# acme-graphs = { optional = true, path = "../graphs", version = "0.3.0-nightly.4" }
# acme-macros = { optional = true, path = "../macros", version = "0.3.0-nightly.4" }
# acme-tensor = { optional = true, path = "../tensor", version = "0.3.0-nightly.4" }
# acme-core = { path = "../core", version = "0.3.0" }
# acme-derive = { optional = true, path = "../derive", version = "0.3.0" }
# acme-graphs = { optional = true, path = "../graphs", version = "0.3.0" }
# acme-macros = { optional = true, path = "../macros", version = "0.3.0" }
# acme-tensor = { optional = true, path = "../tensor", version = "0.3.0" }
acme-core = { path = "../core", version = "0.3.0-nightly.4" }
acme-derive = { optional = true, path = "../derive", version = "0.3.0-nightly.4" }
acme-graphs = { optional = true, path = "../graphs", version = "0.3.0-nightly.4" }
acme-macros = { optional = true, path = "../macros", version = "0.3.0-nightly.4" }
acme-tensor = { optional = true, path = "../tensor", version = "0.3.0-nightly.4" }

[dev-dependencies]
approx = "0.5"
Expand Down
1 change: 0 additions & 1 deletion acme/benches/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,3 @@ fn bench_iter_rev(b: &mut Bencher) {
let tensor = Tensor::linspace(0f64, n as f64, n);
b.iter(|| tensor.strided().rev().take(n))
}

2 changes: 0 additions & 2 deletions core/src/ops/binary/kinds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,3 @@ impl BinaryOp {
}
}
}


4 changes: 2 additions & 2 deletions graphs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ strum.workspace = true

[dependencies.acme-core]
path = "../core"
version = "0.3.0"
# version = "0.3.0-nightly.4"
# version = "0.3.0"
version = "0.3.0-nightly.4"

[package.metadata.docs.rs]
all-features = true
Expand Down
3 changes: 1 addition & 2 deletions graphs/src/ops/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ macro_rules! impl_binary_op {
$(
impl_binary_op!($op, $bound, $operator);
)*

};
($op:ident, $bound:ident, $operator:tt) => {
operator!($op);
Expand Down Expand Up @@ -128,7 +128,6 @@ operators!(Arithmetic; {Add: Addition => add, Div: Division => div, Mul: Multipl

impl_binary_op!((Addition, Add, +), (Division, Div, /), (Multiplication, Mul, *), (Remainder, Rem, %), (Subtraction, Sub, -));


impl Arithmetic {
pub fn new(op: Arithmetic) -> Self {
op
Expand Down
1 change: 0 additions & 1 deletion graphs/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ pub trait Operator {
}

impl Operator for Box<dyn Operator> {

fn name(&self) -> String {
self.as_ref().name()
}
Expand Down
4 changes: 2 additions & 2 deletions tensor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ strum = { features = ["derive"], version = "0.26" }

[dependencies.acme-core]
path = "../core"
version = "0.3.0"
# version = "0.3.0-nightly.4"
# version = "0.3.0"
version = "0.3.0-nightly.4"

[dev-dependencies]
approx = "0.5"
Expand Down
1 change: 0 additions & 1 deletion tensor/src/actions/iter/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,3 @@ pub struct BaseIter<'a, T> {
data: &'a [T],
index: usize,
}

7 changes: 6 additions & 1 deletion tensor/src/actions/iter/strides.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@ impl<'a> DoubleEndedIterator for Strided<'a> {
} else {
return None;
};
let position = self.shape.iter().zip(pos.iter()).map(|(s, p)| s - p).collect();
let position = self
.shape
.iter()
.zip(pos.iter())
.map(|(s, p)| s - p)
.collect();
let scope = self.index(&position);
println!("{:?}", &position);
Some((position, scope))
Expand Down
6 changes: 2 additions & 4 deletions tensor/src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@ pub trait BackendStorage {

#[allow(unused_imports)]
pub(crate) mod prelude {
pub use super::{Backend, BackendStorage};
pub use super::devices::Device;
pub use super::{Backend, BackendStorage};
}

#[cfg(test)]
mod tests {

}
mod tests {}
1 change: 0 additions & 1 deletion tensor/src/data/container.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ where
dbg!("Implement a custom iter for ContainerBase");
self.as_slice_memory_order().unwrap().iter()
}


pub fn layout(&self) -> &Layout {
&self.layout
Expand Down
5 changes: 4 additions & 1 deletion tensor/src/impls/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
use crate::prelude::Scalar;
use crate::tensor::TensorBase;

impl<T> TensorBase<T> where T: Scalar {
impl<T> TensorBase<T>
where
T: Scalar,
{
pub fn sum(&self) -> T {
self.data().iter().copied().sum()
}
Expand Down
76 changes: 58 additions & 18 deletions tensor/src/impls/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,16 @@ use acme::ops::binary::BinaryOp;
use core::ops;
use num::traits::Pow;


pub(crate) fn broadcast_scalar_op<F, T>(lhs: &TensorBase<T>, rhs: &TensorBase<T>, op: BinaryOp, f: F) -> TensorBase<T> where F: Fn(T, T) -> T, T: Copy + Default {
pub(crate) fn broadcast_scalar_op<F, T>(
lhs: &TensorBase<T>,
rhs: &TensorBase<T>,
op: BinaryOp,
f: F,
) -> TensorBase<T>
where
F: Fn(T, T) -> T,
T: Copy + Default,
{
let mut lhs = lhs.clone();
let mut rhs = rhs.clone();
if lhs.is_scalar() {
Expand All @@ -19,16 +27,27 @@ pub(crate) fn broadcast_scalar_op<F, T>(lhs: &TensorBase<T>, rhs: &TensorBase<T>
rhs = rhs.broadcast(lhs.shape());
}
let shape = lhs.shape().clone();
let store = lhs.data().iter().zip(rhs.data().iter()).map(|(a, b)| f(*a, *b)).collect();
let store = lhs
.data()
.iter()
.zip(rhs.data().iter())
.map(|(a, b)| f(*a, *b))
.collect();
let op = TensorExpr::binary(lhs, rhs, op);
from_vec_with_op(false, op, shape, store)
}

fn check_shapes_or_scalar<T>(lhs: &TensorBase<T>, rhs: &TensorBase<T>) where T: Clone + Default {
fn check_shapes_or_scalar<T>(lhs: &TensorBase<T>, rhs: &TensorBase<T>)
where
T: Clone + Default,
{
let is_scalar = lhs.is_scalar() || rhs.is_scalar();
debug_assert!(is_scalar || lhs.shape() == rhs.shape(), "Shape Mismatch: {:?} != {:?}", lhs.shape(), rhs.shape());


debug_assert!(
is_scalar || lhs.shape() == rhs.shape(),
"Shape Mismatch: {:?} != {:?}",
lhs.shape(),
rhs.shape()
);
}

macro_rules! check {
Expand All @@ -39,32 +58,50 @@ macro_rules! check {
};
}

impl<T> TensorBase<T> where T: Scalar {
impl<T> TensorBase<T>
where
T: Scalar,
{
pub fn apply_binary(&self, other: &Self, op: BinaryOp) -> Self {
check_shapes_or_scalar(self, other);
let shape = self.shape();
let store = self.data().iter().zip(other.data().iter()).map(|(a, b)| *a + *b).collect();
let store = self
.data()
.iter()
.zip(other.data().iter())
.map(|(a, b)| *a + *b)
.collect();
let op = TensorExpr::binary(self.clone(), other.clone(), op);
from_vec_with_op(false, op, shape, store)
}

pub fn apply_binaryf<F>(&self, other: &Self, op: BinaryOp, f: F) -> Self where F: Fn(T, T) -> T {
pub fn apply_binaryf<F>(&self, other: &Self, op: BinaryOp, f: F) -> Self
where
F: Fn(T, T) -> T,
{
check_shapes_or_scalar(self, other);
let shape = self.shape();
let store = self.data().iter().zip(other.data().iter()).map(|(a, b)| f(*a, *b)).collect();
let store = self
.data()
.iter()
.zip(other.data().iter())
.map(|(a, b)| f(*a, *b))
.collect();
let op = TensorExpr::binary(self.clone(), other.clone(), op);
from_vec_with_op(false, op, shape, store)
}
}

impl<T> TensorBase<T> where T: Scalar {
impl<T> TensorBase<T>
where
T: Scalar,
{
pub fn pow(&self, exp: T) -> Self {
let shape = self.shape();
let store = self.data().iter().copied().map(|a| a.pow(exp)).collect();
let op = TensorExpr::binary_scalar(self.clone(), exp, BinaryOp::Pow);
from_vec_with_op(false, op, shape, store)
}

}

impl<T> Pow<T> for TensorBase<T>
Expand Down Expand Up @@ -199,7 +236,7 @@ macro_rules! impl_binary_op {
}
}
};

}

macro_rules! impl_assign_op {
Expand Down Expand Up @@ -243,7 +280,7 @@ macro_rules! impl_binary_method {
let op = TensorExpr::binary_scalar(self.clone(), other.clone(), BinaryOp::$variant);
from_vec_with_op(false, op, shape, store)
}

};
(tensor: $variant:ident, $method:ident, $op:tt) => {
pub fn $method(&self, other: &Self) -> Self {
Expand All @@ -253,7 +290,7 @@ macro_rules! impl_binary_method {
let op = TensorExpr::binary(self.clone(), other.clone(), BinaryOp::$variant);
from_vec_with_op(false, op, shape, store)
}

};
}

Expand All @@ -265,7 +302,10 @@ impl_assign_op!(MulAssign, mul_assign, Mul, *);
impl_assign_op!(RemAssign, rem_assign, Rem, %);
impl_assign_op!(SubAssign, sub_assign, Sub, -);

impl<T> TensorBase<T> where T: Scalar {
impl<T> TensorBase<T>
where
T: Scalar,
{
impl_binary_method!(tensor: Add, add, +);
impl_binary_method!(scalar: Add, add_scalar, +);
}
}
2 changes: 1 addition & 1 deletion tensor/src/ops/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use acme::prelude::{BinaryOp, UnaryOp};

pub type BoxTensor<T = f64> = Box<TensorBase<T>>;

#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd,)]
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[non_exhaustive]
pub enum TensorExpr<T> {
Binary(BoxTensor<T>, BoxTensor<T>, BinaryOp),
Expand Down
2 changes: 0 additions & 2 deletions tensor/src/shape/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,6 @@ unsafe impl Send for Shape {}

unsafe impl Sync for Shape {}


impl From<()> for Shape {
fn from(_: ()) -> Self {
Self::default()
Expand Down Expand Up @@ -415,7 +414,6 @@ impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
}
}


// macro_rules! tuple_vec {
// ($($n:tt),*) => {
// vec![$($n,)*]
Expand Down
1 change: 0 additions & 1 deletion tensor/src/stats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ pub trait SummaryStatistics<T> {
}

pub trait TensorStats<T>: SummaryStatistics<T> {

/// Compute the mean along the specified axis.
fn mean_axis(&self, axis: Axis) -> T;
}
Expand Down
2 changes: 1 addition & 1 deletion tensor/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ pub(crate) mod prelude {
pub use super::kinds::*;
pub use super::order::*;
pub use super::tensors::*;
}
}
12 changes: 11 additions & 1 deletion tensor/src/types/tensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@ use serde::{Deserialize, Serialize};
use strum::{Display, EnumCount, EnumDiscriminants, EnumIs, EnumIter, EnumString, VariantNames};

#[derive(Clone, Debug, EnumDiscriminants, Eq, PartialEq)]
#[strum_discriminants(derive(Display, EnumCount, EnumIs, EnumIter, EnumString, Hash, Ord, PartialOrd, VariantNames))]
#[strum_discriminants(derive(
Display,
EnumCount,
EnumIs,
EnumIter,
EnumString,
Hash,
Ord,
PartialOrd,
VariantNames
))]
#[strum_discriminants(name(TensorType))]
#[cfg_attr(feature = "serde", strum_discriminants(derive(Deserialize, Serialize)))]
pub enum Tensors<T> {
Expand Down
2 changes: 1 addition & 1 deletion tensor/tests/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ fn test_product() {
let shape = (2, 2).into_shape();
let a = Tensor::fill(shape, 2f64);
assert_eq!(a.product(), 16.0);
}
}
2 changes: 1 addition & 1 deletion tensor/tests/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ fn test_product() {
let shape = (2, 2).into_shape();
let a = Tensor::fill(shape, 2f64);
assert_eq!(a.product(), 16.0);
}
}

0 comments on commit deed2a7

Please sign in to comment.