Skip to content

Commit

Permalink
Merge pull request #32 from denehoffman/status-update
Browse files Browse the repository at this point in the history
Status update
  • Loading branch information
denehoffman authored Sep 12, 2024
2 parents 0743380 + 37d1cf4 commit ebfbc9e
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 140 deletions.
35 changes: 15 additions & 20 deletions src/algorithms/bfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ pub enum BFGSErrorMode {
#[allow(clippy::upper_case_acronyms)]
pub struct BFGS<T: Scalar, U, E> {
status: Status<T>,
x: DVector<T>,
g: DVector<T>,
h_inv: DMatrix<T>,
Expand Down Expand Up @@ -114,7 +113,6 @@ where
{
fn default() -> Self {
Self {
status: Default::default(),
x: Default::default(),
g: Default::default(),
h_inv: Default::default(),
Expand Down Expand Up @@ -158,18 +156,18 @@ where
x0: &[T],
bounds: Option<&Vec<Bound<T>>>,
user_data: &mut U,
status: &mut Status<T>,
) -> Result<(), E> {
self.status = Status::default();
self.f_previous = T::infinity();
self.h_inv = DMatrix::identity(x0.len(), x0.len());
self.x = Bound::to_unbounded(x0, bounds);
self.g = func.gradient_bounded(self.x.as_slice(), bounds, user_data)?;
self.status.inc_n_g_evals();
self.status.update_position((
status.inc_n_g_evals();
status.update_position((
Bound::to_bounded(self.x.as_slice(), bounds),
func.evaluate_bounded(self.x.as_slice(), bounds, user_data)?,
));
self.status.inc_n_f_evals();
status.inc_n_f_evals();
Ok(())
}

Expand All @@ -179,6 +177,7 @@ where
func: &dyn Function<T, U, E>,
bounds: Option<&Vec<Bound<T>>>,
user_data: &mut U,
status: &mut Status<T>,
) -> Result<(), E> {
let d = -&self.h_inv * &self.g;
let (valid, alpha, f_kp1, g_kp1) = self.line_search.search(
Expand All @@ -188,7 +187,7 @@ where
func,
bounds,
user_data,
&mut self.status,
status,
)?;
if valid {
let dx = d.scale(alpha);
Expand All @@ -198,10 +197,9 @@ where
self.update_h_inv(i_step, n, &dx, &dg);
self.x += dx;
self.g = grad_kp1_vec;
self.status
.update_position((Bound::to_bounded(self.x.as_slice(), bounds), f_kp1));
status.update_position((Bound::to_bounded(self.x.as_slice(), bounds), f_kp1));
} else {
self.status.set_converged();
status.set_converged();
}
Ok(())
}
Expand All @@ -211,33 +209,30 @@ where
func: &dyn Function<T, U, E>,
bounds: Option<&Vec<Bound<T>>>,
user_data: &mut U,
status: &mut Status<T>,
) -> Result<bool, E> {
let f_current = func.evaluate_bounded(self.x.as_slice(), bounds, user_data)?;
self.terminator_f
.update_convergence(f_current, self.f_previous, &mut self.status);
.update_convergence(f_current, self.f_previous, status);
self.f_previous = f_current;
self.terminator_g
.update_convergence(&self.g, &mut self.status);
Ok(self.status.converged)
}

fn get_status(&self) -> &Status<T> {
&self.status
self.terminator_g.update_convergence(&self.g, status);
Ok(status.converged)
}

fn postprocessing(
&mut self,
func: &dyn Function<T, U, E>,
bounds: Option<&Vec<Bound<T>>>,
user_data: &mut U,
status: &mut Status<T>,
) -> Result<(), E> {
match self.error_mode {
BFGSErrorMode::ExactHessian => {
let hessian = func.hessian_bounded(self.x.as_slice(), bounds, user_data)?;
self.status.set_hess(&hessian);
status.set_hess(&hessian);
}
BFGSErrorMode::ApproximateHessian => {
self.status.set_cov(Some(self.h_inv.clone()));
status.set_cov(Some(self.h_inv.clone()));
}
}
Ok(())
Expand Down
31 changes: 13 additions & 18 deletions src/algorithms/lbfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ pub enum LBFGSErrorMode {
#[allow(clippy::upper_case_acronyms)]
pub struct LBFGS<T: Scalar, U, E> {
status: Status<T>,
x: DVector<T>,
g: DVector<T>,
f_previous: T,
Expand Down Expand Up @@ -121,7 +120,6 @@ where
{
fn default() -> Self {
Self {
status: Default::default(),
x: Default::default(),
g: Default::default(),
f_previous: T::infinity(),
Expand Down Expand Up @@ -179,17 +177,17 @@ where
x0: &[T],
bounds: Option<&Vec<Bound<T>>>,
user_data: &mut U,
status: &mut Status<T>,
) -> Result<(), E> {
self.status = Status::default();
self.f_previous = T::infinity();
self.x = Bound::to_unbounded(x0, bounds);
self.g = func.gradient_bounded(self.x.as_slice(), bounds, user_data)?;
self.status.inc_n_g_evals();
self.status.update_position((
status.inc_n_g_evals();
status.update_position((
Bound::to_bounded(self.x.as_slice(), bounds),
func.evaluate_bounded(self.x.as_slice(), bounds, user_data)?,
));
self.status.inc_n_f_evals();
status.inc_n_f_evals();
Ok(())
}

Expand All @@ -199,6 +197,7 @@ where
func: &dyn Function<T, U, E>,
bounds: Option<&Vec<Bound<T>>>,
user_data: &mut U,
status: &mut Status<T>,
) -> Result<(), E> {
let d = -self.g_approx();
let (valid, alpha, f_kp1, g_kp1) = self.line_search.search(
Expand All @@ -208,7 +207,7 @@ where
func,
bounds,
user_data,
&mut self.status,
status,
)?;
if valid {
let dx = d.scale(alpha);
Expand All @@ -226,8 +225,7 @@ where
}
self.x += dx;
self.g = grad_kp1_vec;
self.status
.update_position((Bound::to_bounded(self.x.as_slice(), bounds), f_kp1));
status.update_position((Bound::to_bounded(self.x.as_slice(), bounds), f_kp1));
} else {
// reboot
self.s_store.clear();
Expand All @@ -241,30 +239,27 @@ where
func: &dyn Function<T, U, E>,
bounds: Option<&Vec<Bound<T>>>,
user_data: &mut U,
status: &mut Status<T>,
) -> Result<bool, E> {
let f_current = func.evaluate_bounded(self.x.as_slice(), bounds, user_data)?;
self.terminator_f
.update_convergence(f_current, self.f_previous, &mut self.status);
.update_convergence(f_current, self.f_previous, status);
self.f_previous = f_current;
self.terminator_g
.update_convergence(&self.g, &mut self.status);
Ok(self.status.converged)
}

fn get_status(&self) -> &Status<T> {
&self.status
self.terminator_g.update_convergence(&self.g, status);
Ok(status.converged)
}

fn postprocessing(
&mut self,
func: &dyn Function<T, U, E>,
bounds: Option<&Vec<Bound<T>>>,
user_data: &mut U,
status: &mut Status<T>,
) -> Result<(), E> {
match self.error_mode {
LBFGSErrorMode::ExactHessian => {
let hessian = func.hessian_bounded(self.x.as_slice(), bounds, user_data)?;
self.status.set_hess(&hessian);
status.set_hess(&hessian);
}
}
Ok(())
Expand Down
36 changes: 15 additions & 21 deletions src/algorithms/lbfgsb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ pub enum LBFGSBErrorMode {
/// [^1]: [Numerical Optimization. Springer New York, 2006. doi: 10.1007/978-0-387-40065-5.](https://doi.org/10.1007/978-0-387-40065-5)
#[allow(clippy::upper_case_acronyms)]
pub struct LBFGSB<T: Scalar, U, E> {
status: Status<T>,
x: DVector<T>,
g: DVector<T>,
l: DVector<T>,
Expand Down Expand Up @@ -132,7 +131,6 @@ where
{
fn default() -> Self {
Self {
status: Default::default(),
x: Default::default(),
g: Default::default(),
l: Default::default(),
Expand Down Expand Up @@ -376,8 +374,8 @@ where
x0: &[T],
bounds: Option<&Vec<Bound<T>>>,
user_data: &mut U,
status: &mut Status<T>,
) -> Result<(), E> {
self.status = Status::default();
self.f_previous = T::infinity();
self.theta = T::one();
self.l = DVector::from_element(x0.len(), T::neg_infinity());
Expand Down Expand Up @@ -405,10 +403,9 @@ where
}
});
self.g = func.gradient(self.x.as_slice(), user_data)?;
self.status.inc_n_g_evals();
self.status
.update_position((self.x.clone(), func.evaluate(self.x.as_slice(), user_data)?));
self.status.inc_n_f_evals();
status.inc_n_g_evals();
status.update_position((self.x.clone(), func.evaluate(self.x.as_slice(), user_data)?));
status.inc_n_f_evals();
self.w_mat = DMatrix::zeros(self.x.len(), 1);
self.m_mat = DMatrix::zeros(1, 1);
Ok(())
Expand All @@ -420,6 +417,7 @@ where
func: &dyn Function<T, U, E>,
bounds: Option<&Vec<Bound<T>>>,
user_data: &mut U,
status: &mut Status<T>,
) -> Result<(), E> {
let d = self.compute_step_direction();
let max_step = self.compute_max_step(&d);
Expand All @@ -430,7 +428,7 @@ where
func,
bounds,
user_data,
&mut self.status,
status,
)?;
if valid {
let dx = d.scale(alpha);
Expand All @@ -450,7 +448,7 @@ where
}
self.x += dx;
self.g = grad_kp1_vec;
self.status.update_position((self.x.clone(), f_kp1));
status.update_position((self.x.clone(), f_kp1));
} else {
// reboot
self.s_store.clear();
Expand All @@ -467,35 +465,31 @@ where
func: &dyn Function<T, U, E>,
_bounds: Option<&Vec<Bound<T>>>,
user_data: &mut U,
status: &mut Status<T>,
) -> Result<bool, E> {
let f_current = func.evaluate(self.x.as_slice(), user_data)?;
self.terminator_f
.update_convergence(f_current, self.f_previous, &mut self.status);
.update_convergence(f_current, self.f_previous, status);
self.f_previous = f_current;
self.terminator_g
.update_convergence(&self.g, &mut self.status);
self.terminator_g.update_convergence(&self.g, status);
if self.get_inf_norm_projected_gradient() < self.g_tolerance {
self.status.set_converged();
self.status
.update_message("PROJECTED GRADIENT WITHIN TOLERANCE");
status.set_converged();
status.update_message("PROJECTED GRADIENT WITHIN TOLERANCE");
}
Ok(self.status.converged)
}

fn get_status(&self) -> &Status<T> {
&self.status
Ok(status.converged)
}

fn postprocessing(
&mut self,
func: &dyn Function<T, U, E>,
_bounds: Option<&Vec<Bound<T>>>,
user_data: &mut U,
status: &mut Status<T>,
) -> Result<(), E> {
match self.error_mode {
LBFGSBErrorMode::ExactHessian => {
let hessian = func.hessian(self.x.as_slice(), user_data)?;
self.status.set_hess(&hessian);
status.set_hess(&hessian);
}
}
Ok(())
Expand Down
Loading

0 comments on commit ebfbc9e

Please sign in to comment.