Skip to content

Commit

Permalink
Add diverging sampler stats
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Jul 1, 2023
1 parent 73cc0b3 commit 129b94c
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/nuts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ pub struct StatsBuilder<H: Hamiltonian, A: AdaptStrategy> {
gradient: Option<MutableFixedSizeListArray<MutablePrimitiveArray<f64>>>,
hamiltonian: <H::Stats as ArrowRow>::Builder,
adapt: <A::Stats as ArrowRow>::Builder,
diverging: MutableBooleanArray,
}

#[cfg(feature = "arrow")]
Expand Down Expand Up @@ -555,6 +556,7 @@ impl<H: Hamiltonian, A: AdaptStrategy> StatsBuilder<H, A> {
unconstrained,
hamiltonian: <H::Stats as ArrowRow>::new_builder(dim, settings),
adapt: <A::Stats as ArrowRow>::new_builder(dim, settings),
diverging: MutableBooleanArray::with_capacity(capacity),
}
}
}
Expand All @@ -571,6 +573,7 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
self.energy.push(Some(value.energy));
self.chain.push(Some(value.chain));
self.draw.push(Some(value.draw));
self.diverging.push(Some(value.divergence_info().is_some()));

if let Some(store) = self.gradient.as_mut() {
store
Expand Down Expand Up @@ -607,6 +610,7 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
Field::new("energy", DataType::Float64, false),
Field::new("chain", DataType::UInt64, false),
Field::new("draw", DataType::UInt64, false),
Field::new("diverging", DataType::Boolean, false),
];

let mut arrays = vec![
Expand All @@ -617,6 +621,7 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
self.energy.as_box(),
self.chain.as_box(),
self.draw.as_box(),
self.diverging.as_box(),
];

if let Some(hamiltonian) = self.hamiltonian.finalize() {
Expand Down

0 comments on commit 129b94c

Please sign in to comment.