@@ -265,6 +265,63 @@ impl<T: Num + PartialOrd + Copy> Mul<T> for Tensor<T> {
265265 }
266266}
267267
268+ // Vector/Matrix Multiplication
269+ impl < T : Num + PartialOrd + Copy > Mul < Tensor < T > > for Tensor < T > {
270+ type Output = Tensor < T > ;
271+
272+ fn mul ( self , rhs : Tensor < T > ) -> Tensor < T > {
273+ if self . shape . len ( ) == 1 && rhs. shape . len ( ) == 1 {
274+ // Vector-Vector multiplication (dot product)
275+ assert ! ( self . shape[ 0 ] == rhs. shape[ 0 ] , "Vectors must be of the same length for dot product." ) ;
276+ let mut result = T :: zero ( ) ;
277+ for i in 0 ..self . shape [ 0 ] {
278+ result = result + self . data [ i] * rhs. data [ i] ;
279+ }
280+ Tensor :: new ( & shape ! [ 1 ] , & vec ! [ result] )
281+ } else if self . shape . len ( ) == 1 && rhs. shape . len ( ) == 2 {
282+ // Vector-Matrix multiplication
283+ assert ! ( self . shape[ 0 ] == rhs. shape[ 0 ] , "The length of the vector must be equal to the number of rows in the matrix." ) ;
284+ let mut result = Tensor :: zeros ( & shape ! [ rhs. shape[ 1 ] ] ) ;
285+ for j in 0 ..rhs. shape [ 1 ] {
286+ let mut sum = T :: zero ( ) ;
287+ for i in 0 ..self . shape [ 0 ] {
288+ sum = sum + self . data [ i] * rhs. data [ i * rhs. shape [ 1 ] + j] ;
289+ }
290+ result. data [ j] = sum;
291+ }
292+ result
293+ } else if self . shape . len ( ) == 2 && rhs. shape . len ( ) == 1 {
294+ // Matrix-Vector multiplication
295+ assert ! ( self . shape[ 1 ] == rhs. shape[ 0 ] , "The number of columns in the matrix must be equal to the length of the vector." ) ;
296+ let mut result = Tensor :: zeros ( & shape ! [ self . shape[ 0 ] ] ) ;
297+ for i in 0 ..self . shape [ 0 ] {
298+ let mut sum = T :: zero ( ) ;
299+ for j in 0 ..self . shape [ 1 ] {
300+ sum = sum + self . data [ i * self . shape [ 1 ] + j] * rhs. data [ j] ;
301+ }
302+ result. data [ i] = sum;
303+ }
304+ result
305+ } else if self . shape . len ( ) == 2 && rhs. shape . len ( ) == 2 {
306+ // Matrix-Matrix multiplication
307+ assert ! ( self . shape[ 1 ] == rhs. shape[ 0 ] , "The number of columns in the first matrix must be equal to the number of rows in the second matrix." ) ;
308+ let mut result = Tensor :: zeros ( & shape ! [ self . shape[ 0 ] , rhs. shape[ 1 ] ] ) ;
309+ for i in 0 ..self . shape [ 0 ] {
310+ for j in 0 ..rhs. shape [ 1 ] {
311+ let mut sum = T :: zero ( ) ;
312+ for k in 0 ..self . shape [ 1 ] {
313+ sum = sum + self . data [ i * self . shape [ 1 ] + k] * rhs. data [ k * rhs. shape [ 1 ] + j] ;
314+ }
315+ result. data [ i * rhs. shape [ 1 ] + j] = sum;
316+ }
317+ }
318+ result
319+ } else {
320+ panic ! ( "Unsupported shapes for multiplication." ) ;
321+ }
322+ }
323+ }
324+
268325// Element-wise Addition
269326impl < T : Num + PartialOrd + Copy > Add < T > for Tensor < T > {
270327 type Output = Tensor < T > ;
@@ -773,6 +830,84 @@ mod tests {
773830 assert_eq ! ( result. data, vec![ 2.0 , 4.0 , 6.0 , 8.0 ] ) ;
774831 }
775832
833+ #[ test]
834+ fn test_vec_vec_mul_single ( ) {
835+ let shape = shape ! [ 1 ] ;
836+ let data1 = vec ! [ 2.0 ] ;
837+ let data2 = vec ! [ 5.0 ] ;
838+
839+ let tensor1 = Tensor :: new ( & shape, & data1) ;
840+ let tensor2 = Tensor :: new ( & shape, & data2) ;
841+
842+ let result = tensor1 * tensor2;
843+
844+ assert_eq ! ( result. shape( ) , & shape![ 1 ] ) ;
845+ assert_eq ! ( result. data, vec![ 10.0 ] ) ;
846+ }
847+
848+ #[ test]
849+ fn test_vec_vec_mul ( ) {
850+ let shape = shape ! [ 4 ] ;
851+ let data1 = vec ! [ 1.0 , 2.0 , 3.0 , 4.0 ] ;
852+ let data2 = vec ! [ 2.0 , 3.0 , 4.0 , 5.0 ] ;
853+
854+ let tensor1 = Tensor :: new ( & shape, & data1) ;
855+ let tensor2 = Tensor :: new ( & shape, & data2) ;
856+
857+ let result = tensor1 * tensor2;
858+
859+ assert_eq ! ( result. shape( ) , & shape![ 1 ] ) ;
860+ assert_eq ! ( result. data, vec![ 40.0 ] ) ;
861+ }
862+
863+ #[ test]
864+ fn test_vec_matrix_mul ( ) {
865+ let shape_vec = shape ! [ 2 ] ;
866+ let shape_matrix = shape ! [ 2 , 3 ] ;
867+ let data_vec = vec ! [ 1.0 , 2.0 ] ;
868+ let data_matrix = vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ;
869+
870+ let tensor_vec = Tensor :: new ( & shape_vec, & data_vec) ;
871+ let tensor_matrix = Tensor :: new ( & shape_matrix, & data_matrix) ;
872+
873+ let result = tensor_vec * tensor_matrix;
874+
875+ assert_eq ! ( result. shape( ) , & shape![ 3 ] ) ;
876+ assert_eq ! ( result. data, vec![ 9.0 , 12.0 , 15.0 ] ) ;
877+ }
878+
879+ #[ test]
880+ fn test_matrix_vec_mul ( ) {
881+ let shape_matrix = shape ! [ 2 , 3 ] ;
882+ let shape_vec = shape ! [ 3 ] ;
883+ let data_matrix = vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ;
884+ let data_vec = vec ! [ 1.0 , 2.0 , 3.0 ] ;
885+
886+ let tensor_matrix = Tensor :: new ( & shape_matrix, & data_matrix) ;
887+ let tensor_vec = Tensor :: new ( & shape_vec, & data_vec) ;
888+
889+ let result = tensor_matrix * tensor_vec;
890+
891+ assert_eq ! ( result. shape( ) , & shape![ 2 ] ) ;
892+ assert_eq ! ( result. data, vec![ 14.0 , 32.0 ] ) ;
893+ }
894+
895+ #[ test]
896+ fn test_matrix_matrix_mul ( ) {
897+ let shape1 = shape ! [ 2 , 3 ] ;
898+ let shape2 = shape ! [ 3 , 2 ] ;
899+ let data1 = vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ;
900+ let data2 = vec ! [ 7.0 , 8.0 , 9.0 , 10.0 , 11.0 , 12.0 ] ;
901+
902+ let tensor1 = Tensor :: new ( & shape1, & data1) ;
903+ let tensor2 = Tensor :: new ( & shape2, & data2) ;
904+
905+ let result = tensor1 * tensor2;
906+
907+ assert_eq ! ( result. shape( ) , & shape![ 2 , 2 ] ) ;
908+ assert_eq ! ( result. data, vec![ 58.0 , 64.0 , 139.0 , 154.0 ] ) ;
909+ }
910+
776911 #[ test]
777912 fn test_div_tensor ( ) {
778913 let shape = shape ! [ 4 ] ;
0 commit comments