diff --git a/parser/README.md b/parser/README.md index e3dbc83..b3bab3b 100644 --- a/parser/README.md +++ b/parser/README.md @@ -1,17 +1,27 @@ ### Extended Backus-Naur Form (EBNF) Grammar +An EBNF context-free grammar (CFG) representing a formal language for boolean expressions. + ``` -Expr ::= Var Expr' - | ! Expr - | ( Expr ) - -Expr' ::= & Expr - | \| Expr +Expr ::= Term Expr' + +Expr' ::= | Expr' + | & Expr' + | ^ Expr' | ε +Term ::= ( Expr ) + | ! Term + | Var + Var ::= [A-Z]+ + | t + | f ``` -### LL(1) Parsing Table +### LL(1) Parse Table +The parse table applying all of the above production rules in the EBNF grammar showing that it is LL(1). + +TODO: Update this table. | | [A-Z] | ( | ) | ! | & | \| | $ | |-------|:-----------------:|:-------------:|:-:|:---------------:|:----------------:|:---------------:|:---------:| @@ -19,10 +29,11 @@ Var ::= [A-Z]+ | Expr' | | | | | Expr' → & Expr | Expr' → \| Expr | Expr' → ϵ | | Var | Var → [A-Z]+ | | | | | | | -Var = Variable \ -Expr = Expression - ### First and Follow Function Table +A table showing the FIRST and FOLLOW sets of each non-terminal in the EBNF grammar along with if the non-terminal is NULLABLE. + +TODO: Update this table. + | | FIRST | FOLLOW | NULLABLE | |-------|:------------:|:------:| :-------: | | Expr | [A-Z]+, !, ( | $, ) | | diff --git a/parser/ast.py b/parser/ast.py index cc204ba..23229b4 100644 --- a/parser/ast.py +++ b/parser/ast.py @@ -2,93 +2,93 @@ Abstract Syntax Tree (AST) for boolean expressions. """ -from typing import TypeVar +from typing import TypeVar, Optional from parser.visitor import Visitor, ParamVisitor, RetVisitor, RetParamVisitor T = TypeVar("T") R = TypeVar("R") -class Var: +class Node: """ - A class to represent a variable from A-Z. - - Attributes - ---------- - name : str - The name of the variable. + An interface class to represent a node in the AST. """ - def __init__(self, name: str): - self.name = name + pass - def __str__(self) -> str: - return self.name - def accept(self, v: Visitor) -> None: - v.visitVar(self) +# ============================================================================= - def acceptParam(self, v: ParamVisitor, param: T) -> None: - v.visitVar(self, param) - def acceptRet(self, v: RetVisitor) -> R: - return v.visitVar(self) +class Expr(Node): + """ + An interface class to represent a full expression. + """ - def acceptParamRet(self, v: RetParamVisitor, param: T) -> R: - return v.visitVar(self, param) + pass -class Node: +class ExprPrime(Node): """ - An interface class to represent a node in the AST. + An interface class to represent an expr prime expression. Used to eliminate left-recursion. """ pass -class Expr(Node): +class Term(Node): """ - An interface class to represent a full expression. + An interface class to represent a term. Used to add precedence to the grammar. """ pass -class ExprPrime(Node): +class Var(Term): """ - An interface class to represent a prime expression. Used to eliminate left-recursion. + An interface class to represent a var. Used for constants and variables. """ - pass +# ============================================================================= -class AndExpr(ExprPrime): + +class TermExpr(Expr): """ - A class to represent an AND expression. + A class to represent an expression containing only a Term node. Attributes ---------- - first : Expr - The first part of the AND expression. + first : Term + The variable contained by the TermExpr. + + second : Optional[ExprPrime] """ - def __init__(self, first: Expr): + def __init__(self, first: Term, second: Optional[ExprPrime] = None): self.first = first + self.second = second def __str__(self) -> str: - return f"& {self.first}" + if self.second is not None: + return f"{self.first} {self.second}" + + return str(self.first) def accept(self, v: Visitor) -> None: - v.visitAndExpr(self) + v.visitTermExpr(self) - def acceptParam(self, v: ParamVisitor, param: T) -> None: - v.visitAndExpr(self, param) + def acceptParam(self, v: ParamVisitor[T], param: T) -> None: + v.visitTermExpr(self, param) - def acceptRet(self, v: RetVisitor) -> R: - return v.visitAndExpr(self) + def acceptRet(self, v: RetVisitor[R]) -> R: + return v.visitTermExpr(self) - def acceptParamRet(self, v: RetParamVisitor, param: T) -> R: - return v.visitAndExpr(self, param) + def acceptParamRet(self, v: RetParamVisitor[T, R], param: T) -> R: + return v.visitTermExpr(self, param) + + +# ============================================================================= class OrExpr(ExprPrime): @@ -97,102 +97,119 @@ class OrExpr(ExprPrime): Attributes ---------- - first : Expr + first : Term The first part of the OR expression. + + second : Optional[ExprPrime] + The second (or following) part of the OR expression. """ - def __init__(self, first: Expr): + def __init__(self, first: Term, second: Optional[ExprPrime] = None): self.first = first + self.second = second def __str__(self) -> str: + if self.second is not None: + return f"| {self.first} {self.second}" + return f"| {self.first}" def accept(self, v: Visitor) -> None: v.visitOrExpr(self) - def acceptParam(self, v: ParamVisitor, param: T) -> None: + def acceptParam(self, v: ParamVisitor[T], param: T) -> None: v.visitOrExpr(self, param) - def acceptRet(self, v: RetVisitor) -> R: + def acceptRet(self, v: RetVisitor[R]) -> R: return v.visitOrExpr(self) - def acceptParamRet(self, v: RetParamVisitor, param: T) -> R: + def acceptParamRet(self, v: RetParamVisitor[T, R], param: T) -> R: return v.visitOrExpr(self, param) -class VarExpr(Expr): +class AndExpr(ExprPrime): """ - A class to represent an expression containing only a Var node. + A class to represent an AND expression. Attributes ---------- - first : Var - The variable contained by the VarExpr. + first : Term + The first part of the AND expression. - second : ExprPrime, Optional - Optional attribute for containing the "AndExpr", "OrExpr", or "None" (for epsilon case). + second : Optional[ExprPrime] + The second (or following) part of the AND expression. """ - def __init__(self, first: Var, second: ExprPrime = None): + def __init__(self, first: Term, second: Optional[ExprPrime] = None): self.first = first self.second = second def __str__(self) -> str: - if not self.second: - return str(self.first) + if self.second is not None: + return f"& {self.first} {self.second}" - return f"{self.first} {self.second}" + return f"& {self.first}" def accept(self, v: Visitor) -> None: - v.visitVarExpr(self) + v.visitAndExpr(self) - def acceptParam(self, v: ParamVisitor, param: T) -> None: - v.visitVarExpr(self, param) + def acceptParam(self, v: ParamVisitor[T], param: T) -> None: + v.visitAndExpr(self, param) - def acceptRet(self, v: RetVisitor) -> R: - return v.visitVarExpr(self) + def acceptRet(self, v: RetVisitor[R]) -> R: + return v.visitAndExpr(self) - def acceptParamRet(self, v: RetParamVisitor, param: T) -> R: - return v.visitVarExpr(self, param) + def acceptParamRet(self, v: RetParamVisitor[T, R], param: T) -> R: + return v.visitAndExpr(self, param) -class NotExpr(Expr): +class XorExpr(ExprPrime): """ - A class to represent a NOT expression. + A class to represent an XOR expression. Attributes ---------- - first : Expr - The first part of the NOT expression. + first : Term + The first part of the XOR expression. + + second : Optional[ExprPrime] + The second (or following) part of the XOR expression. """ - def __init__(self, first: Expr): + def __init__(self, first: Term, second: Optional[ExprPrime] = None): self.first = first + self.second = second def __str__(self) -> str: - return f"!{self.first}" + if self.second is not None: + return f"^ {self.first} {self.second}" + + return f"^ {self.first}" def accept(self, v: Visitor) -> None: - v.visitNotExpr(self) + v.visitXorExpr(self) + + def acceptParam(self, v: ParamVisitor[T], param: T) -> None: + v.visitXorExpr(self, param) + + def acceptRet(self, v: RetVisitor[R]) -> R: + return v.visitXorExpr(self) - def acceptParam(self, v: ParamVisitor, param: T) -> None: - v.visitNotExpr(self, param) + def acceptParamRet(self, v: RetParamVisitor[T, R], param: T) -> R: + return v.visitXorExpr(self, param) - def acceptRet(self, v: RetVisitor) -> R: - return v.visitNotExpr(self) - def acceptParamRet(self, v: RetParamVisitor, param: T) -> R: - return v.visitNotExpr(self, param) +# ============================================================================= -class ParenExpr(Expr): +class ParenTerm(Term): """ - A class to represent a parenthesized expression. + A class to represent a parenthesized term. Attributes ---------- first : Expr - The first part of the parenthesized expression. + The first part of the parenthesized term. """ def __init__(self, first: Expr): @@ -202,13 +219,116 @@ def __str__(self) -> str: return f"({self.first})" def accept(self, v: Visitor) -> None: - v.visitParenExpr(self) + v.visitParenTerm(self) + + def acceptParam(self, v: ParamVisitor[T], param: T) -> None: + v.visitParenTerm(self, param) + + def acceptRet(self, v: RetVisitor[R]) -> R: + return v.visitParenTerm(self) + + def acceptParamRet(self, v: RetParamVisitor[T, R], param: T) -> R: + return v.visitParenTerm(self, param) + + +class NotTerm(Term): + """ + A class to represent a NOT term. + + Attributes + ---------- + first : Term + The first part of the NOT term. + """ + + def __init__(self, first: Term): + self.first = first + + def __str__(self) -> str: + return f"!{self.first}" + + def accept(self, v: Visitor) -> None: + v.visitNotTerm(self) + + def acceptParam(self, v: ParamVisitor[T], param: T) -> None: + v.visitNotTerm(self, param) + + def acceptRet(self, v: RetVisitor[R]) -> R: + return v.visitNotTerm(self) + + def acceptParamRet(self, v: RetParamVisitor[T, R], param: T) -> R: + return v.visitNotTerm(self, param) + + +# ============================================================================= + + +class VarVar(Term): + """ + A class to represent a variable from A-Z. + + Attributes + ---------- + name : str + The name of the variable. + """ + + def __init__(self, name: str) -> None: + self.name = name + + def __str__(self) -> str: + return self.name + + def accept(self, v: Visitor) -> None: + v.visitVarVar(self) + + def acceptParam(self, v: ParamVisitor[T], param: T) -> None: + v.visitVarVar(self, param) + + def acceptRet(self, v: RetVisitor[R]) -> R: + return v.visitVarVar(self) + + def acceptParamRet(self, v: RetParamVisitor[T, R], param: T) -> R: + return v.visitVarVar(self, param) + + +class TrueConst(Var): + """ + Represents the constant TRUE + """ + + def __str__(self) -> str: + return "t" + + def accept(self, v: Visitor) -> None: + v.visitTrueConst(self) + + def acceptParam(self, v: ParamVisitor[T], param: T) -> None: + v.visitTrueConst(self, param) + + def acceptRet(self, v: RetVisitor[R]) -> R: + return v.visitTrueConst(self) + + def acceptParamRet(self, v: RetParamVisitor[T, R], param: T) -> R: + return v.visitTrueConst(self, param) + + +class FalseConst(Var): + """ + Represents the constant FALSE. + """ + + def __str__(self) -> str: + return "f" + + def accept(self, v: Visitor) -> None: + v.visitFalseConst(self) - def acceptParam(self, v: ParamVisitor, param: T) -> None: - v.visitParenExpr(self, param) + def acceptParam(self, v: ParamVisitor[T], param: T) -> None: + v.visitFalseConst(self, param) - def acceptRet(self, v: RetVisitor) -> R: - return v.visitParenExpr(self) + def acceptRet(self, v: RetVisitor[R]) -> R: + return v.visitFalseConst(self) - def acceptParamRet(self, v: RetParamVisitor, param: T) -> R: - return v.visitParenExpr(self, param) + def acceptParamRet(self, v: RetParamVisitor[T, R], param: T) -> R: + return v.visitFalseConst(self, param) diff --git a/parser/lex.py b/parser/lex.py index 264b1c3..6713630 100644 --- a/parser/lex.py +++ b/parser/lex.py @@ -3,7 +3,7 @@ class Lexer: - terminals: str = r"\&|\||\!|\(|\)|[A-Z]+" + terminals: str = r"\&|\||\^|\!|\(|\)|[A-Z]+|t|f" ws: str = r"\s|\t|\n|\r" eof: str = r"\Z" diff --git a/parser/parse.py b/parser/parse.py index 485e497..7e4c9ad 100644 --- a/parser/parse.py +++ b/parser/parse.py @@ -1,12 +1,19 @@ +from typing import Optional + from parser.ast import ( - AndExpr, Expr, ExprPrime, - NotExpr, - OrExpr, - ParenExpr, + Term, Var, - VarExpr, + TermExpr, + OrExpr, + AndExpr, + XorExpr, + ParenTerm, + NotTerm, + VarVar, + TrueConst, + FalseConst, ) @@ -15,10 +22,14 @@ class Parser: - """Parser object for analyzing tokens for errors""" + """ + Parser object for building a boolean expression AST. + """ def __init__(self) -> None: - """Initializes the parser with attributes to be used""" + """ + Initializes the parser. + """ self.pos: int = -1 @staticmethod @@ -27,13 +38,15 @@ def error(msg: str, pos: int): exit(1) def parse(self, tokens: list[str]) -> Expr: - """Parses given tokens""" + """ + Parses given tokens. + """ self.tokens: list[str] = tokens # Initializes the first token self.advance() - rv = self.expr() + rv: Expr = self.expr() self.assert_end() return rv @@ -42,7 +55,6 @@ def assert_end(self) -> None: Parser.error(f"Expected end '' but found {self.next_token}", self.pos) def eat(self, expected: str) -> None: - """Skips a token""" if self.next_token == expected: self.advance() else: @@ -51,52 +63,110 @@ def eat(self, expected: str) -> None: ) def advance(self) -> None: - """Moves to the next token""" + """ + Moves to the next token. + """ self.pos += 1 self.next_token: str = self.tokens[self.pos] def expr(self) -> Expr: - """Parses an expression""" - if re.match("[A-Z]+", self.next_token): - first: Var = self.var() - second: ExprPrime | None = self.expr_prime() - if not second: - return VarExpr(first) + """ + Parses an expression. + """ + if ( + self.next_token == "(" + or self.next_token == "!" + or re.match("[A-Z]+|t|f", self.next_token) + ): + first: Term = self.term() + second: Optional[ExprPrime] = self.expr_prime() + + if second is None: + return TermExpr(first) else: - return VarExpr(first, second) - elif self.next_token == "!": - self.eat("!") - first: Expr = self.expr() - return NotExpr(first) - elif self.next_token == "(": - self.eat("(") - first: Expr = self.expr() - self.eat(")") - return ParenExpr(first) + return TermExpr(first, second) else: Parser.error( - f"Expected [var, !, (] but found '{self.next_token}'", self.pos + f"Expected (, !, [A-Z]+, t, or f but found '{self.next_token}'", + self.pos, ) - def expr_prime(self) -> ExprPrime | None: - """Parses an expression prime (explain what this is later)""" - if self.next_token == "&": - self.eat("&") - first: Expr = self.expr() - return AndExpr(first) - elif self.next_token == "|": + def expr_prime(self) -> Optional[ExprPrime]: + """ + Parses an expression prime. + """ + if self.next_token == "|": self.eat("|") - first: Expr = self.expr() - return OrExpr(first) - else: + + first: Term = self.term() + second: Optional[ExprPrime] = self.expr_prime() + + if second is None: + return OrExpr(first) + else: + return OrExpr(first, second) + elif self.next_token == "&": + self.eat("&") + + first: Term = self.term() + second: Optional[ExprPrime] = self.expr_prime() + + if second is None: + return AndExpr(first) + else: + return AndExpr(first, second) + if self.next_token == "^": + self.eat("^") + + first: Term = self.term() + second: Optional[ExprPrime] = self.expr_prime() + + if second is None: + return XorExpr(first) + else: + return XorExpr(first, second) + elif self.next_token == ")" or self.next_token == "": # Handles epsilon case return None + else: + Parser.error( + f"Expected |, &, ^, ), or but found '{self.next_token}'", self.pos + ) + + def term(self) -> Term: + """ + Parses a term. + """ + if self.next_token == "(": + self.eat("(") + first: Expr = self.expr() + self.eat(")") + return ParenTerm(first) + elif self.next_token == "!": + self.eat("!") + first: Term = self.term() + return NotTerm(first) + elif re.match("[A-Z]+|t|f", self.next_token): + return self.var() + else: + Parser.error( + f"Expected (, !, [A-Z]+, t, or f but found '{self.next_token}'", + self.pos, + ) def var(self) -> Var: """Parses a variable that represents a boolean expression""" - if not re.match("[A-Z]+", self.next_token): - Parser.error(f"Expected [A-Z]+ but found '{self.next_token}'", self.pos) + if re.match("[A-Z]+", self.next_token): + v: VarVar = VarVar(self.next_token) + self.eat(self.next_token) + return v + elif self.next_token == "t": + self.eat("t") + return TrueConst() + elif self.next_token == "f": + self.eat("f") + return FalseConst() else: - t: str = self.next_token - self.eat(t) - return Var(t) + Parser.error( + f"Expected [A-Z]+, t, or f but found '{self.next_token}'", self.pos + ) diff --git a/parser/visitor.py b/parser/visitor.py index 54b0b85..d4603a0 100644 --- a/parser/visitor.py +++ b/parser/visitor.py @@ -2,7 +2,20 @@ from abc import abstractmethod if TYPE_CHECKING: - from parser.ast import VarExpr, NotExpr, ParenExpr, AndExpr, OrExpr + from parser.ast import ( + TermExpr, + OrExpr, + AndExpr, + XorExpr, + ParenTerm, + NotTerm, + VarVar, + TrueConst, + FalseConst, + ) + +R = TypeVar("R") +T = TypeVar("T") class Visitor: @@ -10,28 +23,40 @@ class Visitor: A visitor that visits each node in the AST. """ - def visitVarExpr(self, vex: "VarExpr") -> None: - vex.first.accept(self) - if vex.second: - vex.second.accept(self) + def visitTermExpr(self, node: "TermExpr") -> None: + node.first.accept(self) + if node.second: + node.second.accept(self) + + def visitOrExpr(self, node: "OrExpr") -> None: + node.first.accept(self) + if node.second: + node.second.accept(self) - def visitNotExpr(self, nex: "NotExpr") -> None: - nex.first.accept(self) + def visitAndExpr(self, node: "AndExpr") -> None: + node.first.accept(self) + if node.second: + node.second.accept(self) - def visitParenExpr(self, pex: "ParenExpr") -> None: - pex.first.accept(self) + def visitXorExpr(self, node: "XorExpr") -> None: + node.first.accept(self) + if node.second: + node.second.accept(self) - def visitAndExpr(self, aex: "AndExpr") -> None: - aex.first.accept(self) + def visitParenTerm(self, node: "ParenTerm") -> None: + node.first.accept(self) - def visitOrExpr(self, oex: "OrExpr") -> None: - oex.first.accept(self) + def visitNotTerm(self, node: "NotTerm") -> None: + node.first.accept(self) - def visitVar(self, _) -> None: + def visitVarVar(self, node: "VarVar") -> None: pass + def visitTrueConst(self, node: "TrueConst") -> None: + pass -T = TypeVar("T") + def visitFalseConst(self, node: "FalseConst") -> None: + pass class ParamVisitor(Generic[T]): @@ -40,28 +65,40 @@ class ParamVisitor(Generic[T]): passes a parameter. """ - def visitVarExpr(self, vex: "VarExpr", param: T) -> None: - vex.first.accept(self, param) - if vex.second: - vex.second.accept(self, param) + def visitTermExpr(self, node: "TermExpr", param: T) -> None: + node.first.accept(self, param) + if node.second: + node.second.accept(self, param) + + def visitOrExpr(self, node: "OrExpr", param: T) -> None: + node.first.accept(self, param) + if node.second: + node.second.accept(self, param) - def visitNotExpr(self, nex: "NotExpr", param: T) -> None: - nex.first.accept(self, param) + def visitAndExpr(self, node: "AndExpr", param: T) -> None: + node.first.accept(self, param) + if node.second: + node.second.accept(self, param) - def visitParenExpr(self, pex: "ParenExpr", param: T) -> None: - pex.first.accept(self, param) + def visitXorExpr(self, node: "XorExpr", param: T) -> None: + node.first.accept(self, param) + if node.second: + node.second.accept(self, param) - def visitAndExpr(self, aex: "AndExpr", param: T) -> None: - aex.first.accept(self, param) + def visitParenTerm(self, node: "ParenTerm", param: T) -> None: + node.first.accept(self, param) - def visitOrExpr(self, oex: "OrExpr", param: T) -> None: - oex.first.accept(self, param) + def visitNotTerm(self, node: "NotTerm", param: T) -> None: + node.first.accept(self, param) - def visitVar(self, _, param: T) -> None: + def visitVarVar(self, node: "VarVar", param: T) -> None: pass + def visitTrueConst(self, node: "TrueConst", param: T) -> None: + pass -R = TypeVar("R") + def visitFalseConst(self, node: "FalseConst", param: T) -> None: + pass class RetVisitor(Generic[R]): @@ -71,56 +108,80 @@ class RetVisitor(Generic[R]): """ @abstractmethod - def visitVarExpr(self, vex: "VarExpr") -> R: + def visitTermExpr(self, node: "TermExpr") -> R: pass @abstractmethod - def visitNotExpr(self, nex: "NotExpr") -> R: + def visitOrExpr(self, node: "OrExpr") -> R: pass @abstractmethod - def visitParenExpr(self, pex: "ParenExpr") -> R: + def visitAndExpr(self, node: "AndExpr") -> R: pass @abstractmethod - def visitAndExpr(self, aex: "AndExpr") -> R: + def visitXorExpr(self, node: "XorExpr") -> R: pass @abstractmethod - def visitOrExpr(self, oex: "OrExpr") -> R: + def visitParenTerm(self, node: "ParenTerm") -> R: pass @abstractmethod - def visitVar(self, _) -> R: + def visitNotTerm(self, node: "NotTerm") -> R: pass + @abstractmethod + def visitVarVar(self, node: "VarVar") -> R: + pass -class RetParamVisitor(Generic[R, T]): + @abstractmethod + def visitTrueConst(self, node: "TrueConst") -> R: + pass + + @abstractmethod + def visitFalseConst(self, node: "FalseConst") -> R: + pass + + +class RetParamVisitor(Generic[T, R]): """ A visitor that visits each node in the AST and returns a value and passes a parameter. """ @abstractmethod - def visitVarExpr(self, vex: "VarExpr", param: T) -> R: + def visitTermExpr(self, node: "TermExpr", param: T) -> R: + pass + + @abstractmethod + def visitOrExpr(self, node: "OrExpr", param: T) -> R: + pass + + @abstractmethod + def visitAndExpr(self, node: "AndExpr", param: T) -> R: + pass + + @abstractmethod + def visitXorExpr(self, node: "XorExpr", param: T) -> R: pass @abstractmethod - def visitNotExpr(self, nex: "NotExpr", param: T) -> R: + def visitParenTerm(self, node: "ParenTerm", param: T) -> R: pass @abstractmethod - def visitParenExpr(self, pex: "ParenExpr", param: T) -> R: + def visitNotTerm(self, node: "NotTerm", param: T) -> R: pass @abstractmethod - def visitAndExpr(self, aex: "AndExpr", param: T) -> R: + def visitVarVar(self, node: "VarVar", param: T) -> R: pass @abstractmethod - def visitOrExpr(self, oex: "OrExpr", param: T) -> R: + def visitTrueConst(self, node: "TrueConst", param: T) -> R: pass @abstractmethod - def visitVar(self, _, param: T) -> R: + def visitFalseConst(self, node: "FalseConst", param: T) -> R: pass diff --git a/solver/passes/sympy_pass.py b/solver/passes/sympy_pass.py index a17aaaf..d8e3362 100644 --- a/solver/passes/sympy_pass.py +++ b/solver/passes/sympy_pass.py @@ -1,19 +1,23 @@ from typing import override + from parser.ast import ( - Var, Expr, - VarExpr, - NotExpr, - ParenExpr, - AndExpr, + TermExpr, OrExpr, + AndExpr, + XorExpr, + ParenTerm, + NotTerm, + VarVar, + TrueConst, + FalseConst, ) from parser.visitor import Visitor, RetVisitor from parser.parse import Parser from parser.lex import Lexer import sympy -from sympy.logic.boolalg import And, Or, Not +from sympy.logic.boolalg import Or, And, Xor, Not def run_pass(ast: Expr) -> Expr: @@ -39,15 +43,15 @@ def run_pass(ast: Expr) -> Expr: class SympyMappingVisitor(Visitor): """ - A visitor that visits each node in the AST and adds Var nodes to the symbolMap. + A visitor that visits each node in the AST and adds Var nodes to the symbols. """ def __init__(self) -> None: - self.symbolMap: dict[str, sympy.Symbol] = {} + self.symbols: dict[str, sympy.Symbol] = {} @override - def visitVar(self, va: Var) -> None: - self.symbolMap[va.name] = sympy.Symbol(va.name) + def visitVarVar(self, node: VarVar) -> None: + self.symbols[node.name] = sympy.Symbol(node.name) class TranslateToSympy(RetVisitor[sympy.Basic]): @@ -60,35 +64,76 @@ def __init__(self, symbols: dict[str, sympy.Symbol]) -> None: self.symbols = symbols @override - def visitVarExpr(self, vex: VarExpr) -> sympy.Basic: - first: sympy.Basic = vex.first.acceptRet(self) - if vex.second: - second: sympy.Basic = vex.second.first.acceptRet(self) - if isinstance(vex.second, AndExpr): + def visitTermExpr(self, node: TermExpr) -> sympy.Basic: + first: sympy.Basic = node.first.acceptRet(self) + if node.second: + second: sympy.Basic = node.second.acceptRet(self) + if isinstance(node.second, OrExpr): + return Or(first, second) + elif isinstance(node.second, AndExpr): return And(first, second) - elif isinstance(vex.second, OrExpr): + elif isinstance(node.second, XorExpr): + return Xor(first, second) + return first + + @override + def visitOrExpr(self, node: OrExpr) -> sympy.Basic: + first: sympy.Basic = node.first.acceptRet(self) + if node.second: + second: sympy.Basic = node.second.acceptRet(self) + if isinstance(node.second, OrExpr): return Or(first, second) + elif isinstance(node.second, AndExpr): + return And(first, second) + elif isinstance(node.second, XorExpr): + return Xor(first, second) return first @override - def visitNotExpr(self, nex: NotExpr) -> sympy.Basic: - return Not(nex.first.acceptRet(self)) + def visitAndExpr(self, node: AndExpr) -> sympy.Basic: + first: sympy.Basic = node.first.acceptRet(self) + if node.second: + second: sympy.Basic = node.second.acceptRet(self) + if isinstance(node.second, OrExpr): + return Or(first, second) + elif isinstance(node.second, AndExpr): + return And(first, second) + elif isinstance(node.second, XorExpr): + return Xor(first, second) + return first @override - def visitParenExpr(self, pex: ParenExpr) -> sympy.Basic: - return pex.first.acceptRet(self) + def visitXorExpr(self, node: XorExpr) -> sympy.Basic: + first: sympy.Basic = node.first.acceptRet(self) + if node.second: + second: sympy.Basic = node.second.acceptRet(self) + if isinstance(node.second, OrExpr): + return Or(first, second) + elif isinstance(node.second, AndExpr): + return And(first, second) + elif isinstance(node.second, XorExpr): + return Xor(first, second) + return first @override - def visitAndExpr(self, aex: AndExpr) -> sympy.Basic: - pass + def visitParenTerm(self, node: ParenTerm) -> sympy.Basic: + return node.first.acceptRet(self) @override - def visitOrExpr(self, oex: OrExpr) -> sympy.Basic: - pass + def visitNotTerm(self, node: NotTerm) -> sympy.Basic: + return Not(node.first.acceptRet(self)) @override - def visitVar(self, va: Var) -> sympy.Basic: - return self.symbols[va.name] + def visitVarVar(self, node: VarVar) -> sympy.Basic: + return self.symbols[node.name] + + @override + def visitTrueConst(self, node: TrueConst) -> sympy.Basic: + return sympy.true + + @override + def visitFalseConst(self, node: FalseConst) -> sympy.Basic: + return sympy.false if __name__ == "__main__": @@ -99,4 +144,6 @@ def visitVar(self, va: Var) -> sympy.Basic: p: Parser = Parser() ast: Expr = p.parse(l.tokens) - run_pass(ast) + simplified_ast: Expr = run_pass(ast) + + assert str(simplified_ast) == "B" diff --git a/solver/passes/z3_pass.py b/solver/passes/z3_pass.py index c55d551..d3243f8 100644 --- a/solver/passes/z3_pass.py +++ b/solver/passes/z3_pass.py @@ -2,10 +2,21 @@ import html, re -from parser.ast import AndExpr, OrExpr, Expr, NotExpr, ParenExpr, Var, VarExpr +from parser.ast import ( + Expr, + TermExpr, + OrExpr, + AndExpr, + XorExpr, + ParenTerm, + NotTerm, + VarVar, + TrueConst, + FalseConst, +) from parser.visitor import Visitor, RetVisitor -from parser.lex import Lexer from parser.parse import Parser +from parser.lex import Lexer import z3 @@ -55,45 +66,86 @@ def __init__(self) -> None: self.symbols: dict[str, z3.Bool] = {} @override - def visitVar(self, va: Var): - self.symbols[va.name] = z3.Bool(va.name) + def visitVarVar(self, node: VarVar) -> None: + self.symbols[node.name] = z3.Bool(node.name) class TranslateToZ3(RetVisitor[z3.ExprRef]): def __init__(self, symbols: dict[str, z3.Bool]) -> None: - self.symbols = symbols + self.symbols: dict[str, z3.Bool] = symbols @override - def visitVarExpr(self, vex: VarExpr) -> z3.ExprRef: - first: z3.ExprRef = vex.first.acceptRet(self) - if vex.second: - second: z3.ExprRef = vex.second.first.acceptRet(self) - if isinstance(vex.second, AndExpr): + def visitTermExpr(self, node: TermExpr) -> z3.ExprRef: + first: z3.ExprRef = node.first.acceptRet(self) + if node.second: + second: z3.ExprRef = node.second.acceptRet(self) + if isinstance(node.second, OrExpr): + return z3.Or(first, second) + elif isinstance(node.second, AndExpr): return z3.And(first, second) - elif isinstance(vex.second, OrExpr): + elif isinstance(node.second, XorExpr): + return z3.Xor(first, second) + return first + + @override + def visitOrExpr(self, node: OrExpr) -> z3.ExprRef: + first: z3.ExprRef = node.first.acceptRet(self) + if node.second: + second: z3.ExprRef = node.second.acceptRet(self) + if isinstance(node.second, OrExpr): return z3.Or(first, second) + elif isinstance(node.second, AndExpr): + return z3.And(first, second) + elif isinstance(node.second, XorExpr): + return z3.Xor(first, second) + return first + + @override + def visitAndExpr(self, node: AndExpr) -> z3.ExprRef: + first: z3.ExprRef = node.first.acceptRet(self) + if node.second: + second: z3.ExprRef = node.second.acceptRet(self) + if isinstance(node.second, OrExpr): + return z3.Or(first, second) + elif isinstance(node.second, AndExpr): + return z3.And(first, second) + elif isinstance(node.second, XorExpr): + return z3.Xor(first, second) + return first + + @override + def visitXorExpr(self, node: XorExpr) -> z3.ExprRef: + first: z3.ExprRef = node.first.acceptRet(self) + if node.second: + second: z3.ExprRef = node.second.acceptRet(self) + if isinstance(node.second, OrExpr): + return z3.Or(first, second) + elif isinstance(node.second, AndExpr): + return z3.And(first, second) + elif isinstance(node.second, XorExpr): + return z3.Xor(first, second) return first @override - def visitNotExpr(self, nex: NotExpr) -> z3.ExprRef: - return z3.Not(nex.first.acceptRet(self)) + def visitParenTerm(self, node: ParenTerm) -> z3.ExprRef: + return node.first.acceptRet(self) @override - def visitParenExpr(self, pex: ParenExpr) -> z3.ExprRef: - return pex.first.acceptRet(self) + def visitNotTerm(self, node: NotTerm) -> z3.ExprRef: + return z3.Not(node.first.acceptRet(self)) @override - def visitAndExpr(self) -> None: - pass + def visitVarVar(self, node: VarVar) -> z3.ExprRef: + return self.symbols[node.name] @override - def visitOrExpr(self) -> None: - pass + def visitTrueConst(self, node: TrueConst) -> z3.ExprRef: + return z3.BoolVal(True) @override - def visitVar(self, va: Var) -> z3.ExprRef: - return self.symbols[va.name] + def visitFalseConst(self, node: FalseConst) -> z3.ExprRef: + return z3.BoolVal(False) if __name__ == "__main__": diff --git a/tests/test_lex.py b/tests/test_lex.py index 81bf255..4e64680 100644 --- a/tests/test_lex.py +++ b/tests/test_lex.py @@ -3,20 +3,35 @@ class TestLexer(unittest.TestCase): + def test_or(self) -> None: + prog = "A|B" + l: Lexer = Lexer() + l.lex(prog) + self.assertEqual(l.getTokens(), ["A", "|", "B", ""]) + + def test_and(self) -> None: + prog = "A & C" + l: Lexer = Lexer() + l.lex(prog) + self.assertEqual(l.getTokens(), ["A", "&", "C", ""]) + + def test_xor(self) -> None: + prog = "A ^ B" + l: Lexer = Lexer() + l.lex(prog) + self.assertEqual(l.getTokens(), ["A", "^", "B", ""]) + + def test_not(self) -> None: + prog = "! A" + l: Lexer = Lexer() + l.lex(prog) + self.assertEqual(l.getTokens(), ["!", "A", ""]) + def test_lex(self) -> None: - prog: str = "(A & B) | !C" + prog: str = "(A & B) | !C & t ^ f" l: Lexer = Lexer() l.lex(prog) self.assertEqual( - l.getTokens(), ["(", "A", "&", "B", ")", "|", "!", "C", ""] + l.getTokens(), + ["(", "A", "&", "B", ")", "|", "!", "C", "&", "t", "^", "f", ""], ) - - prog = "A|B" - l2: Lexer = Lexer() - l2.lex(prog) - self.assertEqual(l2.getTokens(), ["A", "|", "B", ""]) - - prog = "!(A & C)" - l3: Lexer = Lexer() - l3.lex(prog) - self.assertEqual(l3.getTokens(), ["!", "(", "A", "&", "C", ")", ""]) diff --git a/tests/test_parse.py b/tests/test_parse.py index 7dfec91..fd97268 100644 --- a/tests/test_parse.py +++ b/tests/test_parse.py @@ -1,123 +1,64 @@ -from typing import override from unittest import TestCase -from parser.ast import NotExpr, OrExpr, ParenExpr, VarExpr, AndExpr +from parser.ast import Expr from parser.parse import Parser -from parser.visitor import Visitor, RetParamVisitor class TestParse(TestCase): - def test_parse(self) -> None: + def test_or(self) -> None: + tokens: list[str] = ["A", "|", "B", ""] p: Parser = Parser() - tree = p.parse(["!", "(", "A", "&", "!", "B", "|", "C", ")", ""]) + ast: Expr = p.parse(tokens) self.assertEqual( - str(tree), - "!(A & !B | C)", + str(ast), + "A | B", ) - def test_visitor(self) -> None: - class CountVisitor(Visitor): - count: int = 0 - - def __init__(self): - self.visited: list[str] = [] - - @override - def visitVarExpr(self, vex: VarExpr) -> None: - CountVisitor.count += 1 - self.visited.append("VarExpr") - vex.first.accept(self) - if vex.second: - vex.second.accept(self) - - @override - def visitNotExpr(self, nex: NotExpr) -> None: - CountVisitor.count += 1 - self.visited.append("NotExpr") - nex.first.accept(self) - - @override - def visitParenExpr(self, pex: ParenExpr) -> None: - CountVisitor.count += 1 - self.visited.append("ParenExpr") - pex.first.accept(self) - - @override - def visitAndExpr(self, aex: AndExpr) -> None: - CountVisitor.count += 1 - self.visited.append("AndExpr") - aex.first.accept(self) - - @override - def visitOrExpr(self, oex: OrExpr) -> None: - CountVisitor.count += 1 - self.visited.append("OrExpr") - oex.first.accept(self) - - @override - def visitVar(self, _) -> None: - CountVisitor.count += 1 - self.visited.append("Var") - + def test_and(self) -> None: + tokens: list[str] = ["A", "&", "B", ""] p: Parser = Parser() - tree = p.parse(["!", "(", "A", "&", "!", "B", "|", "C", ")", ""]) - visitor: CountVisitor = CountVisitor() - tree.accept(visitor) - - self.assertEqual(visitor.count, 11) + ast: Expr = p.parse(tokens) self.assertEqual( - visitor.visited, - [ - "NotExpr", - "ParenExpr", - "VarExpr", - "Var", - "AndExpr", - "NotExpr", - "VarExpr", - "Var", - "OrExpr", - "VarExpr", - "Var", - ], + str(ast), + "A & B", ) - def test_ret_visitor(self) -> None: - class CountVisitor(RetParamVisitor[int, int]): - def __init__(self): - self.visited: list[str] = [] - - @override - def visitVarExpr(self, vex: VarExpr, param: int) -> int: - tmp = vex.first.acceptParamRet(self, param) + 1 - if vex.second: - return vex.second.acceptParamRet(self, tmp) - else: - return tmp - - @override - def visitNotExpr(self, nex: NotExpr, param: int) -> int: - return nex.first.acceptParamRet(self, param) + 1 - - @override - def visitParenExpr(self, pex: ParenExpr, param: int) -> int: - return pex.first.acceptParamRet(self, param) + 1 - - @override - def visitAndExpr(self, aex: AndExpr, param: int) -> int: - return aex.first.acceptParamRet(self, param) + 1 - - @override - def visitOrExpr(self, oex: OrExpr, param: int) -> int: - return oex.first.acceptParamRet(self, param) + 1 - - @override - def visitVar(self, _, param: int) -> int: - return param + 1 + def test_xor(self) -> None: + tokens: list[str] = ["A", "^", "B", ""] + p: Parser = Parser() + ast: Expr = p.parse(tokens) + self.assertEqual( + str(ast), + "A ^ B", + ) + def test_not(self) -> None: + tokens: list[str] = ["!", "A", ""] p: Parser = Parser() - tree = p.parse(["!", "(", "A", "&", "!", "B", "|", "C", ")", ""]) - visitor: CountVisitor = CountVisitor() - count = tree.acceptParamRet(visitor, 0) + ast: Expr = p.parse(tokens) + self.assertEqual( + str(ast), + "!A", + ) - self.assertEqual(count, 11) + def test_parse(self) -> None: + tokens: list[str] = [ + "!", + "(", + "A", + "&", + "!", + "B", + "|", + "f", + ")", + "^", + "t", + "", + ] + p: Parser = Parser() + ast: Expr = p.parse(tokens) + self.assertEqual( + str(ast), + "!(A & !B | f) ^ t", + ) diff --git a/utils/metrics.py b/utils/metrics.py index 31eecc8..1341405 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -1,28 +1,40 @@ from typing import override -from parser.ast import VarExpr, NotExpr, ParenExpr, AndExpr, OrExpr +from parser.ast import OrExpr, AndExpr, XorExpr, NotTerm from parser.visitor import Visitor class OpCounter(Visitor): """Counts the number of boolean operators visited""" - _count: int = 0 + def __init__(self) -> None: + self._count: int = 0 + + def getCount(self) -> int: + return self._count @override - def visitNotExpr(self, nex: NotExpr) -> None: - OpCounter._count += 1 - nex.first.accept(self) + def visitOrExpr(self, node: OrExpr) -> None: + self._count += 1 + node.first.accept(self) + if node.second: + node.second.accept(self) @override - def visitAndExpr(self, aex: AndExpr) -> None: - OpCounter._count += 1 - aex.first.accept(self) + def visitAndExpr(self, node: AndExpr) -> None: + self._count += 1 + node.first.accept(self) + if node.second: + node.second.accept(self) @override - def visitOrExpr(self, oex: OrExpr) -> None: - OpCounter._count += 1 - oex.first.accept(self) + def visitXorExpr(self, node: XorExpr) -> None: + self._count += 1 + node.first.accept(self) + if node.second: + node.second.accept(self) - def getCount(self) -> int: - return self._count + @override + def visitNotTerm(self, node: NotTerm) -> None: + self._count += 1 + node.first.accept(self)