diff --git a/src/main/java/at/ac/tuwien/kr/alpha/api/query/AnswerSetQuery.java b/src/main/java/at/ac/tuwien/kr/alpha/api/query/AnswerSetQuery.java new file mode 100644 index 000000000..74db5479b --- /dev/null +++ b/src/main/java/at/ac/tuwien/kr/alpha/api/query/AnswerSetQuery.java @@ -0,0 +1,183 @@ +package at.ac.tuwien.kr.alpha.api.query; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import at.ac.tuwien.kr.alpha.common.AnswerSet; +import at.ac.tuwien.kr.alpha.common.Predicate; +import at.ac.tuwien.kr.alpha.common.atoms.Atom; +import at.ac.tuwien.kr.alpha.common.terms.ConstantTerm; +import at.ac.tuwien.kr.alpha.common.terms.FunctionTerm; +import at.ac.tuwien.kr.alpha.common.terms.Term; + +/** + * A query for ASP atoms matching a set of filter predicates. + */ +public final class AnswerSetQuery implements java.util.function.Predicate { + + private final Predicate predicate; + private Map> filters = new HashMap<>(); + + private AnswerSetQuery(Predicate pred) { + this.predicate = pred; + } + + /** + * Creates a new AnswerSetQuery that will match all atoms that are instances of the given {@link Predicate}. + * + * @param predicate the predicate to match against + * @return a new AnswerSetQuery matching against the given predicate + */ + public static AnswerSetQuery forPredicate(Predicate predicate) { + return new AnswerSetQuery(predicate); + } + + /** + * Adds a new filter to this AnswerSetQuery. + * For an atom a(t1, ..., tn), the term at index termIdx will be tested against the given filter predicate. + * + * @param termIdx the index of the term to test + * @param filter the test predicate to use on terms + * @return this AnswerSetQuery with the additional filter added + */ + public AnswerSetQuery withFilter(int termIdx, java.util.function.Predicate filter) { + if (termIdx >= this.predicate.getArity()) { + throw new IndexOutOfBoundsException( + "Predicate " + this.predicate.getName() + " has arity " + this.predicate.getArity() + ", term index " + termIdx + " is invalid!"); + } + if (this.filters.containsKey(termIdx)) { + java.util.function.Predicate currFilter = this.filters.get(termIdx); + this.filters.put(termIdx, currFilter.and(filter)); + } else { + this.filters.put(termIdx, filter); + } + return this; + } + + /** + * Convenience method - adds a filter to match names of symbolic constants against a string. + * + * @param termIdx + * @param str + * @return + */ + public AnswerSetQuery withConstantEquals(int termIdx, String str) { + return this.withFilter(termIdx, AnswerSetQuery.constantTermEquals(str)); + } + + /** + * Convenience method - adds a filter to match values of constant terms against a string. + * + * @param termIdx + * @param str + * @return + */ + public AnswerSetQuery withStringEquals(int termIdx, String str) { + return this.withFilter(termIdx, (term) -> { + if (!(term instanceof ConstantTerm)) { + return false; + } + if (((ConstantTerm) term).isSymbolic()) { + return false; + } + return ((ConstantTerm) term).getObject().equals(str); + }); + } + + /** + * Convenience method - adds a filter to check for function terms with a given function symbol and arity. + * + * @param termIdx + * @param funcSymbol + * @param funcArity + * @return + */ + public AnswerSetQuery withFunctionTerm(int termIdx, String funcSymbol, int funcArity) { + java.util.function.Predicate isFunction = (term) -> { + if (!(term instanceof FunctionTerm)) { + return false; + } + FunctionTerm funcTerm = (FunctionTerm) term; + if (!funcTerm.getSymbol().equals(funcSymbol)) { + return false; + } + if (funcTerm.getTerms().size() != funcArity) { + return false; + } + return true; + }; + return this.withFilter(termIdx, isFunction); + } + + /** + * Convenience method - adds a filter to check whether a term is equal to a given term. + * + * @param termIdx + * @param otherTerm + * @return + */ + public AnswerSetQuery withTermEquals(int termIdx, Term otherTerm) { + java.util.function.Predicate isEqual = (term) -> { + return term.equals(otherTerm); + }; + return this.withFilter(termIdx, isEqual); + } + + /** + * Applies this query to an atom. Filters are worked off in + * order of ascending term index in a conjunctive fashion, i.e. for an atom + * to match the query, all of its terms must satisfy all filters on these + * terms + * + * @param atom the atom to which to apply the query + * @return true iff the atom satisfies the query + */ + @Override + public boolean test(Atom atom) { + if (!atom.getPredicate().equals(predicate)) { + return false; + } + for (int i = 0; i < atom.getTerms().size(); i++) { + Term ithTerm = atom.getTerms().get(i); + java.util.function.Predicate ithFilter = filters.get(i); + if (ithFilter != null && !ithFilter.test(ithTerm)) { + return false; + } + } + return true; + } + + /** + * Applies this query to an {@link AnswerSet}. + * + * @param as + * @return + */ + public List applyTo(AnswerSet as) { + if (!as.getPredicates().contains(this.predicate)) { + return Collections.emptyList(); + } + return as.getPredicateInstances(this.predicate).stream().filter(this).collect(Collectors.toList()); + } + + private static java.util.function.Predicate constantTermEquals(final String str) { + java.util.function.Predicate equalsGivenString = (t) -> { + return AnswerSetQuery.constantTermEquals(t, str); + }; + return equalsGivenString; + } + + private static boolean constantTermEquals(Term term, String str) { + if (!(term instanceof ConstantTerm)) { + return false; + } + if (!((ConstantTerm) term).isSymbolic()) { + return false; + } + return ((ConstantTerm) term).getObject().toString().equals(str); + } + +} diff --git a/src/main/java/at/ac/tuwien/kr/alpha/common/AnswerSet.java b/src/main/java/at/ac/tuwien/kr/alpha/common/AnswerSet.java index 40a316159..8fe890d89 100644 --- a/src/main/java/at/ac/tuwien/kr/alpha/common/AnswerSet.java +++ b/src/main/java/at/ac/tuwien/kr/alpha/common/AnswerSet.java @@ -1,10 +1,12 @@ package at.ac.tuwien.kr.alpha.common; +import java.util.List; +import java.util.SortedSet; + import at.ac.tuwien.kr.alpha.Util; +import at.ac.tuwien.kr.alpha.api.query.AnswerSetQuery; import at.ac.tuwien.kr.alpha.common.atoms.Atom; -import java.util.SortedSet; - public interface AnswerSet extends Comparable { SortedSet getPredicates(); @@ -31,4 +33,14 @@ default int compareTo(AnswerSet other) { return 0; } + + /** + * Applies a given {@link AnswerSetQuery} to this AnswerSet. + * + * @param query the query to apply + * @return all atoms that are instances of the predicate specified by the query and meet the filters of the query + */ + default List query(AnswerSetQuery query) { + return query.applyTo(this); + } } diff --git a/src/main/java/at/ac/tuwien/kr/alpha/common/terms/ConstantTerm.java b/src/main/java/at/ac/tuwien/kr/alpha/common/terms/ConstantTerm.java index c0fa1fe43..98e035e16 100644 --- a/src/main/java/at/ac/tuwien/kr/alpha/common/terms/ConstantTerm.java +++ b/src/main/java/at/ac/tuwien/kr/alpha/common/terms/ConstantTerm.java @@ -151,4 +151,8 @@ public Term normalizeVariables(String renamePrefix, RenameCounter counter) { public T getObject() { return object; } + + public boolean isSymbolic() { + return this.symbolic; + } } diff --git a/src/test/java/at/ac/tuwien/kr/alpha/api/query/AnswerSetQueryTest.java b/src/test/java/at/ac/tuwien/kr/alpha/api/query/AnswerSetQueryTest.java new file mode 100644 index 000000000..fe4dacf6f --- /dev/null +++ b/src/test/java/at/ac/tuwien/kr/alpha/api/query/AnswerSetQueryTest.java @@ -0,0 +1,134 @@ +package at.ac.tuwien.kr.alpha.api.query; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.SortedSet; +import java.util.TreeSet; + +import org.junit.Assert; +import org.junit.Test; + +import at.ac.tuwien.kr.alpha.common.AnswerSet; +import at.ac.tuwien.kr.alpha.common.AnswerSetBuilder; +import at.ac.tuwien.kr.alpha.common.BasicAnswerSet; +import at.ac.tuwien.kr.alpha.common.Predicate; +import at.ac.tuwien.kr.alpha.common.atoms.Atom; +import at.ac.tuwien.kr.alpha.common.atoms.BasicAtom; +import at.ac.tuwien.kr.alpha.common.terms.ConstantTerm; +import at.ac.tuwien.kr.alpha.common.terms.FunctionTerm; +import at.ac.tuwien.kr.alpha.common.terms.Term; + +public class AnswerSetQueryTest { + + @Test + public void matchPredicate() { + AnswerSetBuilder bld = new AnswerSetBuilder(); + //@formatter:off + bld.predicate("p") + .symbolicInstance("a") + .symbolicInstance("b") + .predicate("q") + .symbolicInstance("x"); + //@formatter:on + AnswerSet as = bld.build(); + List queryResult = as.query(AnswerSetQuery.forPredicate(Predicate.getInstance("p", 1))); + Assert.assertEquals(2, queryResult.size()); + for (Atom a : queryResult) { + Assert.assertTrue(a.getPredicate().equals(Predicate.getInstance("p", 1))); + } + } + + @Test + public void matchSymbolicConstant() { + AnswerSetBuilder bld = new AnswerSetBuilder(); + bld.predicate("p") + .symbolicInstance("a") + .instance("a"); + AnswerSet as = bld.build(); + AnswerSetQuery constantQuery = AnswerSetQuery + .forPredicate(Predicate.getInstance("p", 1)) + .withConstantEquals(0, "a"); + List queryResult = as.query(constantQuery); + Assert.assertEquals(1, queryResult.size()); + } + + @Test + public void matchString() { + AnswerSetBuilder bld = new AnswerSetBuilder(); + bld.predicate("p") + .symbolicInstance("a") + .instance("a"); + AnswerSet as = bld.build(); + AnswerSetQuery stringQuery = AnswerSetQuery + .forPredicate(Predicate.getInstance("p", 1)) + .withStringEquals(0, "a"); + List queryResult = as.query(stringQuery); + Assert.assertEquals(1, queryResult.size()); + } + + @Test + public void matchEvenIntegers() { + AnswerSetBuilder bld = new AnswerSetBuilder(); + bld.predicate("p") + .instance(1).instance(2).instance(3).instance(4).instance(5) + .instance("bla").symbolicInstance("blubb"); + AnswerSet as = bld.build(); + java.util.function.Predicate isInteger = (term) -> { + if (!(term instanceof ConstantTerm)) { + return false; + } + String strValue = ((ConstantTerm) term).getObject().toString(); + return strValue.matches("[0-9]+"); + }; + AnswerSetQuery evenIntegers = AnswerSetQuery.forPredicate(Predicate.getInstance("p", 1)) + .withFilter(0, isInteger.and( + (term) -> Integer.valueOf(((ConstantTerm) term).getObject().toString()) % 2 == 0)); + List queryResult = as.query(evenIntegers); + Assert.assertEquals(2, queryResult.size()); + for (Atom atom : queryResult) { + ConstantTerm term = (ConstantTerm) atom.getTerms().get(0); + Assert.assertTrue(Integer.valueOf(term.getObject().toString()) % 2 == 0); + } + } + + @Test + public void matchXWithFuncTerm() { + Predicate p = Predicate.getInstance("p", 2); + Atom a1 = new BasicAtom(p, ConstantTerm.getSymbolicInstance("x"), FunctionTerm.getInstance("f", ConstantTerm.getSymbolicInstance("x"))); + Atom a2 = new BasicAtom(p, ConstantTerm.getSymbolicInstance("y"), FunctionTerm.getInstance("f", ConstantTerm.getSymbolicInstance("y"))); + Atom a3 = new BasicAtom(p, ConstantTerm.getSymbolicInstance("y"), FunctionTerm.getInstance("f", ConstantTerm.getSymbolicInstance("x"))); + Atom a4 = new BasicAtom(p, ConstantTerm.getSymbolicInstance("x"), FunctionTerm.getInstance("f", ConstantTerm.getSymbolicInstance("y"))); + Atom a5 = new BasicAtom(p, ConstantTerm.getSymbolicInstance("x"), ConstantTerm.getSymbolicInstance("f")); + SortedSet predicates = new TreeSet<>(); + predicates.add(p); + Map> instances = new HashMap<>(); + SortedSet ps = new TreeSet<>(); + ps.add(a1); + ps.add(a2); + ps.add(a3); + ps.add(a4); + ps.add(a5); + instances.put(p, ps); + AnswerSet as = new BasicAnswerSet(predicates, instances); + AnswerSetQuery query = AnswerSetQuery.forPredicate(Predicate.getInstance("p", 2)).withConstantEquals(0, "x").withFunctionTerm(1, "f", 1); + List queryResult = as.query(query); + Assert.assertEquals(2, queryResult.size()); + } + + @Test + public void matchTerm() { + AnswerSetBuilder bld = new AnswerSetBuilder(); + bld.predicate("p") + .instance(1).instance(2).instance(3).instance(4).instance(5) + .instance("bla").symbolicInstance("blubb"); + AnswerSet as = bld.build(); + + AnswerSetQuery equalTerm = AnswerSetQuery.forPredicate(Predicate.getInstance("p", 1)).withTermEquals(0, ConstantTerm.getInstance(1)); + List queryResult = as.query(equalTerm); + Assert.assertEquals(1, queryResult.size()); + Atom retrievedAtom = queryResult.get(0); + Assert.assertTrue(retrievedAtom.getTerms().get(0).equals(ConstantTerm.getInstance(1))); + } + +}