-
Notifications
You must be signed in to change notification settings - Fork 493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove broken final state loop #874
Conversation
A related matter: https://github.com/outlines-dev/outlines/blob/4f8433d8d6633b0780c3a6c27981f9adffbe49f5/outlines/generate/generator.py#L94 This code is fundamentally broken, in my opinion, because it always stops generation when a final state is reached, regardless of outgoing transitions it may have. Instead, the condition for stopping should be that a stop-token has been generated. Right? |
Is there a minimal reproducing example we could add as a test? |
The issue does not show in the |
There are some test errors. I believe there is a condition still to be checked if the state does not exist in the transitions table. I’ll invest some time later today. @rlouf I didn’t run the tests, because I didn’t know how. |
I believe this should do it now. |
Okay, obviously not. I will invest some time into this today and hopefully come to a solution. |
I've looked into the breaking tests:
I believe these tests are testing an assumption that is fundamentally wrong. Final states can have outbound transitions, including into non-terminal states.
I'm wondering if the right thing to do would be to remove these tests, or what we would want to test instead. |
I have changed the tests to verify that if we are in a final state with no outbound transitions, a new generation will lead to us staying in a final state. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically the same as #884
Should we leave test_fsm.py
as is? Otherwise looks good.
Do you mean reverting the changes in |
Thanks, looks good to me! |
32cc20d
to
21247f8
Compare
Damn! I was facing this issue on Fri and spent a couple of days to finally figure out the solution only to find that this PR existed :) |
@br3no can you pls share how did you generate the FSM plot here? #874 (comment) |
@ekagra-ranjan sure. I used graphviz for that. Here's an example: import outlines
from transformers import AutoTokenizer
from graphviz import Digraph
def draw_state_machine(graph: dict, final_states: set, tokenizer):
dot = Digraph()
# Add nodes
for state in graph:
if state in final_states:
dot.node(str(state), str(state), color='salmon', style='filled', fillcolor='salmon')
else:
dot.node(str(state), str(state), color='lightblue', style='filled', fillcolor='lightblue')
# Prepare edge labels by aggregating transitions between the same nodes
edge_labels = {}
for state, transitions in graph.items():
for transition, end_state in transitions.items():
if end_state not in graph:
# Add end states not in the state map
dot.node(str(end_state), str(end_state), color='salmon', style='filled', fillcolor='salmon')
label = tokenizer.decode(int(transition))
edge_key = (str(state), str(end_state))
if edge_key not in edge_labels:
edge_labels[edge_key] = label
else:
# Append new label to existing label, separated by a comma
edge_labels[edge_key] += ", " + label
# Add edges with aggregated labels
for (start_state, end_state), label in edge_labels.items():
dot.edge(start_state, end_state, label=label)
# Render and view the graph
dot.render('state_machine', view=True, format='svg')
tokenizer_zephyr = AutoTokenizer.from_pretrained("stabilityai/stablelm-2-zephyr-1_6b")
regex = r"(12){1,2}"
model = outlines.models.transformers("stabilityai/stablelm-2-zephyr-1_6b")
generator_zephyr = outlines.generate.regex(
model,
regex,
)
draw_state_machine(generator_zephyr.fsm.states_to_token_maps, generator_zephyr.fsm.final_states, tokenizer_zephyr) |
@@ -193,12 +193,8 @@ def get_next_state(self, state: int, token_id: int) -> int: | |||
The new state of the guide. | |||
|
|||
""" | |||
if token_id == self.eos_token_id: | |||
if token_id == self.eos_token_id or state not in self.states_to_token_maps: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@br3no I was wondering if we really need the 2nd condition state not in self.states_to_token_maps
? The condition basically checks for states which do not have outgoing edges. But such states would be a part of final states in the FSM and this block of code adds EOS as an edge to such states which makes them have atleast one outgoing edges. Therefore, no states in FSM will be absent in the states_to_token_maps
. Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ekagra-ranjan yes, we do need this second condition.
The block of code you linked to does not add an EOS outbound transition to these states. It only adds transitions to final states which are present in states_to_token_maps
. But these states are not present there. states_to_token_subsets.get(state)
will return None
for these states.
I'm not really knowledgeable about the way Outlines (and interegular) build the state machines out of regexes. The matter of fact is that the states_to_token_maps
does not contain all states that are reachable. I have noticed this while debugging the code for some example regexes.
This is not a problem in principle, as these states are considered to be final and states_to_token_subsets.get(state) is None
is used all over the code to handle this special case (as in the block you linked to).
I actually believe this could be improved and Outlines would profit from removing this special case that needs to be thought of all over the place and could lead to bugs. But this is, as I said, not a problem in principle.
Fixes #856
The code this PR removes introduces an artificial and erroneous loop transition in every final state that is always traversed, regardless of the generation.
The comment doesn't make sense in my opinion, as the
if
above just handles exactly this case.Removing this piece of code fixes the bug that surfaced in the upgrade of outlines in the vLLM integration.