-
Notifications
You must be signed in to change notification settings - Fork 485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement prompt/generation alignment #161
Comments
First thoughts on how this could be achieved. Let's consider the prompt
What do we do when (2) gives several matches? |
In case of multiple matches, we can rank by the integer id of the token(s). A smaller integer implies more frequently occurring token in the BPE tokenizing process. How about evaluating the log probs of the sequence? I think smaller integer should have higher log prob. |
Here is a quick outline of a solution that uses from outlines.text.parsing import find_partial_matches
from transformers import AutoTokenizer
import interegular
import collections
tokenizer = AutoTokenizer.from_pretrained("gpt2")
vocabulary = tokenizer.get_vocab()
# The BPE tokenizer encodes spaces as Ġ so anything that is based on string matching
# will require us to do this.
sorted_vocabulary = [
tokenizer.convert_tokens_to_string([k]) for k, v in sorted(vocabulary.items(), key=lambda kv: kv[1])
]
prompt = "This is a new day"
fsm = interegular.parse_pattern(prompt).to_fsm()
tokenized_prompt = tokenizer.encode(prompt)
prompt_tokens = [tokenizer.decode(t) for t in tokenized_prompt]
token_idx_in_prompt = [prompt.rfind(pt) for pt in prompt_tokens] # fails if appears several times do something better by tracking the position in the prompt string
found = defaultdict(list)
for vocab_str in sorted_vocabulary:
pmatch = find_partial_matches(
fsm,
vocab_str
)
if pmatch != set():
end_idx, states = pmatch.pop() # We need to loop over the matches instead
if end_idx is not None and states[-1] == len(prompt):
if states[0] in prompt_idx:
found[token_idx_in_prompt.index(states[0])].append(vocab_str)
print(found)
# {4: [' day', ' days', ' daylight', ' daytime']}) We then need to back one token, generate a next token using the masks we can build from the above list and then generate the sequence as usual. Now I understand my intuition behind leaving a space at the end of my prompt: this tells the model that it shouldn't complete the word. As you can see, the However, if you run the code above for the prompt "This is a new day ", you can see that it backs up one token (the whitespace), and suggests the ~33,000 tokens that start with a whitespace as potential continuations. |
Another fun one: import outlines.models as models
import outlines.text.generate as generate
model = models.transformers("gpt2")
prompt = "Is regex-guided generation useful? "
unguided = generate.continuation(model, max_tokens=30)(prompt)
guided = generate.regex(model, r"(Yes|No)", max_tokens=30)(prompt)
print(guided)
# Is regex-guided generation useful? No
prompt = "Is regex-guided generation useful?"
guided = generate.regex(model, r"(Yes|No)", max_tokens=30)(prompt)
print(guided)
# Is regex-guided generation useful?No
prompt = "Is regex-guided generation useful?"
guided = generate.regex(model, r"( )?(Yes|No)", max_tokens=30)(prompt)
print(guided)
# Is regex-guided generation useful? No
print([k for k in model.tokenizer.vocabulary.keys() if k.endswith("Yes")])
# ['Yes', 'ĠYes'] The "right" prompting here would be to leave a whitespace after the question, since we don't want "useful?" to be completed. However, " Yes" might be the most likely answer, as this is typically how the model would have tokenized "Is regex-generation useful? Yes". So we need to back one token and add this character to the regex. In this case, we should be able to match " Yes", " No" and also a succession of whitespace and "Yes" or "No". |
Do you think it would be right strategy to still create the |
Guidance implements a method called token healing, which consists in correcting for the quirks introduced by modern encodings like BPE. See this notebook for a thorough explanation of why this is necessary. The implementation for
Transformers
models is here.This consists in backtracking one or several tokens and start generation by imposing that we reproduce the text that corresponds to the removed tokens. This can be integrated in the
__call__
method of theSequence
class.The text was updated successfully, but these errors were encountered: