From 969887ee3b626380921fa0cb8f6360cb11ff3ed9 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Mon, 14 Oct 2024 13:03:26 -0400 Subject: [PATCH] test fsm_union and walk_fsm --- tests/fsm/test_parsing.py | 101 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/tests/fsm/test_parsing.py b/tests/fsm/test_parsing.py index 3f4c1ba42..b7446fa0c 100644 --- a/tests/fsm/test_parsing.py +++ b/tests/fsm/test_parsing.py @@ -204,3 +204,104 @@ def test_sequential_parse_example(cleanup_lark_import): if i + 1 == len(input_tokens): assert all(tk in next_vocab for tk in ["\n", "\nde", " ", " + 1"]) + + +# TODO: Remove once fsm_union and walk_fsm are implemented in Outlines-Core +import interegular # noqa + +from outlines.fsm.parsing import fsm_union, walk_fsm # noqa + + +def test_outlines_interegular_union_consistency(): + fsm0 = interegular.parse_pattern(r"abc").to_fsm() + fsm1 = interegular.parse_pattern(r"WXYZ").to_fsm() + fsm2 = interegular.parse_pattern(r"12345").to_fsm() + + interegular_unioned_fsm = fsm0 | fsm1 | fsm2 + outlines_unioned_fsm, _ = fsm_union([fsm0, fsm1, fsm2]) + + assert list(outlines_unioned_fsm.strings()) == list( + interegular_unioned_fsm.strings() + ) + + +def _reconstruct_fsms(fsm, fsms_to_trans_finals): + """Reconstruct the original fsms for testing purposes""" + reconstructed_fsms = [] + for transitions, finals, state_map in fsms_to_trans_finals.values(): + inv_state_map = {new: orig for orig, news in state_map.items() for new in news} + states = set(inv_state_map.values()) + initial = inv_state_map.get(fsm.initial) or next( + (orig for orig, news in state_map.items() if fsm.initial in news), None + ) + finals = {inv_state_map[s] for s in finals} + + transition_map = {} + alphabet = {} + for trans_id, (from_state, to_state) in enumerate(transitions): + orig_from, orig_to = inv_state_map[from_state], inv_state_map[to_state] + # Collect symbols associated with the transition + symbols = { + symbol + for trans, dest in fsm.map.get(from_state, {}).items() + if dest == to_state + for symbol in fsm.alphabet.by_transition.get(trans, []) + } + if symbols: + # NOTE: THIS RECONSTRUCTOR DOESNT WORK FOR MORE THAN ONE TRANSITION PER SYMBOL + assert len(symbols) == 1 + symbol = list(symbols)[0] + alphabet[symbol] = trans_id + transition_map.setdefault(orig_from, {})[trans_id] = orig_to + + reconstructed_fsms.append( + interegular.fsm.FSM( + alphabet=interegular.fsm.Alphabet(alphabet), + states=frozenset(states), + initial=initial, + finals=frozenset(finals), + map=transition_map, + __no_validation__=True, + ) + ) + return reconstructed_fsms + + +def test_fsm_to_trans_finals_reconstruction(): + """Assert that _fsms_to_trans_finals is correct by reconstructing original fsms""" + fsm0 = interegular.parse_pattern(r"abc").to_fsm() + fsm1 = interegular.parse_pattern(r"XYZ").to_fsm() + fsm2 = interegular.parse_pattern(r"12345").to_fsm() + + fsm, _fsms_to_trans_finals = fsm_union([fsm0, fsm1, fsm2]) + + reconstructed = _reconstruct_fsms(fsm, _fsms_to_trans_finals) + + # assert reconstruction equivalent + assert list(fsm0.strings()) == list(reconstructed[0].strings()) + assert list(fsm1.strings()) == list(reconstructed[1].strings()) + assert list(fsm2.strings()) == list(reconstructed[2].strings()) + + +def test_walk_fsm(): + fsm = interegular.parse_pattern(r"abc*d").to_fsm() + # convert to BetterFSM + fsm = fsm_union([fsm])[0] + + # if match, produce equivalent number of states, assert state can terminate + transitions = [fsm.alphabet[letter] for letter in "abcccd"] + accepted_states = walk_fsm(fsm, transitions, fsm.initial, full_match=True) + assert len(accepted_states) == len(transitions) + assert accepted_states[-1] in fsm.finals + + # if no match, assert empty + accepted_states = walk_fsm( + fsm, [fsm.alphabet[letter] for letter in "b"], fsm.initial, full_match=True + ) + assert accepted_states == [] + + # if full_match, but last state not present, assert empty + accepted_states = walk_fsm( + fsm, [fsm.alphabet[letter] for letter in "abc"], fsm.initial, full_match=True + ) + assert accepted_states == []