Skip to content

Commit 493674d

Browse files
adamreicholdbluss
authored andcommitted
Add a par_fold method to Zip to improve the discoverability of Rayon's fold-reduce idiom.
1 parent b10b09c commit 493674d

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

src/parallel/impl_par_methods.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,56 @@ macro_rules! zip_impl {
177177
self.par_map_assign_into(into, f)
178178
}
179179

180+
/// Parallel version of `fold`.
181+
///
182+
/// Splits the producer in multiple tasks which each accumulate a single value
183+
/// using the `fold` closure. Those tasks are executed in parallel and their results
184+
/// are then combined to a single value using the `reduce` closure.
185+
///
186+
/// The `identity` closure provides the initial values for each of the tasks and
187+
/// for the final reduction.
188+
///
189+
/// This is a shorthand for calling `self.into_par_iter().fold(...).reduce(...)`.
190+
///
191+
/// Note that it is often more efficient to parallelize not per-element but rather
192+
/// based on larger chunks of an array like generalized rows and operating on each chunk
193+
/// using a sequential variant of the accumulation.
194+
/// For example, sum each row sequentially and in parallel, taking advatange of locality
195+
/// and vectorization within each task, and then reduce their sums to the sum of the matrix.
196+
///
197+
/// Also note that the splitting of the producer into multiple tasks is _not_ deterministic
198+
/// which needs to be considered when the accuracy of such an operation is analyzed.
199+
///
200+
/// ## Examples
201+
///
202+
/// ```rust
203+
/// use ndarray::{Array, Zip};
204+
///
205+
/// let a = Array::<usize, _>::ones((128, 1024));
206+
/// let b = Array::<usize, _>::ones(128);
207+
///
208+
/// let weighted_sum = Zip::from(a.rows()).and(&b).par_fold(
209+
/// || 0,
210+
/// |sum, row, factor| sum + row.sum() * factor,
211+
/// |sum, other_sum| sum + other_sum,
212+
/// );
213+
///
214+
/// assert_eq!(weighted_sum, a.len());
215+
/// ```
216+
pub fn par_fold<ID, F, R, T>(self, identity: ID, fold: F, reduce: R) -> T
217+
where
218+
ID: Fn() -> T + Send + Sync + Clone,
219+
F: Fn(T, $($p::Item),*) -> T + Send + Sync,
220+
R: Fn(T, T) -> T + Send + Sync,
221+
T: Send
222+
{
223+
self.into_par_iter()
224+
.fold(identity.clone(), move |accumulator, ($($p,)*)| {
225+
fold(accumulator, $($p),*)
226+
})
227+
.reduce(identity, reduce)
228+
}
229+
180230
);
181231
}
182232
)+

0 commit comments

Comments
 (0)