Skip to content

Commit

Permalink
finish
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubospica committed Dec 12, 2024
1 parent dd7feea commit 70104a3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
20 changes: 11 additions & 9 deletions examples/benchmark/bench_grammar_compile_mask_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
build_token_enforcer_tokenizer_data,
)
from outlines.fsm.guide import Guide, RegexGuide
from outlines.fsm.json_schema import build_regex_from_schema, convert_json_schema_to_str
from outlines.fsm.json_schema import convert_json_schema_to_str
from outlines.generate.generator import bias_logits
from outlines.generate.json import build_regex_from_schema
from outlines.models import TransformerTokenizer
from tqdm import tqdm
from transformers import AutoTokenizer
Expand All @@ -23,18 +24,18 @@
wrong_data_indices = [1]


def xgrammar_build(schema: str, tokenizer_info: TokenizerInfo):
grammar = BuiltinGrammar.json_schema(schema, strict_mode=False)
matcher = GrammarMatcher(grammar, tokenizer_info)
def xgrammar_build(schema: str, grammar_compiler: xgr.GrammarCompiler):
grammar = grammar_compiler.compile_json_schema(schema)
matcher = xgr.GrammarMatcher(grammar)
return matcher


def xgrammar_exec(
matcher: GrammarMatcher, logits: torch.Tensor, bitmask: torch.Tensor, token_id: int
matcher: xgr.GrammarMatcher, logits: torch.Tensor, bitmask: torch.Tensor, token_id: int
):
# Logits processing
matcher.fill_next_token_bitmask(bitmask)
matcher.apply_token_bitmask_inplace(logits, bitmask)
xgr.apply_token_bitmask_inplace(logits, bitmask)
# Update state
assert matcher.accept_token(token_id)
return
Expand Down Expand Up @@ -93,7 +94,8 @@ def lmformatenforcer_exec(token_enforcer: TokenEnforcer, logits: torch.Tensor, t
hf_model_path = "meta-llama/Llama-3.1-8B-Instruct"

hf_tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
xgrammar_tokenizer_info = TokenizerInfo.from_huggingface(hf_tokenizer)
xgrammar_tokenizer_info = xgr.TokenizerInfo.from_huggingface(hf_tokenizer)
xgrammar_grammar_compiler = xgr.GrammarCompiler(xgrammar_tokenizer_info)
outlines_tokenizer = TransformerTokenizer(hf_tokenizer)
lmformatenforcer_tokenizer = build_token_enforcer_tokenizer_data(hf_tokenizer)

Expand Down Expand Up @@ -137,8 +139,8 @@ def lmformatenforcer_exec(token_enforcer: TokenEnforcer, logits: torch.Tensor, t
start = time.perf_counter()
try:
if backend == "xgrammar":
worker = xgrammar_build(schema, xgrammar_tokenizer_info)
bitmask = GrammarMatcher.allocate_token_bitmask(worker.vocab_size)
worker = xgrammar_build(schema, xgrammar_grammar_compiler)
bitmask = xgr.allocate_token_bitmask(worker.vocab_size)
elif backend == "outlines":
worker = outlines_build(schema, outlines_tokenizer)
elif backend == "lmformatenforcer":
Expand Down
3 changes: 1 addition & 2 deletions examples/hf_transformers/transformers_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,4 @@
]
responses = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
for response in responses:
print(response)
print()
print(response, end="\n\n")

0 comments on commit 70104a3

Please sign in to comment.