Skip to content

Commit

Permalink
Add pairwise helper function (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
eliotwrobson authored Sep 8, 2023
1 parent 4d99436 commit 274d142
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 11 deletions.
13 changes: 12 additions & 1 deletion automata/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""Miscellaneous utility functions and classes."""

from collections import defaultdict
from itertools import count
from itertools import count, tee, zip_longest
from typing import Any, Callable, Dict, Generic, Iterable, List, Set, Tuple, TypeVar

from frozendict import frozendict
Expand Down Expand Up @@ -106,3 +106,14 @@ def refine(self, S: Iterable[T]) -> List[Tuple[int, int]]:
output.append((id(AintS), Aid))

return output


def pairwise(iterable: Iterable[T], final_none: bool = False) -> Iterable[Tuple[T, T]]:
"""Based on https://docs.python.org/3/library/itertools.html#itertools.pairwise"""
a, b = tee(iterable)
next(b, None)

if final_none:
return zip_longest(a, b)

return zip(a, b)
11 changes: 7 additions & 4 deletions automata/fa/dfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import automata.base.exceptions as exceptions
import automata.fa.fa as fa
import automata.fa.nfa as nfa
from automata.base.utils import PartitionRefinement, get_renaming_function
from automata.base.utils import PartitionRefinement, get_renaming_function, pairwise

DFAStateT = fa.FAStateT

Expand Down Expand Up @@ -1072,8 +1072,7 @@ def successors(
include_input = not strict
sorted_symbols = sorted(self.input_symbols, reverse=reverse, key=key)
symbol_succ: Dict[str, Optional[str]] = {
symbol_a: symbol_b
for symbol_a, symbol_b in zip(sorted_symbols, sorted_symbols[1:])
symbol_a: symbol_b for symbol_a, symbol_b in pairwise(sorted_symbols)
}
symbol_succ[sorted_symbols[-1]] = None
# Special case for None
Expand Down Expand Up @@ -1757,7 +1756,11 @@ def _get_input_path(
"""

state_history = list(self.read_input_stepwise(input_str, ignore_rejection=True))
path = list(zip(state_history, state_history[1:], input_str))

path = [
(*state_pair, char)
for state_pair, char in zip(pairwise(state_history), input_str)
]

last_state = state_history[-1] if state_history else self.initial_state
accepted = last_state in self.final_states
Expand Down
6 changes: 3 additions & 3 deletions automata/regex/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import copy
import re
from collections import deque
from itertools import chain, count, product, repeat, zip_longest
from itertools import chain, count, product, repeat
from typing import AbstractSet, Deque, Dict, Iterable, List, Optional, Set, Tuple, Type

from typing_extensions import NoReturn, Self

import automata.base.exceptions as exceptions
from automata.base.utils import get_renaming_function
from automata.base.utils import get_renaming_function, pairwise
from automata.regex.lexer import Lexer, Token
from automata.regex.postfix import (
InfixOperator,
Expand Down Expand Up @@ -515,7 +515,7 @@ def add_concat_and_empty_string_tokens(
# Pairs of tokens to insert empty string literals between
empty_string_pairs = [(LeftParen, RightParen)]

for curr_token, next_token in zip_longest(token_list, token_list[1:]):
for curr_token, next_token in pairwise(token_list, True):
final_token_list.append(curr_token)

if next_token is not None:
Expand Down
6 changes: 3 additions & 3 deletions automata/regex/postfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import abc
from collections import deque
from itertools import zip_longest
from typing import Deque, List, Optional, Tuple, TypeVar, cast
from itertools import chain, zip_longest
from typing import Deque, Iterable, List, Optional, Tuple, TypeVar, cast

import automata.base.exceptions as exceptions
from automata.regex.lexer import Token
Expand Down Expand Up @@ -75,7 +75,7 @@ def __repr__(self) -> str:
def validate_tokens(token_list: List[Token]) -> None:
"""Validate the inputted tokens list (in infix ordering)."""

token_list_prev: List[Optional[Token]] = [None] + token_list
token_list_prev: Iterable[Optional[Token]] = chain([None], token_list)

paren_counter = 0

Expand Down

0 comments on commit 274d142

Please sign in to comment.