Skip to content

Commit

Permalink
Merge pull request #2 from Rust-Scientific-Computing/1-implement-tensor
Browse files Browse the repository at this point in the history
Implement Tensor struct
  • Loading branch information
siliconlad authored Jun 30, 2024
2 parents 15047f2 + 0df55fe commit 87c2d59
Show file tree
Hide file tree
Showing 11 changed files with 3,174 additions and 15 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ version = "0.1.0"
edition = "2021"
description = "A tensor library for scientific computing in Rust"
license = "MIT"
license-file = "LICENSE"
homepage = "https://github.com/Rust-Scientific-Computing/feotensor"
repository = "https://github.com/Rust-Scientific-Computing/feotensor"

[dependencies]
itertools = "0.13.0"
num = "0.4.3"
1 change: 1 addition & 0 deletions src/axes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub type Axes = Vec<usize>;
131 changes: 131 additions & 0 deletions src/coordinate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
use std::fmt;
use std::ops::{Index, IndexMut};

use crate::error::ShapeError;

#[derive(Debug, Clone, PartialEq)]
pub struct Coordinate {
indices: Vec<usize>,
}

impl Coordinate {
pub fn new(indices: Vec<usize>) -> Result<Self, ShapeError> {
if indices.is_empty() {
return Err(ShapeError::new("Coordinate cannot be empty"));
}
Ok(Self { indices })
}

pub fn order(&self) -> usize {
self.indices.len()
}

pub fn iter(&self) -> std::slice::Iter<'_, usize> {
self.indices.iter()
}

pub fn insert(&self, index: usize, axis: usize) -> Self {
let mut new_indices = self.indices.clone();
new_indices.insert(index, axis);
Self {
indices: new_indices,
}
}
}

impl Index<usize> for Coordinate {
type Output = usize;

fn index(&self, index: usize) -> &Self::Output {
&self.indices[index]
}
}

impl IndexMut<usize> for Coordinate {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.indices[index]
}
}

impl fmt::Display for Coordinate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use itertools::Itertools;
let idxs = self.indices.iter().map(|&x| format!("{}", x)).join(", ");
write!(f, "({})", idxs)
}
}

#[macro_export]
macro_rules! coord {
($($index:expr),*) => {
{
use $crate::coordinate::Coordinate;
Coordinate::new(vec![$($index),*])
}
};

($index:expr; $count:expr) => {
{
use $crate::coordinate::Coordinate;
Coordinate::new(vec![$index; $count])
}
};
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_order() {
let coord = coord![1, 2, 3].unwrap();
assert_eq!(coord.order(), 3);
}

#[test]
fn test_iter() {
let coord = coord![1, 2, 3].unwrap();
let mut iter = coord.iter();
assert_eq!(iter.next(), Some(&1));
assert_eq!(iter.next(), Some(&2));
assert_eq!(iter.next(), Some(&3));
assert_eq!(iter.next(), None);
}

#[test]
fn test_insert() {
let coord = coord![1, 2, 3].unwrap();
let new_coord = coord.insert(1, 4);
assert_eq!(new_coord, coord![1, 4, 2, 3].unwrap());
}

#[test]
fn test_index() {
let coord = coord![1, 2, 3].unwrap();
assert_eq!(coord[0], 1);
assert_eq!(coord[1], 2);
assert_eq!(coord[2], 3);
}

#[test]
fn test_index_mut() {
let mut coord = coord![1, 2, 3].unwrap();
coord[1] = 4;
assert_eq!(coord[1], 4);
}

#[test]
fn test_display() {
let coord = coord![1, 2, 3].unwrap();
assert_eq!(format!("{}", coord), "(1, 2, 3)");
}

#[test]
fn test_coord_macro() {
let coord = coord![1, 2, 3].unwrap();
assert_eq!(coord, Coordinate::new(vec![1, 2, 3]).unwrap());

let coord_repeated = coord![1; 3].unwrap();
assert_eq!(coord_repeated, Coordinate::new(vec![1, 1, 1]).unwrap());
}
}
20 changes: 20 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use std::fmt;

#[derive(Debug, Clone)]
pub struct ShapeError {
reason: String,
}

impl ShapeError {
pub fn new(reason: &str) -> Self {
ShapeError {
reason: reason.to_string(),
}
}
}

impl fmt::Display for ShapeError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "ShapeError: {}", self.reason)
}
}
88 changes: 88 additions & 0 deletions src/iter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use crate::coord;
use crate::coordinate::Coordinate;
use crate::shape::Shape;
use std::cmp::max;

pub struct IndexIterator {
shape: Shape,
current: Coordinate,
done: bool,
}

impl IndexIterator {
pub fn new(shape: &Shape) -> Self {
// (shape.order() == 0) => `next` returns None before `current` is used
let current = coord![0; max(shape.order(), 1)].unwrap();
IndexIterator {
shape: shape.clone(),
current,
done: false,
}
}
}

impl Iterator for IndexIterator {
type Item = Coordinate;

fn next(&mut self) -> Option<Self::Item> {
if self.done || self.shape.order() == 0 {
return None;
}

let result = self.current.clone();

for i in (0..self.shape.order()).rev() {
if self.current[i] + 1 < self.shape[i] {
self.current[i] += 1;
break;
} else {
self.current[i] = 0;
if i == 0 {
self.done = true;
}
}
}

Some(result)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::shape;

#[test]
fn test_index_iterator() {
let shape = shape![2, 3].unwrap();
let mut iter = IndexIterator::new(&shape);

assert_eq!(iter.next(), Some(coord![0, 0].unwrap()));
assert_eq!(iter.next(), Some(coord![0, 1].unwrap()));
assert_eq!(iter.next(), Some(coord![0, 2].unwrap()));
assert_eq!(iter.next(), Some(coord![1, 0].unwrap()));
assert_eq!(iter.next(), Some(coord![1, 1].unwrap()));
assert_eq!(iter.next(), Some(coord![1, 2].unwrap()));
assert_eq!(iter.next(), None);
}

#[test]
fn test_index_iterator_single_dimension() {
let shape = shape![4].unwrap();
let mut iter = IndexIterator::new(&shape);

assert_eq!(iter.next(), Some(coord![0].unwrap()));
assert_eq!(iter.next(), Some(coord![1].unwrap()));
assert_eq!(iter.next(), Some(coord![2].unwrap()));
assert_eq!(iter.next(), Some(coord![3].unwrap()));
assert_eq!(iter.next(), None);
}

#[test]
fn test_index_iterator_empty_tensor() {
let shape = shape![].unwrap();
let mut iter = IndexIterator::new(&shape);

assert_eq!(iter.next(), None);
}
}
23 changes: 9 additions & 14 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
pub fn add(left: usize, right: usize) -> usize {
left + right
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn it_works() {
let result = add(2, 2);
assert_eq!(result, 4);
}
}
pub mod axes;
pub mod coordinate;
pub mod error;
pub mod iter;
pub mod matrix;
pub mod shape;
pub mod storage;
pub mod tensor;
pub mod vector;
Loading

0 comments on commit 87c2d59

Please sign in to comment.