-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from Rust-Scientific-Computing/1-implement-tensor
Implement Tensor struct
- Loading branch information
Showing
11 changed files
with
3,174 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pub type Axes = Vec<usize>; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
use std::fmt; | ||
use std::ops::{Index, IndexMut}; | ||
|
||
use crate::error::ShapeError; | ||
|
||
#[derive(Debug, Clone, PartialEq)] | ||
pub struct Coordinate { | ||
indices: Vec<usize>, | ||
} | ||
|
||
impl Coordinate { | ||
pub fn new(indices: Vec<usize>) -> Result<Self, ShapeError> { | ||
if indices.is_empty() { | ||
return Err(ShapeError::new("Coordinate cannot be empty")); | ||
} | ||
Ok(Self { indices }) | ||
} | ||
|
||
pub fn order(&self) -> usize { | ||
self.indices.len() | ||
} | ||
|
||
pub fn iter(&self) -> std::slice::Iter<'_, usize> { | ||
self.indices.iter() | ||
} | ||
|
||
pub fn insert(&self, index: usize, axis: usize) -> Self { | ||
let mut new_indices = self.indices.clone(); | ||
new_indices.insert(index, axis); | ||
Self { | ||
indices: new_indices, | ||
} | ||
} | ||
} | ||
|
||
impl Index<usize> for Coordinate { | ||
type Output = usize; | ||
|
||
fn index(&self, index: usize) -> &Self::Output { | ||
&self.indices[index] | ||
} | ||
} | ||
|
||
impl IndexMut<usize> for Coordinate { | ||
fn index_mut(&mut self, index: usize) -> &mut Self::Output { | ||
&mut self.indices[index] | ||
} | ||
} | ||
|
||
impl fmt::Display for Coordinate { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
use itertools::Itertools; | ||
let idxs = self.indices.iter().map(|&x| format!("{}", x)).join(", "); | ||
write!(f, "({})", idxs) | ||
} | ||
} | ||
|
||
#[macro_export] | ||
macro_rules! coord { | ||
($($index:expr),*) => { | ||
{ | ||
use $crate::coordinate::Coordinate; | ||
Coordinate::new(vec![$($index),*]) | ||
} | ||
}; | ||
|
||
($index:expr; $count:expr) => { | ||
{ | ||
use $crate::coordinate::Coordinate; | ||
Coordinate::new(vec![$index; $count]) | ||
} | ||
}; | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[test] | ||
fn test_order() { | ||
let coord = coord![1, 2, 3].unwrap(); | ||
assert_eq!(coord.order(), 3); | ||
} | ||
|
||
#[test] | ||
fn test_iter() { | ||
let coord = coord![1, 2, 3].unwrap(); | ||
let mut iter = coord.iter(); | ||
assert_eq!(iter.next(), Some(&1)); | ||
assert_eq!(iter.next(), Some(&2)); | ||
assert_eq!(iter.next(), Some(&3)); | ||
assert_eq!(iter.next(), None); | ||
} | ||
|
||
#[test] | ||
fn test_insert() { | ||
let coord = coord![1, 2, 3].unwrap(); | ||
let new_coord = coord.insert(1, 4); | ||
assert_eq!(new_coord, coord![1, 4, 2, 3].unwrap()); | ||
} | ||
|
||
#[test] | ||
fn test_index() { | ||
let coord = coord![1, 2, 3].unwrap(); | ||
assert_eq!(coord[0], 1); | ||
assert_eq!(coord[1], 2); | ||
assert_eq!(coord[2], 3); | ||
} | ||
|
||
#[test] | ||
fn test_index_mut() { | ||
let mut coord = coord![1, 2, 3].unwrap(); | ||
coord[1] = 4; | ||
assert_eq!(coord[1], 4); | ||
} | ||
|
||
#[test] | ||
fn test_display() { | ||
let coord = coord![1, 2, 3].unwrap(); | ||
assert_eq!(format!("{}", coord), "(1, 2, 3)"); | ||
} | ||
|
||
#[test] | ||
fn test_coord_macro() { | ||
let coord = coord![1, 2, 3].unwrap(); | ||
assert_eq!(coord, Coordinate::new(vec![1, 2, 3]).unwrap()); | ||
|
||
let coord_repeated = coord![1; 3].unwrap(); | ||
assert_eq!(coord_repeated, Coordinate::new(vec![1, 1, 1]).unwrap()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
use std::fmt; | ||
|
||
#[derive(Debug, Clone)] | ||
pub struct ShapeError { | ||
reason: String, | ||
} | ||
|
||
impl ShapeError { | ||
pub fn new(reason: &str) -> Self { | ||
ShapeError { | ||
reason: reason.to_string(), | ||
} | ||
} | ||
} | ||
|
||
impl fmt::Display for ShapeError { | ||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
write!(f, "ShapeError: {}", self.reason) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
use crate::coord; | ||
use crate::coordinate::Coordinate; | ||
use crate::shape::Shape; | ||
use std::cmp::max; | ||
|
||
pub struct IndexIterator { | ||
shape: Shape, | ||
current: Coordinate, | ||
done: bool, | ||
} | ||
|
||
impl IndexIterator { | ||
pub fn new(shape: &Shape) -> Self { | ||
// (shape.order() == 0) => `next` returns None before `current` is used | ||
let current = coord![0; max(shape.order(), 1)].unwrap(); | ||
IndexIterator { | ||
shape: shape.clone(), | ||
current, | ||
done: false, | ||
} | ||
} | ||
} | ||
|
||
impl Iterator for IndexIterator { | ||
type Item = Coordinate; | ||
|
||
fn next(&mut self) -> Option<Self::Item> { | ||
if self.done || self.shape.order() == 0 { | ||
return None; | ||
} | ||
|
||
let result = self.current.clone(); | ||
|
||
for i in (0..self.shape.order()).rev() { | ||
if self.current[i] + 1 < self.shape[i] { | ||
self.current[i] += 1; | ||
break; | ||
} else { | ||
self.current[i] = 0; | ||
if i == 0 { | ||
self.done = true; | ||
} | ||
} | ||
} | ||
|
||
Some(result) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use crate::shape; | ||
|
||
#[test] | ||
fn test_index_iterator() { | ||
let shape = shape![2, 3].unwrap(); | ||
let mut iter = IndexIterator::new(&shape); | ||
|
||
assert_eq!(iter.next(), Some(coord![0, 0].unwrap())); | ||
assert_eq!(iter.next(), Some(coord![0, 1].unwrap())); | ||
assert_eq!(iter.next(), Some(coord![0, 2].unwrap())); | ||
assert_eq!(iter.next(), Some(coord![1, 0].unwrap())); | ||
assert_eq!(iter.next(), Some(coord![1, 1].unwrap())); | ||
assert_eq!(iter.next(), Some(coord![1, 2].unwrap())); | ||
assert_eq!(iter.next(), None); | ||
} | ||
|
||
#[test] | ||
fn test_index_iterator_single_dimension() { | ||
let shape = shape![4].unwrap(); | ||
let mut iter = IndexIterator::new(&shape); | ||
|
||
assert_eq!(iter.next(), Some(coord![0].unwrap())); | ||
assert_eq!(iter.next(), Some(coord![1].unwrap())); | ||
assert_eq!(iter.next(), Some(coord![2].unwrap())); | ||
assert_eq!(iter.next(), Some(coord![3].unwrap())); | ||
assert_eq!(iter.next(), None); | ||
} | ||
|
||
#[test] | ||
fn test_index_iterator_empty_tensor() { | ||
let shape = shape![].unwrap(); | ||
let mut iter = IndexIterator::new(&shape); | ||
|
||
assert_eq!(iter.next(), None); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,9 @@ | ||
pub fn add(left: usize, right: usize) -> usize { | ||
left + right | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[test] | ||
fn it_works() { | ||
let result = add(2, 2); | ||
assert_eq!(result, 4); | ||
} | ||
} | ||
pub mod axes; | ||
pub mod coordinate; | ||
pub mod error; | ||
pub mod iter; | ||
pub mod matrix; | ||
pub mod shape; | ||
pub mod storage; | ||
pub mod tensor; | ||
pub mod vector; |
Oops, something went wrong.