From e23033bd407e702719933ba547e65e1436cbfdac Mon Sep 17 00:00:00 2001 From: Tamas Toth Date: Tue, 8 Mar 2022 14:03:05 +0100 Subject: [PATCH] Refactor KCFG * Remove field split * Remove field loop * Remove field terminal * Remove field frontier * Remove field subsumptions * Remove method getTermHash * Remove method getNodeAttributes * Refactor and rename method getStateIdByShortHash * Refactor and rename insertNode * Refactor and rename removeNode * Remove class attribute _NODE_ATTRS * Refactor methods to_dict and from_dict * Add type annotations for fields * Change fields from List to Set where possible * Rename field states to _nodes * Create class KCFG.Node * Rename method add_node to create_node * Remove methods _encode and _decode * Create class KCFG.Edge * Rename field graph to _edges * Change edge type * Remove methods _assign and invalidateStates * Add method get_edge * Remove method getEdges * Rename get_node to node * Rename get_edge to edge * Add optional parameters to method edges * Remove methods getSuccessors and getPredecessors * Make return type of method edge Optional * Remove method getEdgeCondition * Remove methods markEdge* * Refactor abstractions * Refactor method getEdgeSentence to Edge.to_rule * Remove method getModule * Add class KCFG.Cover * Add class KCFG.EdgeLike * Simplify to_dot * Reimplement path algorithm * Reimplement transitive closure * Implement property frontier * Rename init to _init * Rename target to _target * Rename stuck to _stuck * Refactor getPathCondition * Add method prune * Add class FrozenDict * Add class Subst * Change KAtt backing data type from frozenset to FrozenDict * Change return type of KCFG.Cover.subst to Subst * Rename kastManip.removeGeneratedCells * Refactor and rename countVarOccurances * Add function if_ktype * Add function count_rhs_vars * Add further utils * Add class KRuleLike * Add class CTerm * Make kast.to_hash a cached property * Add properties for sets of nodes * Change type of KCFG.Node.term to CTerm * Change CTerm.constraints from FrozenSet to Tuple * Rename inline_generated_top * Move class Subst to module subst * Add tests for count_vars * Move AST traversal functions to module kast * Change return type of kastManip.match to Subst * Move function match from kastManip to subst * Add tests for matching property --- .../src/main/scripts/lib/pyk/pyk/cterm.py | 43 + .../src/main/scripts/lib/pyk/pyk/kast.py | 71 +- .../src/main/scripts/lib/pyk/pyk/kastManip.py | 195 ++-- .../src/main/scripts/lib/pyk/pyk/kcfg.py | 863 +++++++++--------- .../src/main/scripts/lib/pyk/pyk/subst.py | 113 +++ .../lib/pyk/pyk/tests/test_count_vars.py | 39 + .../scripts/lib/pyk/pyk/tests/test_kcfg.py | 185 ++-- .../scripts/lib/pyk/pyk/tests/test_match.py | 36 + .../scripts/lib/pyk/pyk/tests/test_subst.py | 84 ++ .../src/main/scripts/lib/pyk/pyk/utils.py | 64 +- .../tests/pyk/configuration_test.py | 8 +- 11 files changed, 1041 insertions(+), 660 deletions(-) create mode 100644 k-distribution/src/main/scripts/lib/pyk/pyk/cterm.py create mode 100644 k-distribution/src/main/scripts/lib/pyk/pyk/subst.py create mode 100644 k-distribution/src/main/scripts/lib/pyk/pyk/tests/test_count_vars.py create mode 100644 k-distribution/src/main/scripts/lib/pyk/pyk/tests/test_match.py create mode 100644 k-distribution/src/main/scripts/lib/pyk/pyk/tests/test_subst.py diff --git a/k-distribution/src/main/scripts/lib/pyk/pyk/cterm.py b/k-distribution/src/main/scripts/lib/pyk/pyk/cterm.py new file mode 100644 index 00000000000..1b1078b2a86 --- /dev/null +++ b/k-distribution/src/main/scripts/lib/pyk/pyk/cterm.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass +from functools import cached_property +from itertools import chain +from typing import List, Optional, Tuple + +from .kast import KInner, flattenLabel +from .kastManip import match, splitConfigAndConstraints +from .prelude import mlAnd +from .subst import Subst + + +@dataclass(frozen=True) +class CTerm: + config: KInner # TODO Optional? + constraints: Tuple[KInner, ...] + + def __init__(self, cterm: KInner) -> None: + config, constraint = splitConfigAndConstraints(cterm) + constraints = tuple(flattenLabel('#And', constraint)) + object.__setattr__(self, 'config', config) + object.__setattr__(self, 'constraints', constraints) + + def __iter__(self): + return chain([self.config], self.constraints) + + @cached_property + def cterm(self) -> KInner: + return mlAnd(self) + + @property + def hash(self) -> str: + return self.cterm.hash + + def match(self, pattern: 'CTerm') -> Optional[Tuple[Subst, List[KInner]]]: + subst = match(pattern=pattern.config, term=self.config) + + if subst is None: + return None + + assumptions = set(self.constraints) + obligations = [constraint for constraint in map(subst, pattern.constraints) if constraint not in assumptions] + + return subst, obligations diff --git a/k-distribution/src/main/scripts/lib/pyk/pyk/kast.py b/k-distribution/src/main/scripts/lib/pyk/pyk/kast.py index e8951ed284c..b399f67ac82 100644 --- a/k-distribution/src/main/scripts/lib/pyk/pyk/kast.py +++ b/k-distribution/src/main/scripts/lib/pyk/pyk/kast.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from dataclasses import InitVar, dataclass from enum import Enum +from functools import cached_property from itertools import chain from typing import ( Any, @@ -26,8 +27,11 @@ from typing_extensions import TypeAlias from .cli_utils import fatal, warning +from .utils import FrozenDict, hash_str T = TypeVar('T', bound='KAst') +W = TypeVar('W', bound='WithKAtt') +KI = TypeVar('KI', bound='KInner') class KAst(ABC): @@ -50,6 +54,11 @@ def from_json(s: str) -> 'KAst': def to_json(self) -> str: return json.dumps(self.to_dict(), sort_keys=True) + @final + @cached_property + def hash(self) -> str: + return hash_str(self.to_json()) + @classmethod def _check_node(cls: Type[T], d: Dict[str, Any], expected: Optional[str] = None) -> None: expected = expected if expected is not None else cls.__name__ @@ -61,23 +70,19 @@ def _check_node(cls: Type[T], d: Dict[str, Any], expected: Optional[str] = None) @final @dataclass(frozen=True) class KAtt(KAst, Mapping[str, Any]): - _atts: FrozenSet[Tuple[str, Any]] + atts: FrozenDict[str, Any] def __init__(self, atts: Mapping[str, Any] = {}): - object.__setattr__(self, '_atts', frozenset(atts.items())) - - def __getitem__(self, key: str) -> Any: - return self.atts[key] + object.__setattr__(self, 'atts', FrozenDict(atts)) def __iter__(self) -> Iterator[str]: - return (k for k, _ in self._atts) + return iter(self.atts) def __len__(self) -> int: - return len(self._atts) + return len(self.atts) - @property - def atts(self) -> Dict[str, Any]: - return dict(self._atts) + def __getitem__(self, key: str) -> Any: + return self.atts[key] @staticmethod def of(**atts: Any) -> 'KAtt': @@ -89,7 +94,7 @@ def from_dict(cls: Type['KAtt'], d: Dict[str, Any]) -> 'KAtt': return KAtt(atts=d['att']) def to_dict(self) -> Dict[str, Any]: - return {'node': 'KAtt', 'att': self.atts} + return {'node': 'KAtt', 'att': dict(self.atts)} def let(self, *, atts: Optional[Mapping[str, Any]] = None) -> 'KAtt': atts = atts if atts is not None else self.atts @@ -102,10 +107,6 @@ def update(self, atts: Mapping[str, Any]) -> 'KAtt': EMPTY_ATT: Final = KAtt() -W = TypeVar('W', bound='WithKAtt') -KI = TypeVar('KI', bound='KInner') - - class WithKAtt(KAst, ABC): att: KAtt @@ -784,9 +785,26 @@ def let_att(self, att: KAtt) -> 'KBubble': return self.let(att=att) +class KRuleLike(KSentence, ABC): + body: KInner + requires: KInner + ensures: KInner + + _RULE_LIKE_NODES: Final = {'KRule', 'KClaim'} + + @classmethod + @abstractmethod + def from_dict(cls: Type['KRuleLike'], d: Dict[str, Any]) -> 'KRuleLike': + node = d['node'] + if node in KRuleLike._RULE_LIKE_NODES: + return globals()[node].from_dict(d) + + raise ValueError(f"Expected KRuleLike label as 'node' value, found: '{node}'") + + @final @dataclass(frozen=True) -class KRule(KSentence): +class KRule(KRuleLike): body: KInner requires: KInner ensures: KInner @@ -830,7 +848,7 @@ def let_att(self, att: KAtt) -> 'KRule': @final @dataclass(frozen=True) -class KClaim(KSentence): +class KClaim(KRuleLike): body: KInner requires: KInner ensures: KInner @@ -1076,6 +1094,25 @@ def let_att(self, att: KAtt) -> 'KDefinition': return self.let(att=att) +# TODO make method of KInner +def bottom_up(f: Callable[[KInner], KInner], kinner: KInner) -> KInner: + return f(kinner.map_inner(lambda _kinner: bottom_up(f, _kinner))) + + +# TODO make method of KInner +def top_down(f: Callable[[KInner], KInner], kinner: KInner) -> KInner: + return f(kinner).map_inner(lambda _kinner: top_down(f, _kinner)) + + +# TODO replace by method that does not reconstruct the AST +def collect(callback: Callable[[KInner], None], kinner: KInner) -> None: + def f(kinner: KInner) -> KInner: + callback(kinner) + return kinner + + bottom_up(f, kinner) + + def flattenLabel(label, kast): """Given a cons list, return a flat Python list of the elements. diff --git a/k-distribution/src/main/scripts/lib/pyk/pyk/kastManip.py b/k-distribution/src/main/scripts/lib/pyk/pyk/kastManip.py index c42482b7cf8..ecc8c241379 100644 --- a/k-distribution/src/main/scripts/lib/pyk/pyk/kastManip.py +++ b/k-distribution/src/main/scripts/lib/pyk/pyk/kastManip.py @@ -1,9 +1,9 @@ -from typing import Callable, TypeVar +from collections import Counter +from typing import Callable, Mapping, Type, TypeVar from .cli_utils import fatal from .kast import ( KApply, - KAst, KAtt, KClaim, KDefinition, @@ -15,9 +15,12 @@ KToken, KVariable, WithKAtt, + bottom_up, + collect, flattenLabel, klabelEmptyK, ktokenDots, + top_down, ) from .prelude import ( buildAssoc, @@ -28,76 +31,26 @@ mlOr, mlTop, ) -from .utils import combine_dicts, dedupe, find_common_items, hash_str +from .subst import Subst, match +from .utils import dedupe, find_common_items, hash_str -T = TypeVar('T', bound=KAst) KI = TypeVar('KI', bound=KInner) W = TypeVar('W', bound=WithKAtt) -def match(pattern, kast): - """Perform syntactic pattern matching and return the substitution. - - - Input: a pattern and a kast term. - - Output: substitution instantiating the pattern to the kast term. - """ - subst = {} - if type(pattern) is KVariable: - return {pattern.name: kast} - if type(pattern) is KToken and type(kast) is KToken: - return {} if pattern.token == kast.token else None - if type(pattern) is KApply and type(kast) is KApply \ - and pattern.label == kast.label and pattern.arity == kast.arity: - for patternArg, kastArg in zip(pattern.args, kast.args): - argSubst = match(patternArg, kastArg) - subst = combine_dicts(subst, argSubst) - if subst is None: - return None - return subst - if type(pattern) is KRewrite and type(kast) is KRewrite: - lhsSubst = match(pattern.lhs, kast.lhs) - rhsSubst = match(pattern.rhs, kast.rhs) - return combine_dicts(lhsSubst, rhsSubst) - if type(pattern) is KSequence and type(kast) is KSequence and pattern.arity == kast.arity: - for (patternItem, substItem) in zip(pattern.items, kast.items): - itemSubst = match(patternItem, substItem) - subst = combine_dicts(subst, itemSubst) - if subst is None: - return None - return subst - return None - - -# TODO make method of KInner -def traverseBottomUp(kinner: KInner, f: Callable[[KInner], KInner]) -> KInner: - return f(kinner.map_inner(lambda _kinner: traverseBottomUp(_kinner, f))) +def if_ktype(ktype: Type[KI], then: Callable[[KI], KInner]) -> Callable[[KInner], KInner]: + def fun(term: KInner): + if isinstance(term, ktype): + return then(term) + return term + return fun -# TODO make method of KInner -def traverseTopDown(kinner: KInner, f: Callable[[KInner], KInner]) -> KInner: - return f(kinner).map_inner(lambda _kinner: traverseTopDown(_kinner, f)) - - -# TODO replace by method that does not reconstruct the AST -def collectBottomUp(kinner: KInner, callback: Callable[[KInner], None]) -> None: - def f(kinner: KInner) -> KInner: - callback(kinner) - return kinner - - traverseBottomUp(kinner, f) - - -def substitute(pattern, substitution): - """Apply a substitution to a pattern. - - - Input: a pattern with free variables and a substitution. - - Output: the pattern with the substitution applied. - """ - def replace(k): - if type(k) is KVariable and k.name in substitution: - return substitution[k.name] - return k - return traverseBottomUp(pattern, replace) +# TODO remove +def substitute(pattern: KInner, subst: Mapping[str, KInner]) -> KInner: + if not isinstance(subst, Subst): + subst = Subst(subst) + return subst(pattern) def whereMatchingBottomUp(effect, matchPattern, pattern): @@ -107,7 +60,7 @@ def _effect(k): if matchingSubst is not None: newK = effect(matchingSubst) return newK - return traverseBottomUp(_effect, pattern) + return bottom_up(_effect, pattern) def replaceKLabels(pattern, klabelMap): @@ -115,7 +68,7 @@ def replace(k): if type(k) is KApply and k.label in klabelMap: return k.let(label=klabelMap[k.label]) return k - return traverseBottomUp(pattern, replace) + return bottom_up(replace, pattern) def rewriteWith(rule, pattern): @@ -137,7 +90,7 @@ def rewriteAnywhereWith(rule, pattern): - Input: A rule to rewrite with, and a pattern to rewrite. - Output: The pattern with rewrites applied at every node once starting from the bottom. """ - return traverseBottomUp(pattern, lambda p: rewriteWith(rule, p)) + return bottom_up(lambda p: rewriteWith(rule, p), pattern) def replaceWith(rule, pattern): @@ -148,7 +101,7 @@ def replaceWith(rule, pattern): def replaceAnywhereWith(rule, pattern): - return traverseBottomUp(pattern, lambda p: replaceWith(rule, p)) + return bottom_up(lambda p: replaceWith(rule, p), pattern) def unsafeMlPredToBool(k): @@ -208,32 +161,63 @@ def addOccurance(k): if match(pattern, k): occurances.append(k) - collectBottomUp(kast, addOccurance) + collect(addOccurance, kast) return occurances -def countVarOccurances(kast, numOccurances=None): - """Count the number of occurances of each variable in a proof. +def extract_lhs(term: KInner) -> KInner: + return top_down(if_ktype(KRewrite, lambda rw: rw.lhs), term) + + +def extract_rhs(term: KInner) -> KInner: + return top_down(if_ktype(KRewrite, lambda rw: rw.rhs), term) - - Input: Kast term. - - Output: Map of variable names to their number of occurances. - """ - numOccurances = {} if numOccurances is None else numOccurances - def _getNumOccurances(_kast): - if type(_kast) is KVariable: - vName = _kast.name - if vName in numOccurances: - numOccurances[vName] += 1 - else: - numOccurances[vName] = 1 +def count_vars(term: KInner) -> Counter: + counter: Counter = Counter() - collectBottomUp(kast, _getNumOccurances) - return numOccurances + def count(term: KInner) -> None: + if type(term) is KVariable: + counter[term.name] += 1 + + collect(count, term) + return counter + + +def count_rhs_vars(term: KInner) -> Counter: + def recur(term: KInner, *, rhs=False) -> Counter: + if type(term) is KVariable: + return Counter(term.name) if rhs else Counter() + if type(term) is KRewrite: + return recur(term.rhs, rhs=True) + if type(term) is KApply: + return sum((recur(t, rhs=rhs) for t in term.args), Counter()) + if type(term) is KSequence: + return sum((recur(t, rhs=rhs) for t in term.items), Counter()) + return Counter() + return recur(term) def collectFreeVars(kast): - return list(countVarOccurances(kast).keys()) + return list(count_vars(kast).keys()) + + +def drop_var_prefixes(term: KInner) -> KInner: + term = top_down(if_ktype(KVariable, drop_ques), term) + term = top_down(if_ktype(KVariable, drop_unds), term) + return term + + +def drop_ques(variable: KVariable) -> KVariable: + if variable.name.startswith('?'): + return variable.let(name=variable.name[1:]) + return variable + + +def drop_unds(variable: KVariable) -> KVariable: + if variable.name.startswith('_'): + return variable.let(name=variable.name[1:]) + return variable def splitConfigAndConstraints(kast): @@ -274,7 +258,7 @@ def _propagateUpConstraints(_k): conjunct2 = buildAssoc(mlTop, '#And', r2) disjunct = KApply('#Or', [conjunct1, conjunct2]) return buildAssoc(mlTop(), '#And', [disjunct] + common) - return traverseBottomUp(k, _propagateUpConstraints) + return bottom_up(_propagateUpConstraints, k) def splitConfigFrom(configuration): @@ -299,7 +283,7 @@ def _replaceWithVar(k): return KApply(k.label, [KVariable(config_var)]) return k - symbolic_config = traverseTopDown(configuration, _replaceWithVar) + symbolic_config = top_down(_replaceWithVar, configuration) return (symbolic_config, initial_substitution) @@ -323,7 +307,7 @@ def _collapseDots(_kast): if _kast.lhs == ktokenDots: return ktokenDots return _kast - return traverseBottomUp(kast, _collapseDots) + return bottom_up(_collapseDots, kast) def pushDownRewrites(kast): @@ -353,7 +337,7 @@ def _pushDownRewrites(_kast): if type(lhs) is KSequence and lhs.arity > 0 and type(lhs.items[-1]) is KVariable and type(rhs) is KVariable and lhs.items[-1] == rhs: return KSequence([KRewrite(KSequence(lhs.items[0:-1]), KApply(klabelEmptyK)), rhs]) return _kast - return traverseTopDown(kast, _pushDownRewrites) + return top_down(_pushDownRewrites, kast) def inlineCellMaps(kast): @@ -368,7 +352,7 @@ def _inlineCellMaps(_kast): if type(mapKey) is KApply and mapKey.is_cell: return _kast.args[1] return _kast - return traverseBottomUp(kast, _inlineCellMaps) + return bottom_up(_inlineCellMaps, kast) def removeSemanticCasts(kast): @@ -381,7 +365,7 @@ def _removeSemanticCasts(_kast): if type(_kast) is KApply and _kast.arity == 1 and _kast.label.startswith('#SemanticCast'): return _kast.args[0] return _kast - return traverseBottomUp(kast, _removeSemanticCasts) + return bottom_up(_removeSemanticCasts, kast) def markUselessVars(kast): @@ -390,7 +374,7 @@ def markUselessVars(kast): - Input: A Kast term. - Output: Kast term with variables appropriately named. """ - occurances = countVarOccurances(kast) + occurances = count_vars(kast) subst = {} for v in occurances: if v.startswith('_') and occurances[v] > 1: @@ -406,14 +390,7 @@ def uselessVarsToDots(kast, keepVars=None): - Input: kast term, and a requires clause and ensures clause. - Output: kast term with the useless vars structurally abstracted. """ - initList = {} - if keepVars is not None: - for v in keepVars: - if v not in initList: - initList[v] = 1 - else: - initList[v] += 1 - numOccurances = countVarOccurances(kast, numOccurances=initList) + numOccurances = count_vars(kast) + Counter(keepVars) def _collapseUselessVars(_kast): if type(_kast) is KApply and _kast.is_cell: @@ -426,7 +403,7 @@ def _collapseUselessVars(_kast): return _kast.let(args=newArgs) return _kast - return traverseBottomUp(kast, _collapseUselessVars) + return bottom_up(_collapseUselessVars, kast) def labelsToDots(kast, labels): @@ -439,7 +416,7 @@ def _labelstoDots(k): if type(k) is KApply and k.is_cell and k.label in labels: return ktokenDots return k - return traverseBottomUp(kast, _labelstoDots) + return bottom_up(_labelstoDots, kast) def onAttributes(kast: W, f: Callable[[KAtt], KAtt]) -> W: @@ -522,14 +499,14 @@ def _removeSourceMap(att): return onAttributes(k, _removeSourceMap) -def removeGeneratedCells(constrainedTerm): +def remove_generated_cells(term: KInner) -> KInner: """Remove and from a configuration. - - Input: Constrained term which contains and . + - Input: Constrained term. - Output: Constrained term with those cells removed. """ rule = KApply('', [KVariable('CONFIG'), KVariable('_')]), KVariable('CONFIG') - return rewriteAnywhereWith(rule, constrainedTerm) + return rewriteAnywhereWith(rule, term) def isAnonVariable(kast): @@ -541,7 +518,7 @@ def _largeTokensToDots(_k): if type(_k) is KToken and len(_k.token) > maxLen: return KToken('...', _k.sort) return _k - return traverseBottomUp(kast, _largeTokensToDots) + return bottom_up(_largeTokensToDots, kast) def getCell(constrainedTerm, cellVariable): @@ -620,7 +597,7 @@ def buildRule(ruleId, initConstrainedTerm, finalConstrainedTerm, claim=False, pr lhsVars = collectFreeVars(initConstrainedTerm) rhsVars = collectFreeVars(finalConstrainedTerm) - varOccurances = countVarOccurances(mlAnd([initConstrainedTerm, finalConstrainedTerm])) + varOccurances = count_vars(mlAnd([initConstrainedTerm, finalConstrainedTerm])) vSubst = {} vremapSubst = {} for v in varOccurances: @@ -687,7 +664,7 @@ def _rewritesToAbstractions(_kast): return _kast minimizedRewrite = pushDownRewrites(KRewrite(state1, state2)) - abstractedState = traverseBottomUp(minimizedRewrite, _rewritesToAbstractions) + abstractedState = bottom_up(_rewritesToAbstractions, minimizedRewrite) subst1 = match(abstractedState, state1) subst2 = match(abstractedState, state2) if subst1 is None or subst2 is None: diff --git a/k-distribution/src/main/scripts/lib/pyk/pyk/kcfg.py b/k-distribution/src/main/scripts/lib/pyk/pyk/kcfg.py index b976452a3eb..310b39ec57c 100644 --- a/k-distribution/src/main/scripts/lib/pyk/pyk/kcfg.py +++ b/k-distribution/src/main/scripts/lib/pyk/pyk/kcfg.py @@ -1,59 +1,103 @@ import json -import threading -from typing import Final +from abc import ABC +from dataclasses import dataclass +from functools import reduce +from itertools import chain +from threading import RLock +from typing import ( + Any, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, +) from graphviz import Digraph -from .cli_utils import fatal, notif -from .kast import KClaim, KFlatModule, KImport, KInner, KRule -from .kastManip import ( - buildRule, - matchWithConstraint, - minimizeSubst, - mlAnd, - simplifyBool, - substitute, - substToMap, - unsafeMlPredToBool, -) -from .utils import compare_short_hashes, hash_str, shorten_hashes +from .cterm import CTerm +from .kast import TRUE, KInner, KRuleLike +from .kastManip import buildRule, mlAnd, simplifyBool, unsafeMlPredToBool +from .ktool import KPrint +from .subst import Subst +from .utils import compare_short_hashes, shorten_hashes class KCFG: - _FIELDS: Final = ( - 'states', - 'graph', - 'abstractions', - 'loop', - 'split', - 'init', - 'target', - 'terminal', - 'stuck', - 'subsumptions', - 'frontier', - ) - - _NODE_ATTRS: Final = ('init', 'target', 'stuck', 'terminal', 'frontier', 'loop', 'split') + + @dataclass(frozen=True) + class Node: + term: CTerm + + def __init__(self, term: CTerm): + object.__setattr__(self, 'term', term) + + @property + def id(self) -> str: + return self.term.hash + + def to_dict(self) -> Dict[str, Any]: + return {'id': self.id, 'term': self.term.cterm.to_dict()} + + class EdgeLike(ABC): + source: 'KCFG.Node' + target: 'KCFG.Node' + + @dataclass(frozen=True) + class Edge(EdgeLike): + source: 'KCFG.Node' + target: 'KCFG.Node' + condition: KInner = TRUE + depth: int = 1 + + def to_dict(self) -> Dict[str, Any]: + return {'source': self.source.id, 'target': self.target.id, 'condition': self.condition.to_dict(), 'depth': self.depth} + + def to_rule(self, claim=False, priority=50) -> KRuleLike: + sentence_id = f'BASIC-BLOCK-{self.source.id}-TO-{self.target.id}' + init_term = mlAnd([self.source.term, self.condition]) + final_term = self.target.term + return buildRule(sentence_id, init_term, final_term, claim=claim, priority=priority) + + @dataclass(frozen=True) + class Cover(EdgeLike): + source: 'KCFG.Node' + target: 'KCFG.Node' + subst: Subst + + def __init__(self, source: 'KCFG.Node', target: 'KCFG.Node'): + object.__setattr__(self, 'source', source) + object.__setattr__(self, 'target', target) + + match_res = source.term.match(target.term) + if not match_res: + raise ValueError(f'No matching between: {source.id} and {target.id}') + + subst, _ = match_res + object.__setattr__(self, 'subst', subst) + + def to_dict(self) -> Dict[str, Any]: + return {'source': self.source.id, 'target': self.target.id} + + _nodes: Dict[str, Node] + _edges: Dict[str, Dict[str, Edge]] + _covers: Dict[str, Cover] + _init: Set[str] + _target: Set[str] + _stuck: Set[str] + _lock: RLock def __init__(self): - # todo: switch to sets instead of lists everywhere - # where it is possible - # and make encoders/decoders for that - # it is preparation for immutable data structures - # for parallel execution - self.states = {} - self.graph = {} - self.abstractions = {} - self.loop = [] - self.split = [] - self.init = [] - self.target = [] - self.terminal = [] - self.stuck = [] - self.subsumptions = {} - self.frontier = [] - self._lock = threading.RLock() + self._nodes = {} + self._edges = {} + self._covers = {} + self._init = set() + self._target = set() + self._stuck = set() + self._lock = RLock() def __enter__(self): self._lock.acquire() @@ -65,407 +109,378 @@ def __exit__(self, exc_type, exc_value, traceback): return True return False - @staticmethod - def _encode(value): - result = None - if type(value) is dict: - result = {} - for (key, val) in value.items(): - result[key] = KCFG._encode(val) - elif type(value) is set: - result = {'__kcfg_type__': 'set'} - val = [] - for item in value: - val.append(KCFG._encode(item)) - result['__kcfg_set_items__'] = val - elif isinstance(value, KInner): - result = value.to_dict() - result['__kcfg_type__'] = 'k' - else: - result = value - return result - - @staticmethod - def _decode(value): - result = None - if type(value) is dict: - if '__kcfg_type__' in value and value['__kcfg_type__'] == 'set': - items = value['__kcfg_set_items__'] - result = set() - for item in items: - result.add(KCFG._decode(item)) - elif '__kcfg_type__' in value and value['__kcfg_type__'] == 'k': - return KInner.from_dict(value) - else: - result = {} - for (key, val) in value.items(): - result[key] = KCFG._decode(val) - else: - result = value - return result - - def _assign(self, other): - self.states = other.states - self.graph = other.graph - self.abstractions = other.abstractions - self.loop = other.loop - self.split = other.split - self.init = other.init - self.target = other.target - self.terminal = other.terminal - self.stuck = other.stuck - self.subsumptions = other.subsumptions - self.frontier = other.frontier - - def to_dict(self): - dct = {} - for field in KCFG._FIELDS: - dct[field] = getattr(self, field) - return KCFG._encode(dct) + @property + def nodes(self) -> Set[Node]: + return set(self._nodes.values()) + + @property + def init(self) -> Set[Node]: + return {node for node in self.nodes if self.is_init(node.id)} + + @property + def target(self) -> Set[Node]: + return {node for node in self.nodes if self.is_target(node.id)} + + @property + def stuck(self) -> Set[Node]: + return {node for node in self.nodes if self.is_stuck(node.id)} + + @property + def leaves(self) -> Set[Node]: + return {node for node in self.nodes if self.is_leaf(node.id)} + + @property + def covered(self) -> Set[Node]: + return {node for node in self.nodes if self.is_covered(node.id)} + + @property + def uncovered(self) -> Set[Node]: + return {node for node in self.nodes if not self.is_covered(node.id)} + + @property + def frontier(self) -> Set[Node]: + return {node for node in self._nodes.values() if self.is_frontier(node.id)} + + @property + def covers(self) -> Set[Cover]: + return set(self._covers.values()) + + def to_dict(self) -> Dict[str, Any]: + nodes = [node.to_dict() for node in self.nodes] + edges = [edge.to_dict() for edge in self.edges()] + covers = [cover.to_dict() for cover in self.covers] + + init = list(self._init) + target = list(self._target) + stuck = list(self._stuck) + + res = { + 'nodes': nodes, + 'edges': edges, + 'covers': covers, + 'init': init, + 'target': target, + 'stuck': stuck, + } + return {k: v for k, v in res.items() if v} @staticmethod - def from_dict(dct): - dct = KCFG._decode(dct) + def from_dict(dct: Mapping[str, Any]) -> 'KCFG': cfg = KCFG() - for field in KCFG._FIELDS: - if field in dct: - setattr(cfg, field, dct[field]) + + nodes: Dict[str, str] = {} + + def resolve(node_id: str) -> str: + if node_id not in nodes: + raise ValueError(f'Undeclared node: {node_id}') + return nodes[node_id] + + for node_dict in dct.get('nodes') or []: + term = CTerm(KInner.from_dict(node_dict['term'])) + node = cfg.create_node(term) + + node_key = node_dict['id'] + if node_key in nodes: + raise ValueError(f'Multiple declarations of node: {node_key}') + nodes[node_key] = node.id + + for edge_dict in dct.get('edges') or []: + source_id = resolve(edge_dict['source']) + target_id = resolve(edge_dict['target']) + condition = KInner.from_dict(edge_dict['condition']) + depth = edge_dict['depth'] + cfg.create_edge(source_id, target_id, condition, depth) + + for cover_dict in dct.get('covers') or []: + source_id = resolve(cover_dict['source']) + target_id = resolve(cover_dict['target']) + cfg.create_cover(source_id, target_id) + + for init_id in dct.get('init') or []: + cfg.add_init(resolve(init_id)) + + for target_id in dct.get('target') or []: + cfg.add_target(resolve(target_id)) + + for stuck_id in dct.get('stuck') or []: + cfg.add_stuck(resolve(stuck_id)) + return cfg - def to_json(self): + def to_json(self) -> str: return json.dumps(self.to_dict(), sort_keys=True) @staticmethod def from_json(s: str) -> 'KCFG': return KCFG.from_dict(json.loads(s)) - def to_dot(self, kprint): - graph = Digraph() + def to_dot(self, kprint: KPrint) -> str: + def _node_attrs(node_id: str) -> List[str]: + atts = [] + if node_id in self._init: + atts.append('init') + if node_id in self._target: + atts.append('target') + if node_id in self._stuck: + atts.append('stuck') + return atts + + def _short_label(label): + return '\n'.join([label_line if len(label_line) < 100 else (label_line[0:100] + ' ...') for label_line in label.split('\n')]) - def _short_label(_label): - return '\n'.join([label_line if len(label_line) < 100 else (label_line[0:100] + ' ...') for label_line in _label.split('\n')]) + graph = Digraph() - for state in self.states: - classAttrs = ' '.join(self.getNodeAttributes(state)) - label = shorten_hashes(state) + (classAttrs and ' ' + classAttrs) + for node in self.nodes: + classAttrs = ' '.join(_node_attrs(node.id)) + label = shorten_hashes(node.id) + (classAttrs and ' ' + classAttrs) attrs = {'class': classAttrs} if classAttrs else {} - graph.node(name=state, label=label, **attrs) - for source in self.graph: - for target in self.graph[source]: - edge = self.graph[source][target] - display_condition = simplifyBool(unsafeMlPredToBool(edge['condition'])) - depth = edge['depth'] - classes = edge['classes'] - label = '\nandBool'.join(kprint.prettyPrint(display_condition).split(' andBool')) - label = f'{label}\n{depth} steps' - label = _short_label(label) - classAttrs = ' '.join(classes) - attrs = {'class': classAttrs} if classAttrs else {} - graph.edge(tail_name=source, head_name=target, label=f' {label} ', **attrs) - for state in self.abstractions: - for abstractId in self.abstractions[state]: - subst = substToMap(minimizeSubst(self.subsumptions[state][abstractId])) - label = kprint.prettyPrint(subst) - label = _short_label(label) - attrs = {'class': 'abstraction', 'style': 'dashed'} - graph.edge(tail_name=state, head_name=abstractId, label=f' {label} ', **attrs) - for target in self.target: + graph.node(name=node.id, label=label, **attrs) + + for edge in self.edges(): + display_condition = simplifyBool(unsafeMlPredToBool(edge.condition)) + depth = edge.depth + label = '\nandBool'.join(kprint.prettyPrint(display_condition).split(' andBool')) + label = f'{label}\n{depth} steps' + label = _short_label(label) + graph.edge(tail_name=edge.source.id, head_name=edge.target.id, label=f' {label} ') + + for cover in self.covers: + label = ', '.join(f'{k} |-> {kprint.prettyPrint(v)}' for k, v in cover.subst.items()) + label = _short_label(label) + attrs = {'class': 'abstraction', 'style': 'dashed'} + graph.edge(tail_name=cover.source.id, head_name=cover.target.id, label=f' {label} ', **attrs) + + for target in self._target: for node in self.frontier: attrs = {'class': 'target', 'style': 'solid'} - graph.edge(tail_name=node, head_name=target, label=' ???', **attrs) - for node in self.terminal: - attrs = {'class': 'target', 'style': 'dashed'} - graph.edge(tail_name=node, head_name=target, label=' ???', **attrs) + graph.edge(tail_name=node.id, head_name=target, label=' ???', **attrs) + return graph.source - def getNodeAttributes(self, nodeId): - atts = [] - for att in KCFG._NODE_ATTRS: - if nodeId in self.__getattribute__(att): - atts.append(att) - return atts - - def getTermHash(self, term): - return hash_str(term.to_json()) - - def getStateIdByShortHash(self, shortHash): - for h in self.states: - if compare_short_hashes(shortHash, h): - return h - return None - - def getNodesByHashes(self, shortHashes): - return [self.getStateIdByShortHash(h) for h in shortHashes] - - def insertNode(self, newConstrainedTerm): - subsumes = [] - subsumedBy = [] - for constrainedTermId in self.states: - constrainedTerm = self.states[constrainedTermId] - subsumedWith = matchWithConstraint(constrainedTerm, newConstrainedTerm) - subsumesWith = matchWithConstraint(newConstrainedTerm, constrainedTerm) - if subsumedWith is not None and subsumesWith is not None: - return (False, constrainedTermId) - elif subsumedWith is not None: - subsumedBy.append((constrainedTermId, subsumedWith)) - elif subsumesWith is not None: - subsumes.append((constrainedTermId, subsumesWith)) - newConstrainedTermId = self.getTermHash(newConstrainedTerm) - self.states[newConstrainedTermId] = newConstrainedTerm - self.graph[newConstrainedTermId] = {} - self.subsumptions[newConstrainedTermId] = {} - self.abstractions[newConstrainedTermId] = [] - - for (ctid, subst) in subsumes: - self.subsumptions[ctid][newConstrainedTermId] = subst - for (ctid, subst) in subsumedBy: - self.subsumptions[newConstrainedTermId][ctid] = subst - - return (True, newConstrainedTermId) - - def insertAbstraction(self, concreteId, abstractId): - if concreteId == abstractId: - return self - if abstractId not in self.getMoreGeneralNodes(concreteId): - fatal('Node ' + str(abstractId) + ' does not abstract node ' + str(concreteId) + ' as claimed.') - if abstractId not in self.abstractions[concreteId]: - self.abstractions[concreteId].append(abstractId) - return self + def _resolve(self, short_id: str) -> str: + matches = [node_id for node_id in self._nodes if compare_short_hashes(short_id, node_id)] + if not matches: + raise ValueError(f'Unknown node: {short_id}') + if len(matches) > 1: + raise ValueError(f'Multiple nodes for pattern: {short_id} (matches e.g. {matches[0]} and {matches[1]})') + return matches[0] - # heavy - def removeNode(self, nodeId): - if nodeId not in self.states: - raise ValueError(f'Unknown node: {nodeId}') - - self.states.pop(nodeId) - if nodeId in self.frontier: - for nid in (self.getConcretizations(nodeId) + self.getPredecessors(nodeId)): - if nid not in self.frontier: - self.frontier.append(nid) - for k in ['subsumptions', 'graph']: - self.__getattribute__(k).pop(nodeId, None) - for initNode in self.__getattribute__(k): - if nodeId in self.__getattribute__(k)[initNode]: - self.__getattribute__(k)[initNode].pop(nodeId) - for k in ['abstractions']: - self.__getattribute__(k).pop(nodeId, None) - for initNode in self.__getattribute__(k): - self.__getattribute__(k)[initNode] = [n for n in self.__getattribute__(k)[initNode] if n != nodeId] - for k in KCFG._NODE_ATTRS: - self.__setattr__(k, [n for n in self.__getattribute__(k) if n != nodeId]) - - def getMoreGeneralNodes(self, nodeId): - if nodeId not in self.subsumptions: - return [] - return list(self.subsumptions[nodeId].keys()) - - # heavy - def getLessGeneralNodes(self, nodeId): - return [nid for nid in self.subsumptions if nodeId in self.getMoreGeneralNodes(nid)] - - def getAbstractions(self, nodeId): - if nodeId not in self.abstractions: - return [] - return self.abstractions[nodeId] - - # heavy - def getConcretizations(self, nodeId): - return [nid for nid in self.abstractions if nodeId in self.getAbstractions(nid)] - - def getSuccessors(self, nodeId): - if nodeId not in self.graph: - return [] - return list(self.graph[nodeId].keys()) - - # heavy - def getPredecessors(self, nodeId): - return [nid for nid in self.graph if nodeId in self.getSuccessors(nid)] - - def getEdges(self): - return [(s, f) for s in self.graph for f in self.graph[s]] - - def insertEdge(self, initConstrainedTermId, condition, finalConstrainedTermId, depth, classes=[], priority=50): - edgeLabel = {'depth': depth, 'condition': condition, 'classes': [c for c in classes], 'priority': priority} - if finalConstrainedTermId in self.graph[initConstrainedTermId] and self.graph[initConstrainedTermId][finalConstrainedTermId] == edgeLabel: - return self - self.graph[initConstrainedTermId][finalConstrainedTermId] = edgeLabel - - predNodes = self.transitiveClosureFromState(initConstrainedTermId, reverse=True) - moreGeneralNodes = self.getMoreGeneralNodes(finalConstrainedTermId) - for nid in predNodes: - if nid in moreGeneralNodes: - if nid not in self.getAbstractions(finalConstrainedTermId): - self.abstractions[finalConstrainedTermId].append(nid) - if depth == 0 and nid == initConstrainedTermId and finalConstrainedTermId not in self.split: - self.split.append(finalConstrainedTermId) - elif nid not in self.loop: - self.loop.append(nid) + def node(self, node_id: str) -> Node: + node_id = self._resolve(node_id) + return self._nodes[node_id] - return self + def create_node(self, term: CTerm) -> Node: + node = KCFG.Node(term) + + if node.id in self._nodes: + raise ValueError(f'Node already exists: {node.id}') + + self._nodes[node.id] = node + return node + + def remove_node(self, node_id: str) -> None: + node_id = self._resolve(node_id) + + self._nodes.pop(node_id) + + self._edges.pop(node_id, None) + for source_id in self._edges: + self._edges[source_id].pop(node_id, None) + + self._covers.pop(node_id, None) + for source_id, cover in self._covers.items(): + if cover.target.id == node_id: + self._covers.pop(source_id) + + self._init.discard(node_id) + self._target.discard(node_id) + self._stuck.discard(node_id) + + def create_edge(self, source_id: str, target_id: str, condition: KInner = TRUE, depth=1) -> Edge: + source = self.node(source_id) + target = self.node(target_id) + + if target.id in self._edges.get(source.id, {}): + raise ValueError(f'Edge already exists: {source.id} -> {target.id}') + + if source.id not in self._edges: + self._edges[source.id] = {} + + edge = KCFG.Edge(source, target, condition, depth) + self._edges[source.id][target.id] = edge + return edge + + def edge(self, source_id: str, target_id: str) -> Optional[Edge]: + source_id = self._resolve(source_id) + target_id = self._resolve(target_id) + return self._edges.get(source_id, {}).get(target_id) + + def edges(self, *, source_id: Optional[str] = None, target_id: Optional[str] = None) -> Set[Edge]: + source_id = self._resolve(source_id) if source_id is not None else None + target_id = self._resolve(target_id) if target_id is not None else None + + res: Iterable[KCFG.Edge] + if source_id: + res = self._edges.get(source_id, {}).values() + else: + res = (edge for _, targets in self._edges.items() for _, edge in targets.items()) + + return {edge for edge in res if not target_id or target_id == edge.target.id} + + def create_cover(self, source_id: str, target_id: str) -> Cover: + source = self.node(source_id) + target = self.node(target_id) + + if source.id in self._covers: + raise ValueError(f'Cover already exists: {source.id} -> {self._covers[source.id].target.id}') + + cover = KCFG.Cover(source, target) + self._covers[source.id] = cover + return cover + + def cover_of(self, node_id) -> Optional[Cover]: + node_id = self._resolve(node_id) + return self._covers.get(node_id) + + def covers_by(self, node_id) -> Set[Cover]: + node_id = self._resolve(node_id) + return {cover for cover in self.covers if cover.target.id == node_id} + + def add_init(self, node_id: str) -> None: + node_id = self._resolve(node_id) + self._init.add(node_id) + + def add_target(self, node_id: str) -> None: + node_id = self._resolve(node_id) + self._target.add(node_id) + + def add_stuck(self, node_id: str) -> None: + node_id = self._resolve(node_id) + self._stuck.add(node_id) + + def is_init(self, node_id: str) -> bool: + node_id = self._resolve(node_id) + return node_id in self._init + + def is_target(self, node_id: str) -> bool: + node_id = self._resolve(node_id) + return node_id in self._target + + def is_stuck(self, node_id: str) -> bool: + node_id = self._resolve(node_id) + return node_id in self._stuck + + def is_leaf(self, node_id: str) -> bool: + node_id = self._resolve(node_id) + return node_id not in self._edges + + def is_covered(self, node_id: str) -> bool: + node_id = self._resolve(node_id) + return node_id in self._covers + + def is_frontier(self, node_id: str) -> bool: + node_id = self._resolve(node_id) + return self.is_leaf(node_id) and not self.is_target(node_id) and not self.is_stuck(node_id) and not self.is_covered(node_id) + + def prune(self, node_id: str) -> None: + nodes = self.reachable_nodes(node_id) + for node in nodes: + self.remove_node(node.id) + + def paths_between(self, source_id: str, target_id: str, *, traverse_covers=False) -> List[Tuple[EdgeLike, ...]]: + source_id = self._resolve(source_id) + target_id = self._resolve(target_id) + + INIT = 1 + POP_PATH = 2 + + visited: Set[str] = set() + path: List[KCFG.EdgeLike] = [] + paths: List[Tuple[KCFG.EdgeLike, ...]] = [] + + worklist: List[Any] = [INIT] + + while worklist: + item = worklist.pop() + + if type(item) == str: + visited.remove(item) + continue + + if item == POP_PATH: + path.pop() + continue + + node_id: str + + if item == INIT: + node_id = source_id - def getEdgeCondition(self, initNodeId, finalNodeId): - return self.graph[initNodeId][finalNodeId]['condition'] - - def getEdgeSentence(self, initNodeId, finalNodeId, priority=50): - sentenceId = 'BASIC-BLOCK-' + str(initNodeId) + '-TO-' + str(finalNodeId) - initConstrainedTerm = self.states[initNodeId] - finalConstrainedTerm = self.states[finalNodeId] - edge = self.graph[initNodeId][finalNodeId] - verified = 'verified' in edge['classes'] - edgeConstraint = edge['condition'] - initConstrainedTerm = mlAnd([initConstrainedTerm, edgeConstraint]) - return buildRule(sentenceId, initConstrainedTerm, finalConstrainedTerm, claim=not verified, priority=priority) - - def getModule(self, moduleName, mainModuleName, rules=False, priority=50): - newSentences = [] - for i in self.graph: - for j in self.graph[i]: - (newSentence, _) = self.getEdgeSentence(i, j, priority=priority) - if (rules and type(newSentence) is KRule) or (not rules and type(newSentence) is KClaim): - newSentences.append(newSentence) - return KFlatModule(moduleName, [KImport(mainModuleName)], newSentences) - - def markEdgeVerified(self, initConstrainedTermId, finalConstrainedTermId): - self.graph[initConstrainedTermId][finalConstrainedTermId]['classes'].append('verified') - - def markEdgeAsyncProcess(self, initConstrainedTermId, finalConstrainedTermId): - self.graph[initConstrainedTermId][finalConstrainedTermId]['classes'].append('async_processed') - - def clearEdgeMarkAsyncProcess(self, initConstrainedTermId, finalConstrainedTermId): - self.graph[initConstrainedTermId][finalConstrainedTermId]['classes'].remove('async_processed') - - def transitiveClosureFromState(self, constrainedTermId, reverse=False, stopAtLoops=False, stopAtNodes=None): - constrainedTermIds = [] - newConstrainedTermIds = [constrainedTermId] - stopNodes = [] if not stopAtLoops else self.loop - if stopAtNodes is not None: - stopNodes.extend(stopAtNodes) - while len(newConstrainedTermIds) > 0: - constrainedTermId = newConstrainedTermIds.pop(0) - if constrainedTermId not in constrainedTermIds: - constrainedTermIds.append(constrainedTermId) - if constrainedTermId not in stopNodes: - if not reverse: - newConstrainedTermIds.extend(self.getSuccessors(constrainedTermId)) - newConstrainedTermIds.extend(self.getAbstractions(constrainedTermId)) - else: - newConstrainedTermIds.extend(self.getPredecessors(constrainedTermId)) - newConstrainedTermIds.extend(self.getConcretizations(constrainedTermId)) - return constrainedTermIds - - def nonLoopingPathsBetweenStates(self, initConstrainedTermId, finalConstrainedTermId): - paths = [] - worklistPaths = [[initConstrainedTermId]] - while len(worklistPaths) > 0: - nextPath = worklistPaths.pop(0) - if nextPath[-1] == finalConstrainedTermId: - paths.append(nextPath) else: - initState = nextPath[-1] - for s in (self.getSuccessors(initState) + self.getAbstractions(initState)): - if s not in nextPath: - worklistPaths.append(nextPath + [s]) - return paths + assert isinstance(item, KCFG.EdgeLike) - def invalidateStates(self, stateIds): - invalidNodes = [] - for s in stateIds: - invalidNodes.extend(self.transitiveClosureFromState(s)) - invalidNodes = sorted(list(set(invalidNodes))) - - newCfg = KCFG() - nodeMap = {} - for sid in self.init: - if sid not in invalidNodes: - (_, newSid) = newCfg.insertNode(self.states[sid]) - nodeMap[newSid] = sid - newCfg.init.append(newSid) - - workList = list(newCfg.states.keys()) - while len(workList) > 0: - newInitId = workList.pop(0) - oldInitId = nodeMap[newInitId] - - for oldSuccessorId in [nid for nid in self.getSuccessors(oldInitId) if nid not in invalidNodes]: - (_, newSuccessorId) = newCfg.insertNode(self.states[oldSuccessorId]) - if newSuccessorId in nodeMap: - continue - nodeMap[newSuccessorId] = oldSuccessorId - oldEdge = self.graph[oldInitId][oldSuccessorId] - newCfg.insertEdge(newInitId, oldEdge['condition'], newSuccessorId, oldEdge['depth'], classes=oldEdge['classes'], priority=oldEdge['priority']) - workList.append(newSuccessorId) - - for oldMoreGeneralId in [nid for nid in self.getMoreGeneralNodes(oldInitId) if nid not in invalidNodes]: - (_, newMoreGeneralId) = newCfg.insertNode(self.states[oldMoreGeneralId]) - if newMoreGeneralId in nodeMap: + node_id = item.target.id + if node_id in visited: continue - nodeMap[newMoreGeneralId] = oldMoreGeneralId - workList.append(newMoreGeneralId) - - reverseNodeMap = {v: k for (k, v) in nodeMap.items()} - for newNodeId in newCfg.states: - if nodeMap[newNodeId] in self.abstractions: - for oldAbstractId in self.abstractions[nodeMap[newNodeId]]: - if oldAbstractId in reverseNodeMap and reverseNodeMap[oldAbstractId] not in newCfg.abstractions[newNodeId]: - newCfg.abstractions[newNodeId].append(reverseNodeMap[oldAbstractId]) - - for newNodeId in newCfg.states: - if nodeMap[newNodeId] in self.loop: - newCfg.loop.append(newNodeId) - if nodeMap[newNodeId] in self.terminal: - newCfg.terminal.append(newNodeId) - if nodeMap[newNodeId] in self.stuck: - newCfg.stuck.append(newNodeId) - - for target in self.target: - (newState, newTargetId) = newCfg.insertNode(self.states[target]) - if newTargetId not in newCfg.target: - newCfg.target.append(newTargetId) - if newState: - nodeMap[newTargetId] = target - - newCfg.frontier = [nid for nid in newCfg.states if nodeMap[nid] in self.frontier and nodeMap[nid] not in invalidNodes] - for nodeId in newCfg.states: - if len(newCfg.getSuccessors(nodeId)) < len(self.getSuccessors(nodeMap[nodeId])): - newCfg.frontier.append(nodeId) - if len(newCfg.getAbstractions(nodeId)) < len(self.getAbstractions(nodeMap[nodeId])): - newCfg.frontier.append(nodeId) - - newCfg.frontier = list(sorted(list(set(newCfg.frontier)))) - - notif('Invalidated nodes ' + str(shorten_hashes(invalidNodes)) + '.') - notif('New frontier ' + str(shorten_hashes(newCfg.frontier)) + '.') - self._assign(newCfg) - return self - def getPathsBetween(self, initNodeId, finalNodeId, seenNodes=None): - if initNodeId == finalNodeId: - return [[finalNodeId]] - seen = [] if seenNodes is None else [nid for nid in seenNodes] - succs = self.getSuccessors(initNodeId) - abstr = self.getAbstractions(initNodeId) - paths = [] - if initNodeId in seen: - return [] - for nid in (succs + abstr): - paths.extend([[initNodeId] + p for p in self.getPathsBetween(nid, finalNodeId, seenNodes=seen + [initNodeId])]) + path.append(item) + + if node_id == target_id: + paths.append(tuple(path)) + continue + + visited.add(node_id) + worklist.append(node_id) + + edges: List[KCFG.EdgeLike] = list(self.edges(source_id=node_id)) + if traverse_covers and (cover := self.cover_of(node_id)): + edges.append(cover) + + for edge in edges: + worklist.append(POP_PATH) + worklist.append(edge) + return paths - def getPathCondition(self, path): - constraints = [] - substitutions = [] - depth = 0 - for (init, fin) in zip(path, path[1:]): - if init in self.graph and fin in self.graph[init]: - constraints.append(self.getEdgeCondition(init, fin)) - depth += self.graph[init][fin]['depth'] - if init in self.abstractions and fin in self.abstractions[init]: - substitutions.append(self.subsumptions[init][fin]) - substitutions = list(reversed(substitutions)) - substitution = {} - if len(substitutions) > 0: - substitution = {k: substitutions[0][k] for k in substitutions[0]} - for subst in substitutions[1:]: - for k in substitution: - substitution[k] = substitute(substitution[k], subst) - return (mlAnd(constraints), substitution, depth) + def reachable_nodes(self, node_id: str, *, reverse=False, traverse_covers=False) -> Set[Node]: + node = self.node(node_id) + + visited: Set[KCFG.Node] = set() + worklist: List[KCFG.Node] = [node] + + while worklist: + node = worklist.pop() + + if node in visited: + continue + + visited.add(node) + + edges: Iterable[KCFG.EdgeLike] + if not reverse: + cover = self.cover_of(node.id) + edges = chain(self.edges(source_id=node.id), [cover] if cover and traverse_covers else []) + worklist.extend(edge.target for edge in edges) + else: + edges = chain(self.edges(target_id=node.id), self.covers_by(node.id) if traverse_covers else []) + worklist.extend(edge.source for edge in edges) + + return visited + + +def path_condition(path: Sequence[KCFG.EdgeLike]) -> Tuple[KInner, Subst, int]: + constraints: List[KInner] = [] + substitutions: List[Subst] = [] + depth = 0 + + for edge in path: + if type(edge) == KCFG.Edge: + constraints.append(edge.condition) + depth += edge.depth + elif type(edge) == KCFG.Cover: + substitutions.append(edge.subst) + else: + assert False + + substitution = reduce(Subst.compose, reversed(substitutions), Subst()) + return mlAnd(constraints), substitution, depth diff --git a/k-distribution/src/main/scripts/lib/pyk/pyk/subst.py b/k-distribution/src/main/scripts/lib/pyk/pyk/subst.py new file mode 100644 index 00000000000..d75f9308c69 --- /dev/null +++ b/k-distribution/src/main/scripts/lib/pyk/pyk/subst.py @@ -0,0 +1,113 @@ +from dataclasses import dataclass +from itertools import chain +from typing import Dict, Iterator, Mapping, Optional, TypeVar + +from .kast import KApply, KInner, KRewrite, KSequence, KToken, KVariable +from .kastManip import bottom_up +from .utils import FrozenDict + +K = TypeVar('K') +V = TypeVar('V') + + +@dataclass(frozen=True) +class Subst(Mapping[str, KInner]): + subst: FrozenDict[str, KInner] + + def __init__(self, subst: Mapping[str, KInner] = {}): + object.__setattr__(self, 'subst', FrozenDict({k: v for k, v in subst.items() if type(v) is not KVariable or v.name != k})) + + def __iter__(self) -> Iterator[str]: + return iter(self.subst) + + def __len__(self) -> int: + return len(self.subst) + + def __getitem__(self, key: str) -> KInner: + return self.subst[key] + + def __mul__(self, other: 'Subst') -> 'Subst': + return self.compose(other) + + def __call__(self, term: KInner) -> KInner: + return self.apply(term) + + def compose(self, other: 'Subst') -> 'Subst': + from_other = ((k, self(v)) for k, v in other.items()) + from_self = ((k, v) for k, v in self.items() if k not in other) + return Subst(dict(chain(from_other, from_self))) + + def apply(self, term: KInner) -> KInner: + def replace(term): + if type(term) is KVariable and term.name in self: + return self[term.name] + return term + + return bottom_up(replace, term) + + +def match(pattern: KInner, term: KInner) -> Optional[Subst]: + """Perform syntactic pattern matching and return the substitution. + + - Input: a pattern and a term. + - Output: substitution instantiating the pattern to the term. + """ + + # TODO simplify + def merge(*dicts: Optional[Mapping[K, V]]) -> Optional[Dict[K, V]]: + if len(dicts) == 0: + return {} + + dict1 = dicts[0] + if dict1 is None: + return None + + if len(dicts) == 1: + return dict(dict1) + + dict2 = dicts[1] + if dict2 is None: + return None + + intersecting_keys = set(dict1.keys()).intersection(set(dict2.keys())) + for key in intersecting_keys: + if dict1[key] != dict2[key]: + return None + + newDict = {key: dict1[key] for key in dict1} + for key in dict2.keys(): + newDict[key] = dict2[key] + + restDicts = dicts[2:] + return merge(newDict, *restDicts) + + # TODO simplify + def _match(pattern: KInner, term: KInner) -> Optional[Dict[str, KInner]]: + subst: Optional[Dict[str, KInner]] = {} + if type(pattern) is KVariable: + return {pattern.name: term} + if type(pattern) is KToken and type(term) is KToken: + return {} if pattern.token == term.token else None + if type(pattern) is KApply and type(term) is KApply \ + and pattern.label == term.label and pattern.arity == term.arity: + for patternArg, kastArg in zip(pattern.args, term.args): + argSubst = match(patternArg, kastArg) + subst = merge(subst, argSubst) + if subst is None: + return None + return subst + if type(pattern) is KRewrite and type(term) is KRewrite: + lhsSubst = match(pattern.lhs, term.lhs) + rhsSubst = match(pattern.rhs, term.rhs) + return merge(lhsSubst, rhsSubst) + if type(pattern) is KSequence and type(term) is KSequence and pattern.arity == term.arity: + for (patternItem, substItem) in zip(pattern.items, term.items): + itemSubst = match(patternItem, substItem) + subst = merge(subst, itemSubst) + if subst is None: + return None + return subst + return None + + subst = _match(pattern, term) + return Subst(subst) if subst is not None else None diff --git a/k-distribution/src/main/scripts/lib/pyk/pyk/tests/test_count_vars.py b/k-distribution/src/main/scripts/lib/pyk/pyk/tests/test_count_vars.py new file mode 100644 index 00000000000..325caa51e75 --- /dev/null +++ b/k-distribution/src/main/scripts/lib/pyk/pyk/tests/test_count_vars.py @@ -0,0 +1,39 @@ +from functools import partial +from typing import Final, Mapping, Tuple +from unittest import TestCase + +from ..kast import KApply, KInner, KVariable +from ..kastManip import count_vars + +a, b, c = (KApply(label) for label in ['a', 'b', 'c']) +x, y, z = (KVariable(name) for name in ['x', 'y', 'z']) +f, g, h = (partial(KApply.of, label) for label in ['f', 'g', 'h']) + + +class CountVarTest(TestCase): + TEST_DATA: Final[Tuple[Tuple[KInner, Mapping[str, int]], ...]] = ( + (a, {}), + (x, {'x': 1}), + (f(a), {}), + (f(a, b, c), {}), + (f(x), {'x': 1}), + (f(f(f(x))), {'x': 1}), + (f(x, a), {'x': 1}), + (f(x, x), {'x': 2}), + (f(x, y), {'x': 1, 'y': 1}), + (f(x, y, z), {'x': 1, 'y': 1, 'z': 1}), + (f(x, g(y), h(z)), {'x': 1, 'y': 1, 'z': 1}), + (f(x, a, g(y, b), h(z, c)), {'x': 1, 'y': 1, 'z': 1}), + (f(x, g(x, y), h(x, z)), {'x': 3, 'y': 1, 'z': 1}), + (f(x, g(x, h(x, y, z))), {'x': 3, 'y': 1, 'z': 1}), + ) + + def test(self): + # Given + for i, [term, expected] in enumerate(self.TEST_DATA): + with self.subTest(i=i): + # When + actual = count_vars(term) + + # Then + self.assertDictEqual(actual, expected) diff --git a/k-distribution/src/main/scripts/lib/pyk/pyk/tests/test_kcfg.py b/k-distribution/src/main/scripts/lib/pyk/pyk/tests/test_kcfg.py index 8dcb5611b24..139da845898 100644 --- a/k-distribution/src/main/scripts/lib/pyk/pyk/tests/test_kcfg.py +++ b/k-distribution/src/main/scripts/lib/pyk/pyk/tests/test_kcfg.py @@ -1,76 +1,97 @@ -from typing import Dict, Hashable, Iterable, Tuple, Union, cast +from typing import Any, Dict, List, Tuple from unittest import TestCase -from ..kast import TOP, KInner +from ..cterm import CTerm +from ..kast import TRUE, KApply from ..kcfg import KCFG from ..prelude import token +def nid(i: int) -> str: + return node(i).id + + +def term(i: int) -> CTerm: + return CTerm(KApply('', [token(i)])) + + +def node(i: int) -> KCFG.Node: + return KCFG.Node(term(i)) + + +def edge(i: int, j: int) -> KCFG.Edge: + return KCFG.Edge(node(i), node(j), TRUE, 1) + + +def node_dicts(n: int) -> List[Dict[str, Any]]: + return [node(i).to_dict() for i in range(n)] + + +def edge_dicts(*edges: Tuple[int, int]) -> List[Dict[str, Any]]: + return [ + {'source': nid(i), 'target': nid(j), 'condition': TRUE.to_dict(), 'depth': 1} + for i, j in edges + ] + + class KCFGTestCase(TestCase): - def test_from_dict_single_state(self): + def test_from_dict_single_node(self): # Given - d = cfg_dict([(0, token(0))]) + d = {'nodes': node_dicts(1)} # When cfg = KCFG.from_dict(d) # Then - self.assertEqual(len(cfg.states), 1) - self.assertEqual(cfg.states[0], token(0)) + self.assertSetEqual(cfg.nodes, {node(0)}) + self.assertDictEqual(cfg.to_dict(), d) - def test_from_dict_two_states(self): + def test_from_dict_two_nodes(self): # Given - d = cfg_dict([(0, token(0)), (1, token(1))]) + d = {'nodes': node_dicts(2)} # When cfg = KCFG.from_dict(d) # Then - self.assertEqual(len(cfg.states), 2) - self.assertEqual(cfg.states[0], token(0)) - self.assertEqual(cfg.states[1], token(1)) + self.assertSetEqual(cfg.nodes, {node(0), node(1)}) def test_from_dict_loop_edge(self): # Given - d = cfg_dict(states=[(0, token(0))], edges=[(0, token(True), 0)]) + d = {'nodes': node_dicts(1), 'edges': edge_dicts((0, 0))} # When cfg = KCFG.from_dict(d) # Then - self.assertEqual(len(cfg.states), 1) - self.assertEqual(cfg.states[0], token(0)) - self.assertEqual(len(cfg.graph), 1) - self.assertEqual(len(cfg.graph[0]), 1) - self.assertEqual(cfg.graph[0][0]['condition'], token(True)) + self.assertSetEqual(cfg.nodes, {node(0)}) + self.assertSetEqual(cfg.edges(), {edge(0, 0)}) + self.assertEqual(cfg.edge(nid(0), nid(0)), edge(0, 0)) + self.assertDictEqual(cfg.to_dict(), d) def test_from_dict_simple_edge(self): # Given - d = cfg_dict(states=[(0, token(0)), (1, token(1))], edges=[(0, token(True), 1)]) + d = {'nodes': node_dicts(2), 'edges': edge_dicts((0, 1))} # When cfg = KCFG.from_dict(d) # Then - self.assertEqual(len(cfg.states), 2) - self.assertEqual(cfg.states[0], token(0)) - self.assertEqual(cfg.states[1], token(1)) - self.assertEqual(len(cfg.graph), 2) - self.assertEqual(len(cfg.graph[0]), 1) - self.assertEqual(cfg.graph[0][1]['condition'], token(True)) - self.assertDictEqual(cfg.graph[1], {}) - - def test_insert_node(self): + self.assertSetEqual(cfg.nodes, {node(0), node(1)}) + self.assertSetEqual(cfg.edges(), {edge(0, 1)}) + self.assertEqual(cfg.edge(nid(0), nid(1)), edge(0, 1)) + + def test_create_node(self): # Given cfg = KCFG() # When - _, node_id = cfg.insertNode(token(True)) + new_node = cfg.create_node(term(0)) # Then - self.assertEqual(len(cfg.states), 1) - self.assertEqual(cfg.states[node_id], token(True)) + self.assertEqual(new_node, node(0)) + self.assertSetEqual(cfg.nodes, {node(0)}) def test_remove_unknown_node(self): # Given @@ -79,112 +100,102 @@ def test_remove_unknown_node(self): # Then with self.assertRaises(ValueError): # When - cfg.removeNode(0) + cfg.remove_node(nid(0)) def test_remove_node(self): # Given - d = cfg_dict([0]) + d = {'nodes': node_dicts(1), 'edges': edge_dicts((0, 0))} cfg = KCFG.from_dict(d) # When - cfg.removeNode(0) + cfg.remove_node(nid(0)) # Then - self.assertEqual(len(cfg.states), 0) + self.assertSetEqual(cfg.nodes, set()) + self.assertSetEqual(cfg.edges(), set()) + with self.assertRaises(ValueError): + cfg.node(nid(0)) + with self.assertRaises(ValueError): + cfg.edge(nid(0), nid(0)) def test_insert_loop_edge(self): # Given - d = cfg_dict([0]) + d = {'nodes': node_dicts(1)} cfg = KCFG.from_dict(d) # When - cfg.insertEdge(0, token(True), 0, depth=1) + new_edge = cfg.create_edge(nid(0), nid(0)) # Then - self.assertEqual(len(cfg.graph), 1) + self.assertEqual(new_edge, edge(0, 0)) + self.assertSetEqual(cfg.nodes, {node(0)}) + self.assertSetEqual(cfg.edges(), {edge(0, 0)}) + self.assertEqual(cfg.edge(nid(0), nid(0)), edge(0, 0)) def test_insert_simple_edge(self): # Given - d = cfg_dict([0, 1]) + d = {'nodes': node_dicts(2)} cfg = KCFG.from_dict(d) # When - cfg.insertEdge(0, token(True), 1, depth=1) + new_edge = cfg.create_edge(nid(0), nid(1)) # Then - self.assertEqual(len(cfg.graph), 2) - self.assertEqual(len(cfg.graph[0]), 1) - self.assertEqual(cfg.graph[0][1]['condition'], token(True)) - self.assertDictEqual(cfg.graph[1], {}) + self.assertEqual(new_edge, edge(0, 1)) + self.assertSetEqual(cfg.nodes, {node(0), node(1)}) + self.assertSetEqual(cfg.edges(), {edge(0, 1)}) def test_get_successors(self): - d = cfg_dict(states=[0, 1, 2], edges=[(0, 1), (0, 2)]) + d = {'nodes': node_dicts(3), 'edges': edge_dicts((0, 1), (0, 2))} cfg = KCFG.from_dict(d) # When - succs = set(cfg.getSuccessors(0)) + succs = set(cfg.edges(source_id=nid(0))) # Then - self.assertSetEqual(succs, {1, 2}) + self.assertSetEqual(succs, {edge(0, 1), edge(0, 2)}) def test_get_predecessors(self): - d = cfg_dict(states=[0, 1, 2], edges=[(0, 2), (1, 2)]) - cfg = KCFG.from_dict(d) - - # When - preds = set(cfg.getPredecessors(2)) - - # Then - self.assertSetEqual(preds, {0, 1}) - - def test_get_edges(self): - d = cfg_dict(states=[0, 1, 2], edges=[(0, 1), (0, 2), (1, 2)]) + d = {'nodes': node_dicts(3), 'edges': edge_dicts((0, 2), (1, 2))} cfg = KCFG.from_dict(d) # When - edges = set(cfg.getEdges()) + preds = set(cfg.edges(target_id=nid(2))) # Then - self.assertSetEqual(edges, {(0, 1), (0, 2), (1, 2)}) + self.assertSetEqual(preds, {edge(0, 2), edge(1, 2)}) - def test_transitive_closure(self): + def test_reachable_nodes(self): # Given - d = cfg_dict(states=[0, 1, 2, 3, 4, 5], edges=[(0, 1), (0, 5), (1, 2), (1, 3), (2, 4), (3, 4), (4, 1)]) + d = { + 'nodes': node_dicts(6), + 'edges': edge_dicts((0, 1), (0, 5), (1, 2), (1, 3), (2, 4), (3, 4), (4, 1)), + } cfg = KCFG.from_dict(d) # When - node_ids = set(cfg.transitiveClosureFromState(1)) + nodes = set(cfg.reachable_nodes(nid(1))) # Then - self.assertSetEqual(node_ids, {1, 2, 3, 4}) + self.assertSetEqual(nodes, {node(1), node(2), node(3), node(4)}) - def test_non_looping_paths_between_states(self): + def test_paths_between(self): # Given - d = cfg_dict(states=[0, 1, 2, 3], edges=[(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (3, 0)]) + d = { + 'nodes': node_dicts(4), + 'edges': edge_dicts((0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (3, 0)), + } cfg = KCFG.from_dict(d) # When - paths = set(tuple(path) for path in cfg.nonLoopingPathsBetweenStates(0, 3)) + paths = set(cfg.paths_between(nid(0), nid(3))) # Then - self.assertSetEqual(paths, {(0, 1, 3), (0, 1, 2, 3), (0, 2, 3)}) - - -def cfg_dict( - states: Iterable[Union[Hashable, Tuple[Hashable, KInner]]], - edges: Iterable[Union[Tuple[Hashable, Hashable], Tuple[Hashable, KInner, Hashable]]] = (), -) -> Dict: - cfg_states = {} - for state in states: - state_key = state[0] if type(state) is tuple else state - state_term = state[1] if type(state) is tuple else TOP - cfg_states[state_key] = {'__kcfg_type__': 'k', **state_term.to_dict()} - - cfg_edges: Dict = {state: {} for state in cfg_states} - for edge in edges: - edge_src = edge[0] - edge_term = cast(KInner, edge[1]) if len(edge) == 3 else TOP - edge_trg = edge[-1] - cfg_edges[edge_src][edge_trg] = {'depth': 1, 'condition': {'__kcfg_type__': 'k', **edge_term.to_dict()}, 'classes': [], 'priority': 50} - - return {'states': cfg_states, 'graph': cfg_edges} + self.assertSetEqual( + paths, + { + (edge(0, 1), edge(1, 3)), + (edge(0, 2), edge(2, 3)), + (edge(0, 1), edge(1, 2), edge(2, 3)), + }, + ) diff --git a/k-distribution/src/main/scripts/lib/pyk/pyk/tests/test_match.py b/k-distribution/src/main/scripts/lib/pyk/pyk/tests/test_match.py new file mode 100644 index 00000000000..ec6282d3d95 --- /dev/null +++ b/k-distribution/src/main/scripts/lib/pyk/pyk/tests/test_match.py @@ -0,0 +1,36 @@ +from functools import partial +from typing import Final, Tuple +from unittest import TestCase + +from ..kast import KApply, KInner, KVariable +from ..subst import match + +a, b, c = (KApply(label) for label in ['a', 'b', 'c']) +x, y, z = (KVariable(name) for name in ['x', 'y', 'z']) +f, g, h = (partial(KApply.of, label) for label in ['f', 'g', 'h']) + + +class MatchTest(TestCase): + TEST_DATA: Final[Tuple[Tuple[KInner, KInner], ...]] = ( + (a, a), + (a, x), + (f(a), x), + (f(a), f(a)), + (f(a), f(x)), + (f(a, b), f(x, y)), + (f(a, b, c), f(x, y, z)), + (f(g(h(a))), f(x)), + (f(g(h(x))), f(x)), + (f(a, g(b, h(c))), f(x, y)), + ) + + def test_match_and_subst(self): + # Given + for i, [term, pattern] in enumerate(self.TEST_DATA): + with self.subTest(i=i): + # When + subst = match(pattern, term) + + # Then + self.assertIsNotNone(subst) + self.assertEqual(subst(pattern), term) diff --git a/k-distribution/src/main/scripts/lib/pyk/pyk/tests/test_subst.py b/k-distribution/src/main/scripts/lib/pyk/pyk/tests/test_subst.py new file mode 100644 index 00000000000..4498d0972e8 --- /dev/null +++ b/k-distribution/src/main/scripts/lib/pyk/pyk/tests/test_subst.py @@ -0,0 +1,84 @@ +from functools import partial +from unittest import TestCase + +from ..kast import KApply, KVariable +from ..subst import Subst + +a, b, c = (KApply(label) for label in ['a', 'b', 'c']) +x, y, z = (KVariable(name) for name in ['x', 'y', 'z']) +f, g, h = (partial(KApply.of, label) for label in ['f', 'g', 'h']) + + +class SubstTest(TestCase): + + def test_eq(self): + # Given + test_data = ( + ({}, {}), + ({'x': x}, {}), + ({}, {'x': x}), + ({'x': a}, {'x': a}), + ) + + for i, [subst1, subst2] in enumerate(test_data): + with self.subTest(i=i): + # Then + self.assertEqual(Subst(subst1), Subst(subst2)) + + def test_neq(self): + # Given + test_data = ( + ({'x': a}, {}), + ({}, {'x': a}), + ({'x': a}, {'x': b}), + ({'x': y}, {'x': z}), + ) + + for i, [subst1, subst2] in enumerate(test_data): + with self.subTest(i=i): + # Then + self.assertNotEqual(Subst(subst1), Subst(subst2)) + + def test_compose(self): + # Given + test_data = ( + ({}, {}, {}), + ({'x': x}, {}, {}), + ({}, {'x': x}, {}), + ({'x': y}, {}, {'x': y}), + ({}, {'x': y}, {'x': y}), + ({'y': x}, {'x': y}, {'y': x}), + ({'x': z}, {'x': y}, {'x': y}), + ({'y': z}, {'x': y}, {'x': z, 'y': z}), + ({'x': y}, {'x': f(x)}, {'x': f(y)}), + ({'x': f(x)}, {'x': f(x)}, {'x': f(f(x))}), + ({'y': f(z)}, {'x': f(y)}, {'x': f(f(z)), 'y': f(z)}), + ) + + for i, [subst1, subst2, expected] in enumerate(test_data): + with self.subTest(i=i): + # When + actual = dict(Subst(subst1) * Subst(subst2)) + + # Then + self.assertDictEqual(actual, expected) + + def test_apply(self): + # Given + test_data = ( + (a, {}, a), + (x, {}, x), + (a, {'x': b}, a), + (x, {'x': a}, a), + (f(x), {'x': f(x)}, f(f(x))), + (f(a, g(x, a)), {'x': b}, f(a, g(b, a))), + (f(g(h(x, y, z))), {'x': a, 'y': b, 'z': c}, f(g(h(a, b, c)))) + ) + + for i, [pattern, subst, expected] in enumerate(test_data): + with self.subTest(i=i): + # When + actual = Subst(subst)(pattern) + + # Then + self.assertEqual(actual, expected) diff --git a/k-distribution/src/main/scripts/lib/pyk/pyk/utils.py b/k-distribution/src/main/scripts/lib/pyk/pyk/utils.py index 0ea83099684..9fa6081ae72 100644 --- a/k-distribution/src/main/scripts/lib/pyk/pyk/utils.py +++ b/k-distribution/src/main/scripts/lib/pyk/pyk/utils.py @@ -1,28 +1,54 @@ import hashlib import string -from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, TypeVar +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + Tuple, + TypeVar, +) T = TypeVar('T') +K = TypeVar('K') +V = TypeVar('V') -def combine_dicts(*dicts: Mapping) -> Optional[Dict]: - if len(dicts) == 0: - return {} - if len(dicts) == 1: - return dict(dicts[0]) - dict1 = dicts[0] - dict2 = dicts[1] - restDicts = dicts[2:] - if dict1 is None or dict2 is None: - return None - intersecting_keys = set(dict1.keys()).intersection(set(dict2.keys())) - for key in intersecting_keys: - if dict1[key] != dict2[key]: - return None - newDict = {key: dict1[key] for key in dict1} - for key in dict2.keys(): - newDict[key] = dict2[key] - return combine_dicts(newDict, *restDicts) +# Based on: https://stackoverflow.com/a/2704866 +# Perhaps one day: https://peps.python.org/pep-0603/ +class FrozenDict(Mapping[K, V]): + _dict: Dict[K, V] + _hash: Optional[int] + + def __init__(self, *args, **kwargs): + self._dict = dict(*args, **kwargs) + self._hash = None + + def __iter__(self) -> Iterator[K]: + return iter(self._dict) + + def __len__(self) -> int: + return len(self._dict) + + def __getitem__(self, key: K) -> V: + return self._dict[key] + + def __hash__(self) -> int: + if self._hash is None: + h = 0 + for pair in self.items(): + h ^= hash(pair) + self._hash = h + return self._hash + + def __str__(self) -> str: + return f'FrozenDict({str(self._dict)})' + + def __repr__(self) -> str: + return f'FrozenDict({repr(self._dict)})' def merge_with(f, d1: Mapping, d2: Mapping) -> Dict: diff --git a/k-distribution/tests/pyk/configuration_test.py b/k-distribution/tests/pyk/configuration_test.py index 45ae520322c..5c21000aa91 100644 --- a/k-distribution/tests/pyk/configuration_test.py +++ b/k-distribution/tests/pyk/configuration_test.py @@ -14,7 +14,7 @@ buildRule, collapseDots, getCell, - removeGeneratedCells, + remove_generated_cells, structurallyFrameKCell, substitute, ) @@ -48,18 +48,18 @@ def test(self): self.assertEqual(config_actual, config_expected) -class RemoveGeneratedCounterTest(ConfigurationTest): +class RemoveGeneratedCellsTest(ConfigurationTest): def test_first(self): # When - config_actual = removeGeneratedCells(self.GENERATED_TOP_CELL_1) + config_actual = remove_generated_cells(self.GENERATED_TOP_CELL_1) # Then self.assertEqual(config_actual, self.T_CELL) def test_second(self): # When - config_actual = removeGeneratedCells(self.GENERATED_TOP_CELL_2) + config_actual = remove_generated_cells(self.GENERATED_TOP_CELL_2) # Then self.assertEqual(config_actual, self.T_CELL)