Skip to content

Commit d13bb77

Browse files
committed
Generalize general_tensor_mul to arbitrary dimensions
1 parent 20c263a commit d13bb77

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

Cargo.lock

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pineappl/src/evolution.rs

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ use float_cmp::approx_eq;
1111
use itertools::izip;
1212
use itertools::Itertools;
1313
use ndarray::linalg;
14-
use ndarray::{s, Array1, Array2, Array3, ArrayD, ArrayView1, ArrayView4, Axis, Ix1, Ix2};
14+
use ndarray::{
15+
s, Array1, Array2, Array3, ArrayD, ArrayView1, ArrayView4, ArrayViewD, ArrayViewMutD, Axis,
16+
Ix1, Ix2,
17+
};
1518
use std::iter;
1619

1720
/// This structure captures the information needed to create an evolution kernel operator (EKO) for
@@ -462,7 +465,7 @@ pub(crate) fn evolve_slice_with_many(
462465
.map(|ops| (fk_table, ops))
463466
})
464467
{
465-
general_tensor_mul(*factor, &array, &ops, fk_table);
468+
general_tensor_mul(*factor, array.view(), &ops, &mut fk_table.view_mut());
466469
}
467470
}
468471
}
@@ -496,11 +499,12 @@ pub(crate) fn evolve_slice_with_many(
496499

497500
fn general_tensor_mul(
498501
factor: f64,
499-
array: &ArrayD<f64>,
502+
array: ArrayViewD<f64>,
500503
ops: &[&Array2<f64>],
501-
fk_table: &mut ArrayD<f64>,
504+
fk_table: &mut ArrayViewMutD<f64>,
502505
) {
503506
match array.shape().len() {
507+
0 => unreachable!(),
504508
1 => {
505509
let array = array.view().into_dimensionality::<Ix1>().unwrap();
506510
let mut fk_table = fk_table.view_mut().into_dimensionality::<Ix1>().unwrap();
@@ -516,7 +520,17 @@ fn general_tensor_mul(
516520
// fk_table += factor * ops[0] * tmp
517521
linalg::general_mat_mul(factor, ops[0], &tmp, 1.0, &mut fk_table);
518522
}
519-
// TODO: generalize this to n dimensions
520-
_ => unimplemented!(),
523+
_ => {
524+
let (ops_0, ops_dm1) = ops.split_first().unwrap();
525+
526+
for (mut fk_table_i, ops_0_i) in fk_table
527+
.axis_iter_mut(Axis(0))
528+
.zip(ops_0.axis_iter(Axis(0)))
529+
{
530+
for (array_j, ops_0_ij) in array.axis_iter(Axis(0)).zip(ops_0_i.iter()) {
531+
general_tensor_mul(factor * ops_0_ij, array_j, &ops_dm1, &mut fk_table_i);
532+
}
533+
}
534+
}
521535
}
522536
}

0 commit comments

Comments
 (0)