Skip to content

Commit eba8f49

Browse files
committed
feat: add matrix multiplication
1 parent 918cf6d commit eba8f49

File tree

2 files changed

+144
-0
lines changed

2 files changed

+144
-0
lines changed

src/lib.rs

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,63 @@ impl<T: Num + PartialOrd + Copy> Mul<T> for Tensor<T> {
265265
}
266266
}
267267

268+
// Vector/Matrix Multiplication
269+
impl<T: Num + PartialOrd + Copy> Mul<Tensor<T>> for Tensor<T> {
270+
type Output = Tensor<T>;
271+
272+
fn mul(self, rhs: Tensor<T>) -> Tensor<T> {
273+
if self.shape.len() == 1 && rhs.shape.len() == 1 {
274+
// Vector-Vector multiplication (dot product)
275+
assert!(self.shape[0] == rhs.shape[0], "Vectors must be of the same length for dot product.");
276+
let mut result = T::zero();
277+
for i in 0..self.shape[0] {
278+
result = result + self.data[i] * rhs.data[i];
279+
}
280+
Tensor::new(&shape![1], &vec![result])
281+
} else if self.shape.len() == 1 && rhs.shape.len() == 2 {
282+
// Vector-Matrix multiplication
283+
assert!(self.shape[0] == rhs.shape[0], "The length of the vector must be equal to the number of rows in the matrix.");
284+
let mut result = Tensor::zeros(&shape![rhs.shape[1]]);
285+
for j in 0..rhs.shape[1] {
286+
let mut sum = T::zero();
287+
for i in 0..self.shape[0] {
288+
sum = sum + self.data[i] * rhs.data[i * rhs.shape[1] + j];
289+
}
290+
result.data[j] = sum;
291+
}
292+
result
293+
} else if self.shape.len() == 2 && rhs.shape.len() == 1 {
294+
// Matrix-Vector multiplication
295+
assert!(self.shape[1] == rhs.shape[0], "The number of columns in the matrix must be equal to the length of the vector.");
296+
let mut result = Tensor::zeros(&shape![self.shape[0]]);
297+
for i in 0..self.shape[0] {
298+
let mut sum = T::zero();
299+
for j in 0..self.shape[1] {
300+
sum = sum + self.data[i * self.shape[1] + j] * rhs.data[j];
301+
}
302+
result.data[i] = sum;
303+
}
304+
result
305+
} else if self.shape.len() == 2 && rhs.shape.len() == 2 {
306+
// Matrix-Matrix multiplication
307+
assert!(self.shape[1] == rhs.shape[0], "The number of columns in the first matrix must be equal to the number of rows in the second matrix.");
308+
let mut result = Tensor::zeros(&shape![self.shape[0], rhs.shape[1]]);
309+
for i in 0..self.shape[0] {
310+
for j in 0..rhs.shape[1] {
311+
let mut sum = T::zero();
312+
for k in 0..self.shape[1] {
313+
sum = sum + self.data[i * self.shape[1] + k] * rhs.data[k * rhs.shape[1] + j];
314+
}
315+
result.data[i * rhs.shape[1] + j] = sum;
316+
}
317+
}
318+
result
319+
} else {
320+
panic!("Unsupported shapes for multiplication.");
321+
}
322+
}
323+
}
324+
268325
// Element-wise Addition
269326
impl<T: Num + PartialOrd + Copy> Add<T> for Tensor<T> {
270327
type Output = Tensor<T>;
@@ -773,6 +830,84 @@ mod tests {
773830
assert_eq!(result.data, vec![2.0, 4.0, 6.0, 8.0]);
774831
}
775832

833+
#[test]
834+
fn test_vec_vec_mul_single() {
835+
let shape = shape![1];
836+
let data1 = vec![2.0];
837+
let data2 = vec![5.0];
838+
839+
let tensor1 = Tensor::new(&shape, &data1);
840+
let tensor2 = Tensor::new(&shape, &data2);
841+
842+
let result = tensor1 * tensor2;
843+
844+
assert_eq!(result.shape(), &shape![1]);
845+
assert_eq!(result.data, vec![10.0]);
846+
}
847+
848+
#[test]
849+
fn test_vec_vec_mul() {
850+
let shape = shape![4];
851+
let data1 = vec![1.0, 2.0, 3.0, 4.0];
852+
let data2 = vec![2.0, 3.0, 4.0, 5.0];
853+
854+
let tensor1 = Tensor::new(&shape, &data1);
855+
let tensor2 = Tensor::new(&shape, &data2);
856+
857+
let result = tensor1 * tensor2;
858+
859+
assert_eq!(result.shape(), &shape![1]);
860+
assert_eq!(result.data, vec![40.0]);
861+
}
862+
863+
#[test]
864+
fn test_vec_matrix_mul() {
865+
let shape_vec = shape![2];
866+
let shape_matrix = shape![2, 3];
867+
let data_vec = vec![1.0, 2.0];
868+
let data_matrix = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
869+
870+
let tensor_vec = Tensor::new(&shape_vec, &data_vec);
871+
let tensor_matrix = Tensor::new(&shape_matrix, &data_matrix);
872+
873+
let result = tensor_vec * tensor_matrix;
874+
875+
assert_eq!(result.shape(), &shape![3]);
876+
assert_eq!(result.data, vec![9.0, 12.0, 15.0]);
877+
}
878+
879+
#[test]
880+
fn test_matrix_vec_mul() {
881+
let shape_matrix = shape![2, 3];
882+
let shape_vec = shape![3];
883+
let data_matrix = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
884+
let data_vec = vec![1.0, 2.0, 3.0];
885+
886+
let tensor_matrix = Tensor::new(&shape_matrix, &data_matrix);
887+
let tensor_vec = Tensor::new(&shape_vec, &data_vec);
888+
889+
let result = tensor_matrix * tensor_vec;
890+
891+
assert_eq!(result.shape(), &shape![2]);
892+
assert_eq!(result.data, vec![14.0, 32.0]);
893+
}
894+
895+
#[test]
896+
fn test_matrix_matrix_mul() {
897+
let shape1 = shape![2, 3];
898+
let shape2 = shape![3, 2];
899+
let data1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
900+
let data2 = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
901+
902+
let tensor1 = Tensor::new(&shape1, &data1);
903+
let tensor2 = Tensor::new(&shape2, &data2);
904+
905+
let result = tensor1 * tensor2;
906+
907+
assert_eq!(result.shape(), &shape![2, 2]);
908+
assert_eq!(result.data, vec![58.0, 64.0, 139.0, 154.0]);
909+
}
910+
776911
#[test]
777912
fn test_div_tensor() {
778913
let shape = shape![4];

src/shape.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::fmt;
2+
use std::ops::Index;
23

34
#[derive(Debug, Clone, PartialEq)]
45
pub struct Shape {
@@ -19,6 +20,14 @@ impl Shape {
1920
}
2021
}
2122

23+
impl Index<usize> for Shape {
24+
type Output = usize;
25+
26+
fn index(&self, index: usize) -> &Self::Output {
27+
&self.dims[index]
28+
}
29+
}
30+
2231
impl fmt::Display for Shape {
2332
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2433
use itertools::Itertools;

0 commit comments

Comments
 (0)