diff --git a/rust/ommx/src/convert/instance.rs b/rust/ommx/src/convert/instance.rs index a3c993e3..00b443da 100644 --- a/rust/ommx/src/convert/instance.rs +++ b/rust/ommx/src/convert/instance.rs @@ -1,6 +1,6 @@ use crate::v1::{ instance::{Description, Sense}, - Function, Instance, + Function, Instance, Parameter, ParametricInstance, }; use anyhow::{bail, Result}; use approx::AbsDiffEq; @@ -52,6 +52,31 @@ impl Instance { .prop_flat_map(Self::arbitrary_with) .boxed() } + + pub fn penalty_method(self) -> ParametricInstance { + let id_base = self.defined_ids().last().map(|id| id + 1).unwrap_or(0); + let mut objective = self.objective().into_owned(); + let mut parameters = Vec::new(); + for (i, c) in self.constraints.into_iter().enumerate() { + let parameter = Parameter { + id: id_base + i as u64, + name: Some("penalty".to_string()), + subscripts: vec![c.id as i64], + ..Default::default() + }; + let f = c.function().into_owned(); + objective = objective + ¶meter * f.clone() * f; + parameters.push(parameter); + } + ParametricInstance { + description: self.description, + objective: Some(objective), + constraints: Vec::new(), + decision_variables: self.decision_variables.clone(), + sense: self.sense, + parameters, + } + } } impl Arbitrary for Instance { @@ -193,11 +218,44 @@ impl AbsDiffEq for Instance { #[cfg(test)] mod tests { use super::*; + use crate::v1::Parameters; proptest! { #[test] fn test_instance_arbitrary_any(instance in Instance::arbitrary()) { instance.check_decision_variables().unwrap(); } + + #[test] + fn test_penalty_method(instance in Instance::arbitrary()) { + let parametric_instance = instance.clone().penalty_method(); + let dv_ids = parametric_instance.defined_decision_variable_ids(); + let p_ids = parametric_instance.defined_parameter_ids(); + prop_assert!(dv_ids.is_disjoint(&p_ids)); + + let used_ids = parametric_instance.used_ids().unwrap(); + let all_ids = dv_ids.union(&p_ids).cloned().collect(); + prop_assert!(used_ids.is_subset(&all_ids)); + + // Put every penalty weights to zero + let parameters = Parameters { + entries: p_ids.iter().map(|&id| (id, 0.0)).collect(), + }; + let substituted = parametric_instance.clone().with_parameters(parameters).unwrap(); + prop_assert!(instance.objective().abs_diff_eq(&substituted.objective(), 1e-10)); + prop_assert_eq!(substituted.constraints.len(), 0); + + // Put every penalty weights to two + let parameters = Parameters { + entries: p_ids.iter().map(|&id| (id, 2.0)).collect(), + }; + let substituted = parametric_instance.with_parameters(parameters).unwrap(); + let mut objective = instance.objective().into_owned(); + for c in &instance.constraints { + let f = c.function().into_owned(); + objective = objective + 2.0 * f.clone() * f; + } + prop_assert!(objective.abs_diff_eq(&substituted.objective(), 1e-10)); + } } }