Skip to content

Commit

Permalink
feat: static predicates
Browse files Browse the repository at this point in the history
  • Loading branch information
jmdha committed Aug 19, 2024
1 parent 4ee098b commit af1cc8f
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 41 deletions.
20 changes: 5 additions & 15 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,8 @@ pub struct State {
}

impl State {
pub fn new(facts: Vec<Fact>) -> Self {
State {
facts: facts.into_iter().collect(),
}
pub fn new(facts: BTreeSet<Fact>) -> Self {
State { facts }
}
#[inline(always)]
pub fn fact_count(&self) -> usize {
Expand All @@ -55,19 +53,11 @@ impl State {
self.has_fact(&Fact::new(predicate, vec![]))
}
#[inline(always)]
pub fn has_unary(
&self,
predicate: usize,
arg: &usize,
) -> bool {
pub fn has_unary(&self, predicate: usize, arg: &usize) -> bool {
self.has_fact(&Fact::new(predicate, vec![*arg]))
}
#[inline(always)]
pub fn has_nary(
&self,
predicate: usize,
args: &Vec<usize>,
) -> bool {
pub fn has_nary(&self, predicate: usize, args: &Vec<usize>) -> bool {
self.has_fact(&Fact::new(predicate, args.to_owned()))
}
#[inline(always)]
Expand All @@ -90,7 +80,7 @@ impl State {
}
state
}
pub fn covers(&self, goal: &Vec<(Fact, bool)>) -> bool {
pub fn covers(&self, goal: &BTreeSet<(Fact, bool)>) -> bool {
goal.iter().all(|(f, v)| self.has_fact(f) == *v)
}
}
Expand Down
10 changes: 6 additions & 4 deletions src/task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ pub mod action;
pub mod parameter;
pub mod predicate;

use std::collections::BTreeSet;
use indexmap::IndexSet;

use self::{
action::Action, predicate::Predicate
};
Expand All @@ -12,15 +12,17 @@ use crate::{
};

pub type Plan<'a> = Vec<Operator<'a>>;
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone)]
pub struct Task {
pub domain_name: Option<String>,
pub problem_name: Option<String>,
pub predicates: Vec<Predicate>,
pub actions: Vec<Action>,
pub objects: IndexSet<String>,
pub init: Vec<Fact>,
pub goal: Vec<(Fact, bool)>,
pub static_predicates: BTreeSet<usize>,
pub statics: BTreeSet<Fact>,
pub init: BTreeSet<Fact>,
pub goal: BTreeSet<(Fact, bool)>,
}

impl<'a> Task {
Expand Down
36 changes: 20 additions & 16 deletions src/translate/goal.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::BTreeSet;

use super::error::{Error, Field, Result};
use crate::{state::Fact, task::predicate::Predicate};
use indexmap::IndexSet;
Expand All @@ -6,25 +8,27 @@ pub fn translate(
predicates: &Vec<Predicate>,
objects: &IndexSet<String>,
goal: &pddlp::problem::Goal,
) -> Vec<(Fact, bool)> {
let mut goal_facts: Vec<(Fact, bool)> = Vec::new();
) -> BTreeSet<(Fact, bool)> {
let mut goal_facts: BTreeSet<(Fact, bool)> = BTreeSet::new();
let mut queue: Vec<(&pddlp::problem::Goal, bool)> = vec![(goal, true)];

while let Some((e, value)) = queue.pop() {
match e {
pddlp::problem::Goal::Fact(g) => goal_facts.push((
Fact::new(
predicates
.iter()
.position(|p| p.name == g.predicate)
.unwrap(),
g.objects
.iter()
.map(|o| objects.get_index_of(*o).unwrap())
.collect(),
),
value,
)),
pddlp::problem::Goal::Fact(g) => {
goal_facts.insert((
Fact::new(
predicates
.iter()
.position(|p| p.name == g.predicate)
.unwrap(),
g.objects
.iter()
.map(|o| objects.get_index_of(*o).unwrap())
.collect(),
),
value,
));
}
pddlp::problem::Goal::And(g) => {
queue.extend(g.iter().map(|g| (g, value)))
}
Expand All @@ -40,7 +44,7 @@ pub fn try_translate(
predicates: &Vec<Predicate>,
objects: &IndexSet<String>,
goal: &Option<pddlp::problem::Goal>,
) -> Result<Vec<(Fact, bool)>> {
) -> Result<BTreeSet<(Fact, bool)>> {
Ok(translate(
predicates,
objects,
Expand Down
11 changes: 6 additions & 5 deletions src/translate/init.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::BTreeSet;
use super::error::{Error, Field, Result};
use crate::{
state::Fact,
Expand All @@ -9,7 +10,7 @@ pub fn translate(
predicates: &Vec<Predicate>,
objects: &IndexSet<String>,
facts: &Vec<pddlp::problem::Fact>,
) -> Result<Vec<Fact>> {
) -> Result<BTreeSet<Fact>> {
Ok(facts
.iter()
.map(|fact| {
Expand All @@ -32,14 +33,14 @@ pub fn translate(
.collect::<Result<Vec<usize>>>()?,
))
})
.collect::<Result<Vec<Fact>>>()?)
.collect::<Result<BTreeSet<Fact>>>()?)
}

pub fn try_translate(
predicates: &Vec<Predicate>,
objects: &IndexSet<String>,
facts: &Option<Vec<pddlp::problem::Fact>>,
) -> Result<Vec<Fact>> {
) -> Result<BTreeSet<Fact>> {
translate(
predicates,
objects,
Expand All @@ -64,14 +65,14 @@ fn from_object_type(
pub fn from_object_types(
predicates: &Vec<Predicate>,
objects: &Option<Vec<pddlp::problem::Object>>,
) -> Vec<Fact> {
) -> BTreeSet<Fact> {
match objects {
Some(objects) => objects
.iter()
.enumerate()
.filter(|(_, object)| object.type_name.is_some())
.map(|(i, o)| from_object_type(predicates, i, o.type_name.unwrap()))
.collect(),
None => vec![],
None => BTreeSet::default(),
}
}
7 changes: 6 additions & 1 deletion src/translate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod goal;
mod init;
mod parameters;
mod predicates;
mod statics;
mod types;

use crate::task::Task;
Expand Down Expand Up @@ -42,14 +43,18 @@ pub fn translate_parsed(domain: &Domain, problem: &Problem) -> Result<Task> {
.collect();
let actions = actions::translate(&types, &predicates, &domain.actions);
let mut init = init::try_translate(&predicates, &objects, &problem.init)?;
init.append(&mut init::from_object_types(&predicates, &problem.objects));
init.extend(init::from_object_types(&predicates, &problem.objects));
let static_predicates = statics::find(&actions, &predicates);
let (statics, init) = statics::split(&static_predicates, init);
let goal = goal::try_translate(&predicates, &objects, &problem.goal)?;
Ok(Task {
domain_name,
problem_name,
predicates,
actions,
objects,
static_predicates,
statics,
init,
goal,
})
Expand Down
30 changes: 30 additions & 0 deletions src/translate/statics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use crate::{
state::Fact,
task::{action::Action, predicate::Predicate},
};
use std::collections::BTreeSet;

fn affects_predicate(action: &Action, predicate: usize) -> bool {
action
.effect
.iter()
.any(|effect| effect.predicate == predicate)
}

pub fn find(
actions: &Vec<Action>,
predicates: &Vec<Predicate>,
) -> BTreeSet<usize> {
(0..predicates.len())
.filter(|i| !actions.iter().any(|a| affects_predicate(a, *i)))
.collect()
}

pub fn split(
static_predicates: &BTreeSet<usize>,
facts: BTreeSet<Fact>,
) -> (BTreeSet<Fact>, BTreeSet<Fact>) {
facts
.into_iter()
.partition(|fact| static_predicates.contains(&fact.predicate()))
}

0 comments on commit af1cc8f

Please sign in to comment.