Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Joe McCain III <jo3mccain@icloud.com>
  • Loading branch information
FL03 committed Mar 22, 2024
1 parent 8363af2 commit bc16241
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 11 deletions.
7 changes: 4 additions & 3 deletions tensor/src/impls/grad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,14 @@ where
*store.entry(rhs.id()).or_insert(rhs.zeros_like()) += &grad;
}
BinaryOp::Mul => {
*store.entry(lhs.id()).or_insert(lhs.zeros_like()) += &grad * rhs.as_ref();
*store.entry(rhs.id()).or_insert(rhs.zeros_like()) += &grad * lhs.as_ref();
*store.entry(lhs.id()).or_insert(lhs.zeros_like()) +=
&grad * rhs.as_ref();
*store.entry(rhs.id()).or_insert(rhs.zeros_like()) +=
&grad * lhs.as_ref();
}
_ => todo!(),
},
TensorOp::Unary(_a, kind) => match kind {

_ => todo!(),
},
_ => {}
Expand Down
13 changes: 9 additions & 4 deletions tensor/src/impls/reshape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Appellation: reshape <impls>
Contrib: FL03 <jo3mccain@icloud.com>
*/
use crate::prelude::IntoShape;
use crate::prelude::{IntoShape, ShapeError, TensorResult};
use crate::tensor::TensorBase;

impl<T> TensorBase<T>
Expand All @@ -17,9 +17,14 @@ where
unimplemented!()
}

pub fn reshape(&self, shape: impl IntoShape) -> Self {
let _shape = shape.into_shape();
pub fn reshape(self, shape: impl IntoShape) -> TensorResult<Self> {
let mut tensor = self;
let shape = shape.into_shape();
if tensor.elements() != shape.elements() {
return Err(ShapeError::MismatchedElements.into());
}

unimplemented!()
tensor.layout.shape = shape;
Ok(tensor)
}
}
1 change: 1 addition & 0 deletions tensor/src/shape/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub type ShapeResult<T = ()> = std::result::Result<T, ShapeError>;
pub enum ShapeError {
IncompatibleShapes,
InvalidShape,
MismatchedElements,
}

unsafe impl Send for ShapeError {}
Expand Down
1 change: 1 addition & 0 deletions tensor/src/shape/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ where

pub(crate) mod prelude {
pub use super::dim::*;
pub use super::error::*;
pub use super::shape::*;
pub use super::stride::*;
pub use super::IntoShape;
Expand Down
3 changes: 3 additions & 0 deletions tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ impl<T> TensorBase<T> {
pub fn from_vec(kind: TensorMode, shape: impl IntoShape, store: Vec<T>) -> Self {
from_vec(kind, shape, store)
}
pub fn elements(&self) -> usize {
self.layout.elements()
}
/// Returns the unique identifier of the tensor.
pub fn id(&self) -> TensorId {
self.id
Expand Down
5 changes: 1 addition & 4 deletions tensor/tests/backward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@ fn test_backward() {
let c = &a + &b;
let grad = c.grad();

assert_eq!(
grad[&a.id()],
Tensor::ones(shape),
);
assert_eq!(grad[&a.id()], Tensor::ones(shape),);
assert_eq!(grad[&b.id()], Tensor::ones(shape));

let a = Tensor::<f64>::ones(shape).variable();
Expand Down
10 changes: 10 additions & 0 deletions tensor/tests/composition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ fn test_tensor() {
assert_ne!(&a, &b);
}

#[test]
fn test_reshape() {
let shape = (2, 2);
let a = Tensor::<f64>::ones(shape);
let b = a.clone().reshape((4,)).unwrap();

assert_ne!(&a.shape(), &b.shape());
assert_eq!(&a.elements(), &b.elements());
}

#[test]
fn test_arange() {
let exp = Shape::from(10);
Expand Down

0 comments on commit bc16241

Please sign in to comment.