Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Joe McCain III <jo3mccain@icloud.com>
  • Loading branch information
FL03 committed Apr 2, 2024
1 parent 083fe28 commit 46eaab8
Show file tree
Hide file tree
Showing 21 changed files with 472 additions and 296 deletions.
1 change: 1 addition & 0 deletions acme/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ acme-tensor = { optional = true, path = "../tensor", version = "0.3.0" }

[dev-dependencies]
approx = "0.5"
lazy_static = "1"
num = "0.4"
rand = "0.8"

Expand Down
20 changes: 17 additions & 3 deletions acme/benches/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,29 @@
extern crate acme;
extern crate test;

use acme::prelude::{IntoShape, Tensor};
use acme::prelude::{IntoShape, Shape, Tensor};
use lazy_static::lazy_static;
use test::Bencher;

lazy_static! {
static ref SHAPE_3D: Shape = SHAPE_3D_PATTERN.into_shape();
}

const SHAPE_3D_PATTERN: (usize, usize, usize) = (100, 10, 1);

#[bench]
fn tensor_iter(b: &mut Bencher) {
let shape = (20, 20, 20).into_shape();
fn bench_iter(b: &mut Bencher) {
let shape = SHAPE_3D.clone();
let n = shape.size();
let tensor = Tensor::linspace(0f64, n as f64, n);
b.iter(|| tensor.strided().take(n))
}

#[bench]
fn bench_iter_rev(b: &mut Bencher) {
let shape = SHAPE_3D.clone();
let n = shape.size();
let tensor = Tensor::linspace(0f64, n as f64, n);
b.iter(|| tensor.strided().rev().take(n))
}

2 changes: 2 additions & 0 deletions core/src/ops/binary/kinds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,5 @@ impl BinaryOp {
}
}
}


84 changes: 43 additions & 41 deletions graphs/src/ops/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Appellation: arithmetic <mod>
Contrib: FL03 <jo3mccain@icloud.com>
*/
use super::BinaryOperation;
use super::{BinaryOperation, Operator};
use num::traits::NumOps;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand All @@ -24,6 +24,12 @@ macro_rules! operator {
stringify!($op).to_lowercase()
}
}

