Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary write lock when rewriting a query term. #1592

Merged
merged 2 commits into from
Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions polar-core/src/polar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ impl Polar {
pub fn new_query_from_term(&self, mut term: Term, trace: bool) -> Query {
use crate::vm::{Goal, PolarVirtualMachine};
{
let mut kb = self.kb.write().unwrap();
term = rewrite_term(term, &mut kb);
let kb = self.kb.read().unwrap();
term = rewrite_term(term, &kb);
}
let query = Goal::Query { term: term.clone() };
let vm =
Expand Down
66 changes: 33 additions & 33 deletions polar-core/src/rewrites.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,13 @@ pub fn unwrap_and(term: &Term) -> TermList {
}

/// Rewrite a term.
pub fn rewrite_term(term: Term, kb: &mut KnowledgeBase) -> Term {
pub fn rewrite_term(term: Term, kb: &KnowledgeBase) -> Term {
let mut fld = Rewriter::new(kb);
fld.fold_term(term)
}

/// Rewrite a rule.
pub fn rewrite_rule(rule: Rule, kb: &mut KnowledgeBase) -> Rule {
pub fn rewrite_rule(rule: Rule, kb: &KnowledgeBase) -> Rule {
let mut fld = Rewriter::new(kb);
fld.fold_rule(rule)
}
Expand All @@ -289,23 +289,23 @@ mod tests {

#[test]
fn rewrite_anonymous_vars() {
let mut kb = KnowledgeBase::new();
let kb = KnowledgeBase::new();
let query = parse_query("[1, 2, 3] = [_, _, _]");
assert_eq!(
rewrite_term(query, &mut kb).to_string(),
rewrite_term(query, &kb).to_string(),
"[1, 2, 3] = [_1, _2, _3]"
);
}

#[test]
fn rewrite_rules() {
let mut kb = KnowledgeBase::new();
let kb = KnowledgeBase::new();
let rules = parse_rules("f(a.b);");
let rule = rules[0].clone();
assert_eq!(rule.to_string(), "f(a.b);");

// First rewrite
let rule = rewrite_rule(rule, &mut kb);
let rule = rewrite_rule(rule, &kb);
assert_eq!(rule.to_string(), "f(_value_1) if a.b = _value_1;");

// Check we can parse the rules back again
Expand All @@ -317,7 +317,7 @@ mod tests {
let rules = parse_rules("f(a.b.c);");
let rule = rules[0].clone();
assert_eq!(rule.to_string(), "f(a.b.c);");
let rule = rewrite_rule(rule, &mut kb);
let rule = rewrite_rule(rule, &kb);
assert_eq!(
rule.to_string(),
"f(_value_3) if a.b = _value_2 and _value_2.c = _value_3;"
Expand All @@ -326,15 +326,15 @@ mod tests {

#[test]
fn rewrite_forall_rhs_dots() {
let mut kb = KnowledgeBase::new();
let kb = KnowledgeBase::new();
let rules = parse_rules("foo(z, y) if forall(x in y, x.n < z);");
let rule = rewrite_rule(rules[0].clone(), &mut kb);
let rule = rewrite_rule(rules[0].clone(), &kb);
assert_eq!(
rule.to_string(),
"foo(z, y) if forall(x in y, _value_1 < z and x.n = _value_1);"
);

let query = rewrite_term(parse_query("forall(x in y, x.n < z)"), &mut kb);
let query = rewrite_term(parse_query("forall(x in y, x.n < z)"), &kb);
assert_eq!(
query.to_string(),
"forall(x in y, _value_2 < z and x.n = _value_2)"
Expand All @@ -343,13 +343,13 @@ mod tests {

#[test]
fn rewrite_nested_lookups() {
let mut kb = KnowledgeBase::new();
let kb = KnowledgeBase::new();

// Lookups with args
let rules = parse_rules("f(a, c) if a.b(c);");
let rule = rules[0].clone();
assert_eq!(rule.to_string(), "f(a, c) if a.b(c);");
let rule = rewrite_rule(rule, &mut kb);
let rule = rewrite_rule(rule, &kb);
assert_eq!(
rule.to_string(),
"f(a, c) if a.b(c) = _value_1 and _value_1;"
Expand All @@ -359,7 +359,7 @@ mod tests {
let rules = parse_rules("f(a, c, e) if a.b(c.d(e.f()));");
let rule = rules[0].clone();
assert_eq!(rule.to_string(), "f(a, c, e) if a.b(c.d(e.f()));");
let rule = rewrite_rule(rule, &mut kb);
let rule = rewrite_rule(rule, &kb);
assert_eq!(
rule.to_string(),
"f(a, c, e) if e.f() = _value_2 and c.d(_value_2) = _value_3 and a.b(_value_3) = _value_4 and _value_4;"
Expand All @@ -368,110 +368,110 @@ mod tests {

#[test]
fn rewrite_terms() {
let mut kb = KnowledgeBase::new();
let kb = KnowledgeBase::new();
let term = parse_query("x and a.b");
assert_eq!(term.to_string(), "x and a.b");
assert_eq!(
rewrite_term(term, &mut kb).to_string(),
rewrite_term(term, &kb).to_string(),
"x and a.b = _value_1 and _value_1"
);

let query = parse_query("f(a.b().c)");
assert_eq!(query.to_string(), "f(a.b().c)");
assert_eq!(
rewrite_term(query, &mut kb).to_string(),
rewrite_term(query, &kb).to_string(),
"a.b() = _value_2 and _value_2.c = _value_3 and f(_value_3)"
);

let term = parse_query("a.b = 1");
assert_eq!(
rewrite_term(term, &mut kb).to_string(),
rewrite_term(term, &kb).to_string(),
"a.b = _value_4 and _value_4 = 1"
);
let term = parse_query("{x: 1}.x = 1");
assert_eq!(term.to_string(), "{x: 1}.x = 1");
assert_eq!(
rewrite_term(term, &mut kb).to_string(),
rewrite_term(term, &kb).to_string(),
"{x: 1}.x = _value_5 and _value_5 = 1"
);
}

#[test]
fn rewrite_expressions() {
let mut kb = KnowledgeBase::new();
let kb = KnowledgeBase::new();

let term = parse_query("0 - 0 = 0");
assert_eq!(term.to_string(), "0 - 0 = 0");
assert_eq!(
rewrite_term(term, &mut kb).to_string(),
rewrite_term(term, &kb).to_string(),
"0 - 0 = _op_1 and _op_1 = 0"
);

let rules = parse_rules("sum(a, b, a + b);");
let rule = rules[0].clone();
assert_eq!(rule.to_string(), "sum(a, b, a + b);");
let rule = rewrite_rule(rule, &mut kb);
let rule = rewrite_rule(rule, &kb);
assert_eq!(rule.to_string(), "sum(a, b, _op_2) if a + b = _op_2;");

let rules = parse_rules("fib(n, a+b) if fib(n-1, a) and fib(n-2, b);");
let rule = rules[0].clone();
let rule = rewrite_rule(rule, &mut kb);
let rule = rewrite_rule(rule, &kb);
assert_eq!(rule.to_string(), "fib(n, _op_5) if n - 1 = _op_3 and fib(_op_3, a) and n - 2 = _op_4 and fib(_op_4, b) and a + b = _op_5;");
}

#[test]
fn rewrite_nested_literal() {
let mut kb = KnowledgeBase::new();
let kb = KnowledgeBase::new();
let term = parse_query("new Foo(x: bar.y)");
assert_eq!(term.to_string(), "new Foo(x: bar.y)");
assert_eq!(
rewrite_term(term, &mut kb).to_string(),
rewrite_term(term, &kb).to_string(),
"bar.y = _value_1 and new (Foo(x: _value_1), _instance_2) and _instance_2"
);

let term = parse_query("f(new Foo(x: bar.y))");
assert_eq!(term.to_string(), "f(new Foo(x: bar.y))");
assert_eq!(
rewrite_term(term, &mut kb).to_string(),
rewrite_term(term, &kb).to_string(),
"bar.y = _value_3 and new (Foo(x: _value_3), _instance_4) and f(_instance_4)"
);
}

#[test]
fn rewrite_class_constructor() {
let mut kb = KnowledgeBase::new();
let kb = KnowledgeBase::new();
let term = parse_query("new Foo(a: 1, b: 2)");
assert_eq!(term.to_string(), "new Foo(a: 1, b: 2)");

// @ means external constructor
assert_eq!(
rewrite_term(term, &mut kb).to_string(),
rewrite_term(term, &kb).to_string(),
"new (Foo(a: 1, b: 2), _instance_1) and _instance_1"
);
}

#[test]
fn rewrite_nested_class_constructor() {
let mut kb = KnowledgeBase::new();
let kb = KnowledgeBase::new();
let term = parse_query("new Foo(a: 1, b: new Foo(a: 2, b: 3))");
assert_eq!(term.to_string(), "new Foo(a: 1, b: new Foo(a: 2, b: 3))");

assert_eq!(
rewrite_term(term, &mut kb).to_string(),
rewrite_term(term, &kb).to_string(),
"new (Foo(a: 2, b: 3), _instance_1) and \
new (Foo(a: 1, b: _instance_1), _instance_2) and _instance_2"
);
}

#[test]
fn rewrite_rules_constructor() {
let mut kb = KnowledgeBase::new();
let kb = KnowledgeBase::new();
let mut rules = parse_rules("rule_test(new Foo(a: 1, b: 2));");
let rule = rules.pop().unwrap();
assert_eq!(rule.to_string(), "rule_test(new Foo(a: 1, b: 2));");
assert!(rules.is_empty());

let rule = rewrite_rule(rule, &mut kb);
let rule = rewrite_rule(rule, &kb);
assert_eq!(
rule.to_string(),
"rule_test(_instance_1) if new (Foo(a: 1, b: 2), _instance_1);"
Expand All @@ -480,12 +480,12 @@ mod tests {

#[test]
fn rewrite_not_with_lookup() {
let mut kb = KnowledgeBase::new();
let kb = KnowledgeBase::new();
let term = parse_query("not foo.x = 1");
assert_eq!(term.to_string(), "not foo.x = 1");

pretty_assertions::assert_eq!(
rewrite_term(term, &mut kb).to_string(),
rewrite_term(term, &kb).to_string(),
"not (_value_1 = 1 and foo.x = _value_1)"
)
}
Expand Down