Skip to content

Commit

Permalink
Separate resolved/unresolved Pulse using newtypes. (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
molpopgen authored Jun 14, 2022
1 parent 6a25123 commit ad6b86b
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 40 deletions.
75 changes: 41 additions & 34 deletions src/specification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,12 @@ impl From<Migration> for UnresolvedMigration {

#[derive(Clone, Default, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[serde(deny_unknown_fields)]
pub struct Pulse {
#[repr(transparent)]
pub struct Pulse(UnresolvedPulse);

#[derive(Clone, Default, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[serde(deny_unknown_fields)]
pub struct UnresolvedPulse {
pub sources: Option<Vec<String>>,
pub dest: Option<String>,
pub time: Option<Time>,
Expand All @@ -556,7 +561,7 @@ impl Pulse {
match deme_map.get(deme) {
Some(d) => {
let t = d.time_interval();
let time = match self.time {
let time = match self.0.time {
Some(t) => t,
None => return Err(DemesError::PulseError("time is None".to_string())),
};
Expand All @@ -576,7 +581,7 @@ impl Pulse {
}

fn validate_pulse_time(&self, deme_map: &DemeMap) -> Result<(), DemesError> {
match self.time {
match self.0.time {
Some(time) => {
if !time.is_valid_pulse_time() {
return Err(DemesError::PulseError(format!(
Expand All @@ -588,7 +593,7 @@ impl Pulse {
None => return Err(DemesError::PulseError("time is None".to_string())),
}

for source_name in self.sources.as_ref().unwrap() {
for source_name in self.0.sources.as_ref().unwrap() {
let source = deme_map.get(source_name).unwrap();

let ti = source.time_interval();
Expand Down Expand Up @@ -616,15 +621,15 @@ impl Pulse {
}

fn validate_proportions(&self) -> Result<(), DemesError> {
if self.proportions.is_none() {
if self.0.proportions.is_none() {
return Err(DemesError::PulseError("proportions is None".to_string()));
}
if self.sources.is_none() {
if self.0.sources.is_none() {
return Err(DemesError::PulseError("sources is None".to_string()));
}

let proportions = self.proportions.as_ref().unwrap();
let sources = self.sources.as_ref().unwrap();
let proportions = self.0.proportions.as_ref().unwrap();
let sources = self.0.sources.as_ref().unwrap();
if proportions.len() != sources.len() {
return Err(DemesError::PulseError(format!("number of sources must equal number of proportions; got {} source and {} proportions", sources.len(), proportions.len())));
}
Expand All @@ -644,8 +649,8 @@ impl Pulse {
}

fn dest_is_not_source(&self) -> Result<(), DemesError> {
let dest = self.dest.as_ref().unwrap();
if self.sources.as_ref().unwrap().contains(dest) {
let dest = self.0.dest.as_ref().unwrap();
if self.0.sources.as_ref().unwrap().contains(dest) {
Err(DemesError::PulseError(format!(
"dest: {} is also listed as a source",
dest
Expand All @@ -657,7 +662,7 @@ impl Pulse {

fn sources_are_unique(&self) -> Result<(), DemesError> {
let mut sources = HashSet::<String>::default();
for source in self.sources.as_ref().unwrap() {
for source in self.0.sources.as_ref().unwrap() {
if sources.contains(source) {
return Err(DemesError::PulseError(format!(
"source: {} listed multiple times",
Expand All @@ -674,18 +679,19 @@ impl Pulse {

// NOTE: validate proportions is taking care of
// returning Err if this is not true
assert!(self.sources.is_some());
assert!(self.0.sources.is_some());

let sources = self.sources.as_ref().unwrap();
let sources = self.0.sources.as_ref().unwrap();
sources
.iter()
.try_for_each(|source| self.validate_deme_existence(source, deme_map))?;

self.dest
self.0
.dest
.as_ref()
.ok_or_else(|| DemesError::PulseError("dest is None".to_string()))?;

self.validate_deme_existence(self.dest.as_ref().unwrap(), deme_map)?;
self.validate_deme_existence(self.0.dest.as_ref().unwrap(), deme_map)?;
self.dest_is_not_source()?;
self.sources_are_unique()?;
self.validate_pulse_time(deme_map)
Expand All @@ -697,28 +703,28 @@ impl Pulse {
}

pub fn time(&self) -> Time {
match self.time {
match self.0.time {
Some(time) => time,
None => panic!("pulse time is None"),
}
}

pub fn sources(&self) -> &[String] {
match &self.sources {
match &self.0.sources {
Some(sources) => sources,
None => panic!("sources are None"),
}
}

pub fn dest(&self) -> &str {
match &self.dest {
match &self.0.dest {
Some(dest) => dest,
None => panic!("pulse dest is None"),
}
}

pub fn proportions(&self) -> &[Proportion] {
match &self.proportions {
match &self.0.proportions {
Some(proportions) => proportions,
None => panic!("proportions are None"),
}
Expand Down Expand Up @@ -1563,8 +1569,8 @@ pub struct GraphDefaults {
pub epoch: EpochData,
#[serde(default = "UnresolvedMigration::default")]
pub migration: UnresolvedMigration,
#[serde(default = "Pulse::default")]
pub pulse: Pulse,
#[serde(default = "UnresolvedPulse::default")]
pub pulse: UnresolvedPulse,
#[serde(default = "TopLevelDemeDefaults::default")]
pub deme: TopLevelDemeDefaults,
}
Expand Down Expand Up @@ -1629,17 +1635,17 @@ impl GraphDefaults {
}

fn apply_pulse_defaults(&self, other: &mut Pulse) {
if other.time.is_none() {
other.time = self.pulse.time;
if other.0.time.is_none() {
other.0.time = self.pulse.time;
}
if other.sources.is_none() {
other.sources = self.pulse.sources.clone();
if other.0.sources.is_none() {
other.0.sources = self.pulse.sources.clone();
}
if other.dest.is_none() {
other.dest = self.pulse.dest.clone();
if other.0.dest.is_none() {
other.0.dest = self.pulse.dest.clone();
}
if other.proportions.is_none() {
other.proportions = self.pulse.proportions.clone();
if other.0.proportions.is_none() {
other.0.proportions = self.pulse.proportions.clone();
}
}
}
Expand Down Expand Up @@ -1810,12 +1816,12 @@ impl Graph {
time: Option<Time>,
proportions: Option<Vec<Proportion>>,
) {
self.pulses.push(Pulse {
self.pulses.push(Pulse(UnresolvedPulse {
sources,
dest,
time,
proportions,
});
}));
}

fn build_deme_map(&self) -> Result<DemeMap, DemesError> {
Expand Down Expand Up @@ -2115,16 +2121,17 @@ impl Graph {
}

fn resolve_pulses(&mut self) -> Result<(), DemesError> {
if self.pulses.is_empty() && self.defaults.pulse != Pulse::default() {
self.pulses.push(self.defaults.pulse.clone());
if self.pulses.is_empty() && self.defaults.pulse != UnresolvedPulse::default() {
let c = self.defaults.pulse.clone();
self.pulses.push(Pulse(c));
}
self.pulses
.iter_mut()
.try_for_each(|pulse| pulse.resolve(&self.defaults))?;
// NOTE: the sort_by flips the order to b, a
// to put more ancient events at the front.
self.pulses
.sort_by(|a, b| b.time.partial_cmp(&a.time).unwrap());
.sort_by(|a, b| b.0.time.partial_cmp(&a.0.time).unwrap());
Ok(())
}

Expand Down
5 changes: 1 addition & 4 deletions tests/test_bad_yaml_inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,7 @@ demes:
- start_size: 1000
";
let g = demes::loads(yaml).unwrap();
assert!(matches!(
g.time_units(),
demes::TimeUnits::Generations
));
assert!(matches!(g.time_units(), demes::TimeUnits::Generations));
}

#[test]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use demes::GenerationTime;
use demes::GraphBuilder;
use demes::GraphDefaults;
use demes::Proportion;
use demes::Pulse;
use demes::UnresolvedPulse;
use demes::Time;
use demes::TimeUnits;
use demes::TopLevelDemeDefaults;
Expand All @@ -29,7 +29,7 @@ demes:
let graph_from_yaml = demes::loads(yaml).unwrap();

let toplevel_defaults = GraphDefaults {
pulse: Pulse {
pulse: UnresolvedPulse {
sources: Some(vec!["A".to_string()]),
dest: Some("B".to_string()),
proportions: Some(vec![Proportion::try_from(0.25).unwrap()]),
Expand Down

0 comments on commit ad6b86b

Please sign in to comment.