Skip to content

Commit

Permalink
Make anti-unification more sort-aware (#598)
Browse files Browse the repository at this point in the history
- Changed `anti_unify` and `anti_unify_with_constraints` to require a
`KDefinition` which they use to grab the tightest sort for the terms
they are abstracting
- Since these actions are now associated with a specific K definition,
moved them into `KDefinition` as methods.
- Adds `KDefinition.least_common_supersort` which returns the most
specific sort that will cover two terms (as long as one sort is a
subsort of the other).
- Uses this function in `KDefinition.sort` to determine the sort of a
`KRewrite`.
- Rewrote the tests for `anti_unify` and `anti_unify_with_constraints`
to use `KDefinition` and test that the abstracted variables are now more
tightly typed based on the original terms.

This is an attempt to fix an issue in
runtimeverification/evm-semantics#1934 where the
looser sort bound is preventing simplification after the merging
(anti-unification) of two nodes.

---------

Co-authored-by: devops <devops@runtimeverification.com>
  • Loading branch information
nwatson22 and devops authored Aug 16, 2023
1 parent 9d506c9 commit c34b3a4
Show file tree
Hide file tree
Showing 8 changed files with 307 additions and 158 deletions.
2 changes: 1 addition & 1 deletion pyk/package/version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.415
0.1.416
2 changes: 1 addition & 1 deletion pyk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "pyk"
version = "0.1.415"
version = "0.1.416"
description = ""
authors = [
"Runtime Verification, Inc. <contact@runtimeverification.com>",
Expand Down
63 changes: 60 additions & 3 deletions pyk/src/pyk/cterm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from itertools import chain
from typing import TYPE_CHECKING

from .kast.inner import KApply, KInner, KRewrite, KVariable, Subst
from .kast.inner import KApply, KInner, KRewrite, KToken, KVariable, Subst, bottom_up
from .kast.kast import KAtt
from .kast.manip import (
abstract_term_safely,
apply_existential_substitutions,
count_vars,
flatten_label,
Expand All @@ -22,13 +23,16 @@
)
from .kast.outer import KClaim, KRule
from .prelude.k import GENERATED_TOP_CELL
from .prelude.ml import is_top, mlAnd, mlImplies, mlTop
from .prelude.kbool import orBool
from .prelude.ml import is_top, mlAnd, mlEqualsTrue, mlImplies, mlTop
from .utils import unique

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
from typing import Any

from .kast.outer import KDefinition


@dataclass(frozen=True, order=True)
class CTerm:
Expand Down Expand Up @@ -56,7 +60,7 @@ def from_dict(dct: dict[str, Any]) -> CTerm:
@staticmethod
def _check_config(config: KInner) -> None:
if not isinstance(config, KApply) or not config.is_cell:
raise ValueError('Expected cell label, found: {config.label.name}')
raise ValueError(f'Expected cell label, found: {config}')

@staticmethod
def _normalize_constraints(constraints: Iterable[KInner]) -> tuple[KInner, ...]:
Expand Down Expand Up @@ -138,6 +142,59 @@ def _ml_impl(antecedents: Iterable[KInner], consequents: Iterable[KInner]) -> KI
def add_constraint(self, new_constraint: KInner) -> CTerm:
return CTerm(self.config, [new_constraint] + list(self.constraints))

def anti_unify(
self, other: CTerm, keep_values: bool = False, kdef: KDefinition | None = None
) -> tuple[CTerm, CSubst, CSubst]:
def disjunction_from_substs(subst1: Subst, subst2: Subst) -> KInner:
if KToken('true', 'Bool') in [subst1.pred, subst2.pred]:
return mlTop()
return mlEqualsTrue(orBool([subst1.pred, subst2.pred]))

new_config, self_subst, other_subst = anti_unify(self.config, other.config, kdef=kdef)
common_constraints = [constraint for constraint in self.constraints if constraint in other.constraints]

new_cterm = CTerm(
config=new_config, constraints=([disjunction_from_substs(self_subst, other_subst)] if keep_values else [])
)

new_constraints = []
fvs = free_vars(new_cterm.kast)
len_fvs = 0
while len_fvs < len(fvs):
len_fvs = len(fvs)
for constraint in common_constraints:
if constraint not in new_constraints:
constraint_fvs = free_vars(constraint)
if any(fv in fvs for fv in constraint_fvs):
new_constraints.append(constraint)
fvs.extend(constraint_fvs)

for constraint in new_constraints:
new_cterm = new_cterm.add_constraint(constraint)
self_csubst = new_cterm.match_with_constraint(self)
other_csubst = new_cterm.match_with_constraint(other)
if self_csubst is None or other_csubst is None:
raise ValueError(
f'Anti-unification failed to produce a more general state: {(new_cterm, (self, self_csubst), (other, other_csubst))}'
)
return (new_cterm, self_csubst, other_csubst)


def anti_unify(state1: KInner, state2: KInner, kdef: KDefinition | None = None) -> tuple[KInner, Subst, Subst]:
def _rewrites_to_abstractions(_kast: KInner) -> KInner:
if type(_kast) is KRewrite:
sort = kdef.sort(_kast) if kdef else None
return abstract_term_safely(_kast, sort=sort)
return _kast

minimized_rewrite = push_down_rewrites(KRewrite(state1, state2))
abstracted_state = bottom_up(_rewrites_to_abstractions, minimized_rewrite)
subst1 = abstracted_state.match(state1)
subst2 = abstracted_state.match(state2)
if subst1 is None or subst2 is None:
raise ValueError('Anti-unification failed to produce a more general state!')
return (abstracted_state, subst1, subst2)


@dataclass(frozen=True, order=True)
class CSubst:
Expand Down
54 changes: 1 addition & 53 deletions pyk/src/pyk/kast/manip.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..prelude.k import DOTS, GENERATED_TOP_CELL
from ..prelude.kbool import FALSE, TRUE, andBool, impliesBool, notBool, orBool
from ..prelude.ml import mlAnd, mlEqualsTrue, mlImplies, mlOr, mlTop
from ..prelude.ml import mlAnd, mlEqualsTrue, mlOr
from ..utils import find_common_items, hash_str
from .inner import KApply, KRewrite, KSequence, KToken, KVariable, Subst, bottom_up, top_down, var_occurrences
from .kast import EMPTY_ATT, KAtt, WithKAtt
Expand Down Expand Up @@ -582,58 +582,6 @@ def _abstract(k: KInner) -> KVariable:
return new_var


def anti_unify(state1: KInner, state2: KInner) -> tuple[KInner, Subst, Subst]:
def _rewrites_to_abstractions(_kast: KInner) -> KInner:
if type(_kast) is KRewrite:
return abstract_term_safely(_kast)
return _kast

minimized_rewrite = push_down_rewrites(KRewrite(state1, state2))
abstracted_state = bottom_up(_rewrites_to_abstractions, minimized_rewrite)
subst1 = abstracted_state.match(state1)
subst2 = abstracted_state.match(state2)
if subst1 is None or subst2 is None:
raise ValueError('Anti-unification failed to produce a more general state!')
return (abstracted_state, subst1, subst2)


def anti_unify_with_constraints(
constrained_term_1: KInner,
constrained_term_2: KInner,
implications: bool = False,
constraint_disjunct: bool = False,
abstracted_disjunct: bool = False,
) -> KInner:
def disjunction_from_substs(subst1: Subst, subst2: Subst) -> KInner:
if KToken('true', 'Bool') in [subst1.pred, subst2.pred]:
return mlTop()
return mlEqualsTrue(orBool([subst1.pred, subst2.pred]))

state1, constraint1 = split_config_and_constraints(constrained_term_1)
state2, constraint2 = split_config_and_constraints(constrained_term_2)
constraints1 = flatten_label('#And', constraint1)
constraints2 = flatten_label('#And', constraint2)
state, subst1, subst2 = anti_unify(state1, state2)

constraints = [c for c in constraints1 if c in constraints2]
constraint1 = mlAnd([c for c in constraints1 if c not in constraints])
constraint2 = mlAnd([c for c in constraints2 if c not in constraints])
implication1 = mlImplies(constraint1, subst1.ml_pred)
implication2 = mlImplies(constraint2, subst2.ml_pred)

if abstracted_disjunct:
constraints.append(disjunction_from_substs(subst1, subst2))

if implications:
constraints.append(implication1)
constraints.append(implication2)

if constraint_disjunct:
constraints.append(mlOr([constraint1, constraint2]))

return mlAnd([state] + constraints)


def apply_existential_substitutions(constrained_term: KInner) -> KInner:
state, constraint = split_config_and_constraints(constrained_term)
constraints = flatten_label('#And', constraint)
Expand Down
20 changes: 18 additions & 2 deletions pyk/src/pyk/kast/outer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,8 +1101,11 @@ def sort(self, kast: KInner) -> KSort | None:
case KToken(_, sort) | KVariable(_, sort):
return sort
case KRewrite(lhs, rhs):
sort = self.sort(lhs)
return sort if sort == self.sort(rhs) else None
lhs_sort = self.sort(lhs)
rhs_sort = self.sort(rhs)
if lhs_sort and rhs_sort:
return self.least_common_supersort(lhs_sort, rhs_sort)
return None
case KSequence(_):
return KSort('K')
case KApply(label, _):
Expand All @@ -1128,13 +1131,26 @@ def sort_strict(self, kast: KInner) -> KSort:
raise ValueError(f'Could not determine sort of term: {kast}')
return sort

def least_common_supersort(self, sort1: KSort, sort2: KSort) -> KSort | None:
if sort1 == sort2:
return sort1
if sort1 in self.subsorts(sort2):
return sort2
if sort2 in self.subsorts(sort1):
return sort1
# Computing least common supersort is not currently supported if sort1 is not a subsort of sort2 or
# vice versa. In that case there may be more than one LCS.
return None

def greatest_common_subsort(self, sort1: KSort, sort2: KSort) -> KSort | None:
if sort1 == sort2:
return sort1
if sort1 in self.subsorts(sort2):
return sort1
if sort2 in self.subsorts(sort1):
return sort2
# Computing greatest common subsort is not currently supported if sort1 is not a subsort of sort2 or
# vice versa. In that case there may be more than one GCS.
return None

# Sorts like Int cannot be injected directly into sort K so they are embedded in a KSequence.
Expand Down
138 changes: 136 additions & 2 deletions pyk/src/tests/integration/kcfg/test_imp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

from pyk.cterm import CSubst, CTerm
from pyk.kast.inner import KApply, KSequence, KSort, KToken, KVariable, Subst
from pyk.kast.manip import minimize_term
from pyk.kast.manip import get_cell, minimize_term
from pyk.kcfg.semantics import KCFGSemantics
from pyk.kcfg.show import KCFGShow
from pyk.prelude.kbool import BOOL, notBool
from pyk.prelude.kbool import BOOL, notBool, orBool
from pyk.prelude.kint import intToken
from pyk.prelude.ml import mlAnd, mlBottom, mlEqualsFalse, mlEqualsTrue, mlTop
from pyk.proof import APRBMCProof, APRBMCProver, APRProof, APRProver, ProofStatus
Expand Down Expand Up @@ -1147,3 +1147,137 @@ def test_fail_fast(
assert len(proof.pending) == 1
assert len(proof.terminal) == 1
assert len(proof.failing) == 1

def test_anti_unify_forget_values(
self,
kcfg_explore: KCFGExplore,
kprint: KPrint,
) -> None:
cterm1 = self.config(
kprint=kprint,
k='int $n ; { }',
state='N |-> X:Int',
constraint=mlAnd(
[
mlEqualsTrue(KApply('_>Int_', [KVariable('K', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('N', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('X', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('Y', 'Int'), KToken('1', 'Int')])),
]
),
)
cterm2 = self.config(
kprint=kprint,
k='int $n ; { }',
state='N |-> Y:Int',
constraint=mlAnd(
[
mlEqualsTrue(KApply('_>Int_', [KVariable('K', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('N', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('X', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('Y', 'Int'), KToken('1', 'Int')])),
]
),
)

anti_unifier, subst1, subst2 = cterm1.anti_unify(cterm2, keep_values=False, kdef=kprint.definition)

k_cell = get_cell(anti_unifier.kast, 'STATE_CELL')
assert type(k_cell) is KApply
assert k_cell.label.name == '_|->_'
assert type(k_cell.args[1]) is KVariable
abstracted_var: KVariable = k_cell.args[1]

expected_anti_unifier = self.config(
kprint=kprint,
k='int $n ; { }',
state=f'N |-> {abstracted_var.name}:Int',
constraint=mlEqualsTrue(KApply('_>Int_', [KVariable('N', 'Int'), KToken('1', 'Int')])),
)

assert anti_unifier.kast == expected_anti_unifier.kast

def test_anti_unify_keep_values(
self,
kcfg_explore: KCFGExplore,
kprint: KPrint,
) -> None:
cterm1 = self.config(
kprint=kprint,
k='int $n ; { }',
state='N |-> X:Int',
constraint=mlAnd(
[
mlEqualsTrue(KApply('_>Int_', [KVariable('K', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('N', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('X', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('Y', 'Int'), KToken('1', 'Int')])),
]
),
)
cterm2 = self.config(
kprint=kprint,
k='int $n ; { }',
state='N |-> Y:Int',
constraint=mlAnd(
[
mlEqualsTrue(KApply('_>Int_', [KVariable('K', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('N', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('X', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('Y', 'Int'), KToken('1', 'Int')])),
]
),
)

anti_unifier, subst1, subst2 = cterm1.anti_unify(cterm2, keep_values=True, kdef=kprint.definition)

k_cell = get_cell(anti_unifier.kast, 'STATE_CELL')
assert type(k_cell) is KApply
assert k_cell.label.name == '_|->_'
assert type(k_cell.args[1]) is KVariable
abstracted_var: KVariable = k_cell.args[1]

expected_anti_unifier = self.config(
kprint=kprint,
k='int $n ; { }',
state=f'N |-> {abstracted_var.name}:Int',
constraint=mlAnd(
[
mlEqualsTrue(KApply('_>Int_', [KVariable('N', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('X', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(KApply('_>Int_', [KVariable('Y', 'Int'), KToken('1', 'Int')])),
mlEqualsTrue(
orBool(
[
KApply('_==K_', [KVariable(name=abstracted_var.name), KVariable('X', 'Int')]),
KApply('_==K_', [KVariable(name=abstracted_var.name), KVariable('Y', 'Int')]),
]
)
),
]
),
)

assert anti_unifier.kast == expected_anti_unifier.kast

def test_anti_unify_subst_true(
self,
kcfg_explore: KCFGExplore,
kprint: KPrint,
) -> None:
cterm1 = self.config(
kprint=kprint,
k='int $n ; { }',
state='N |-> 0',
constraint=mlEqualsTrue(KApply('_==K_', [KVariable('N', 'Int'), KToken('1', 'Int')])),
)
cterm2 = self.config(
kprint=kprint,
k='int $n ; { }',
state='N |-> 0',
constraint=mlEqualsTrue(KApply('_==K_', [KVariable('N', 'Int'), KToken('1', 'Int')])),
)

anti_unifier, _, _ = cterm1.anti_unify(cterm2, keep_values=True, kdef=kprint.definition)

assert anti_unifier.kast == cterm1.kast
Loading

0 comments on commit c34b3a4

Please sign in to comment.