@@ -11,7 +11,10 @@ use float_cmp::approx_eq;
11
11
use itertools:: izip;
12
12
use itertools:: Itertools ;
13
13
use ndarray:: linalg;
14
- use ndarray:: { s, Array1 , Array2 , Array3 , ArrayD , ArrayView1 , ArrayView4 , Axis , Ix1 , Ix2 } ;
14
+ use ndarray:: {
15
+ s, Array1 , Array2 , Array3 , ArrayD , ArrayView1 , ArrayView4 , ArrayViewD , ArrayViewMutD , Axis ,
16
+ Ix1 , Ix2 ,
17
+ } ;
15
18
use std:: iter;
16
19
17
20
/// This structure captures the information needed to create an evolution kernel operator (EKO) for
@@ -462,7 +465,7 @@ pub(crate) fn evolve_slice_with_many(
462
465
. map ( |ops| ( fk_table, ops) )
463
466
} )
464
467
{
465
- general_tensor_mul ( * factor, & array, & ops, fk_table) ;
468
+ general_tensor_mul ( * factor, array. view ( ) , & ops, & mut fk_table. view_mut ( ) ) ;
466
469
}
467
470
}
468
471
}
@@ -496,11 +499,12 @@ pub(crate) fn evolve_slice_with_many(
496
499
497
500
fn general_tensor_mul (
498
501
factor : f64 ,
499
- array : & ArrayD < f64 > ,
502
+ array : ArrayViewD < f64 > ,
500
503
ops : & [ & Array2 < f64 > ] ,
501
- fk_table : & mut ArrayD < f64 > ,
504
+ fk_table : & mut ArrayViewMutD < f64 > ,
502
505
) {
503
506
match array. shape ( ) . len ( ) {
507
+ 0 => unreachable ! ( ) ,
504
508
1 => {
505
509
let array = array. view ( ) . into_dimensionality :: < Ix1 > ( ) . unwrap ( ) ;
506
510
let mut fk_table = fk_table. view_mut ( ) . into_dimensionality :: < Ix1 > ( ) . unwrap ( ) ;
@@ -516,7 +520,17 @@ fn general_tensor_mul(
516
520
// fk_table += factor * ops[0] * tmp
517
521
linalg:: general_mat_mul ( factor, ops[ 0 ] , & tmp, 1.0 , & mut fk_table) ;
518
522
}
519
- // TODO: generalize this to n dimensions
520
- _ => unimplemented ! ( ) ,
523
+ _ => {
524
+ let ( ops_0, ops_dm1) = ops. split_first ( ) . unwrap ( ) ;
525
+
526
+ for ( mut fk_table_i, ops_0_i) in fk_table
527
+ . axis_iter_mut ( Axis ( 0 ) )
528
+ . zip ( ops_0. axis_iter ( Axis ( 0 ) ) )
529
+ {
530
+ for ( array_j, ops_0_ij) in array. axis_iter ( Axis ( 0 ) ) . zip ( ops_0_i. iter ( ) ) {
531
+ general_tensor_mul ( factor * ops_0_ij, array_j, & ops_dm1, & mut fk_table_i) ;
532
+ }
533
+ }
534
+ }
521
535
}
522
536
}
0 commit comments