impl Operator for $op {
fn name(&self) -> String {
self.name()
}
}
};
($($op:ident),*) => {
$(
Expand All @@ -34,10 +40,7 @@ macro_rules! operator {
}

macro_rules! operators {
(class $group:ident; {$($op:ident: $variant:ident),*}) => {
$(
operator!($op);
)*
($group:ident; {$($variant:ident: $op:ident => $method:ident),*}) => {
#[derive(
Clone,
Copy,
Expand Down Expand Up @@ -65,11 +68,35 @@ macro_rules! operators {
$variant($op),
)*
}

impl $group {
$(
pub fn $method() -> Self {
Self::$variant($op::new())
}
)*

pub fn name(&self) -> String {
match self {
$(
$group::$variant(op) => op.name(),
)*
}
}
}
};
}

macro_rules! impl_binary_op {
($(($op:ident, $bound:ident, $operator:tt)),*) => {
$(
impl_binary_op!($op, $bound, $operator);
)*

};
($op:ident, $bound:ident, $operator:tt) => {
operator!($op);

impl<A, B, C> BinaryOperation<A, B> for $op
where
A: core::ops::$bound<B, Output = C>,
Expand All @@ -82,6 +109,8 @@ macro_rules! impl_binary_op {
}
};
(expr $op:ident, $bound:ident, $exp:expr) => {
operator!($op);

impl<A, B, C> BinaryOperation<A, B> for $op
where
A: core::ops::$bound<B, Output = C>,
Expand All @@ -95,45 +124,21 @@ macro_rules! impl_binary_op {
};
}

// operator!(Addition, Division, Multiplication, Subtraction);
operators!(class Arithmetic; {Addition: Add, Division: Div, Multiplication: Mul, Remainder: Rem, Subtraction: Sub});

impl_binary_op!(Addition, Add, +);
operators!(Arithmetic; {Add: Addition => add, Div: Division => div, Mul: Multiplication => mul, Rem: Remainder => rem, Sub: Subtraction => sub});

impl_binary_op!(Division, Div, /);
impl_binary_op!((Addition, Add, +), (Division, Div, /), (Multiplication, Mul, *), (Remainder, Rem, %), (Subtraction, Sub, -));

impl_binary_op!(Multiplication, Mul, *);

impl_binary_op!(Remainder, Rem, %);

impl_binary_op!(Subtraction, Sub, -);

impl Arithmetic {
pub fn new(op: Arithmetic) -> Self {
op
}

pub fn add() -> Self {
Self::Add(Addition::new())
}

pub fn div() -> Self {
Self::Div(Division::new())
}

pub fn mul() -> Self {
Self::Mul(Multiplication::new())
}

pub fn sub() -> Self {
Self::Sub(Subtraction::new())
}

pub fn op<A, B, C>(&self) -> Box<dyn BinaryOperation<A, B, Output = C>>
pub fn into_op<A, B, C>(self) -> Box<dyn BinaryOperation<A, B, Output = C>>
where
A: NumOps<B, C>,
{
match self.clone() {
match self {
Arithmetic::Add(op) => Box::new(op),
Arithmetic::Div(op) => Box::new(op),
Arithmetic::Mul(op) => Box::new(op),
Expand All @@ -142,14 +147,11 @@ impl Arithmetic {
}
}

pub fn name(&self) -> String {
match self {
Arithmetic::Add(op) => op.name(),
Arithmetic::Div(op) => op.name(),
Arithmetic::Mul(op) => op.name(),
Arithmetic::Rem(op) => op.name(),
Arithmetic::Sub(op) => op.name(),
}
pub fn op<A, B, C>(&self) -> Box<dyn BinaryOperation<A, B, Output = C>>
where
A: NumOps<B, C>,
{
self.into_op()
}

pub fn eval<A, B, C>(&self, lhs: A, rhs: B) -> C
Expand Down
36 changes: 33 additions & 3 deletions graphs/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,44 @@ pub use self::{arithmetic::*, kinds::*};
pub(crate) mod arithmetic;
pub(crate) mod kinds;

pub trait BinaryOperation<A, B> {
pub trait BinaryOperation<A, B = A> {
type Output;

fn eval(&self, lhs: A, rhs: B) -> Self::Output;
}

impl<S, A, B, C> BinaryOperation<A, B> for S
where
S: Fn(A, B) -> C,
{
type Output = C;

fn eval(&self, lhs: A, rhs: B) -> Self::Output {
self(lhs, rhs)
}
}

impl<A, B, C> BinaryOperation<A, B> for Box<dyn BinaryOperation<A, B, Output = C>> {
type Output = C;

fn eval(&self, lhs: A, rhs: B) -> Self::Output {
self.as_ref().eval(lhs, rhs)
}
}

pub trait Operator {
type Output;
fn boxed(self) -> Box<dyn Operator>
where
Self: Sized + 'static,
{
Box::new(self)
}
fn name(&self) -> String;
}

impl Operator for Box<dyn Operator> {

fn kind(&self) -> String;
fn name(&self) -> String {
self.as_ref().name()
}
}
7 changes: 7 additions & 0 deletions tensor/src/actions/iter/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,10 @@ impl Iter {
self.order
}
}

pub struct BaseIter<'a, T> {
iter: &'a Iter,
data: &'a [T],
index: usize,
}

3 changes: 1 addition & 2 deletions tensor/src/actions/iter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ mod tests {

#[test]
fn test_strided() {
let shape = Shape::from_iter([2, 2]);
let shape = Shape::from_iter([2, 2, 2, 2]);
let n = shape.size();
let exp = Vec::linspace(0f64, n as f64, n);
let tensor = Tensor::linspace(0f64, n as f64, n).reshape(shape).unwrap();
Expand All @@ -58,7 +58,6 @@ mod tests {
}

#[test]
#[ignore = "not implemented"]
fn test_strided_rev() {
let shape = Shape::from_iter([2, 2]);
let n = shape.size();
Expand Down
41 changes: 12 additions & 29 deletions tensor/src/actions/iter/strides.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ impl<'a, T> From<&'a TensorBase<T>> for StrideIter<'a, T> {
pub struct Strided<'a> {
next: Option<usize>,
position: Vec<usize>,
pub(crate) shape: &'a Shape,
pub(crate) stride: &'a Stride,
shape: &'a Shape,
stride: &'a Stride,
}

impl<'a> Strided<'a> {
Expand All @@ -70,8 +70,9 @@ impl<'a> Strided<'a> {
}
}

pub fn index(&self, index: &[usize]) -> usize {
pub(crate) fn index(&self, index: impl AsRef<[usize]>) -> usize {
index
.as_ref()
.iter()
.zip(self.stride.iter())
.map(|(i, s)| i * s)
Expand All @@ -81,33 +82,15 @@ impl<'a> Strided<'a> {

impl<'a> DoubleEndedIterator for Strided<'a> {
fn next_back(&mut self) -> Option<Self::Item> {

let scope = match self.next {
None => return None,
Some(storage_index) => storage_index,
let (pos, _idx) = if let Some(item) = self.next() {
item
} else {
return None;
};
let mut updated = false;
let mut next = scope;
for ((pos, max_i), stride) in self
.position
.iter_mut()
.zip(self.shape.iter())
.zip(self.stride.iter())
{
let next_i = *pos - 1;
if next_i > *max_i {
*pos = next_i;
updated = true;
next -= stride;
break;
} else {
next += *pos * stride;
*pos = 0
}
}
self.next = if updated { Some(next) } else { None };
println!("{:?}", &self.position);
Some((self.position.clone(), scope))
let position = self.shape.iter().zip(pos.iter()).map(|(s, p)| s - p).collect();
let scope = self.index(&position);
println!("{:?}", &position);
Some((position, scope))
// unimplemented!()
}
}
Expand Down
12 changes: 6 additions & 6 deletions tensor/src/impls/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Contrib: FL03 <jo3mccain@icloud.com>
*/
use crate::prelude::IntoShape;
use crate::tensor::{from_vec, TensorBase};
use crate::tensor::{from_vec_with_kind, TensorBase};
use num::traits::real::Real;
use num::traits::{FromPrimitive, NumAssign, One, Zero};

Expand All @@ -30,7 +30,7 @@ where
pub fn fill(shape: impl IntoShape, value: T) -> Self {
let shape = shape.into_shape();
let store = vec![value; shape.size()];
from_vec(false, shape, store)
from_vec_with_kind(false, shape, store)
}
/// Create a tensor, filled with some value, from the current shape
pub fn fill_like(&self, value: T) -> Self {
Expand All @@ -53,7 +53,7 @@ where
store.push(value);
value += step;
}
from_vec(false, store.len(), store)
Self::from_vec(store)
}
/// Create an identity matrix of a certain size
pub fn eye(size: usize) -> Self {
Expand All @@ -63,7 +63,7 @@ where
store.push(if i == j { T::one() } else { T::zero() });
}
}
from_vec(false, (size, size), store)
Self::from_shape_vec((size, size), store)
}
/// Create a tensor with a certain number of elements, evenly spaced
/// between the provided start and end values
Expand All @@ -88,7 +88,7 @@ where
store.push(value.exp2());
value += step;
}
from_vec(false, (store.len(),), store)
from_vec_with_kind(false, (store.len(),), store)
}

pub fn geomspace(start: T, end: T, steps: usize) -> Self
Expand All @@ -104,7 +104,7 @@ where
store.push(value.exp());
value += step;
}
from_vec(false, (store.len(),), store)
from_vec_with_kind(false, (store.len(),), store)
}
}

Expand Down
Loading

0 comments on commit 46eaab8

Please sign in to comment.