Skip to content

Commit

Permalink
Add benchmark: CFG rejection sampling + CFG no rejection sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Aug 12, 2024
1 parent 9f9b48e commit a0df818
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions benchmarks/bench_cfg_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,31 @@ def setup(self, grammar_name):
)

@staticmethod
def _run_random_cfg(guide):
def _run_random_cfg(guide, rejection_sampling=True):
state = guide.initial_state
token_ids = list(guide.tokenizer.vocabulary.values())
for i in range(40):
# simulate ordering of logits top prob to lowest prob
random.shuffle(token_ids)
# simulate sampling and state update
next_token_id = next(guide.iter_valid_token_ids(state, token_ids))
state = guide.get_next_state(state, next_token_id)
if rejection_sampling:
next_token_id = next(guide.iter_valid_token_ids(state, token_ids))
state = guide.get_next_state(state, next_token_id)
else:
next_token_id = random.choice(guide.get_next_instruction(state).tokens)
state = guide.get_next_state(state, next_token_id)

@cache_disabled()
def time_cfg_guide_setup(self, grammar_name):
CFGGuide(benched_grammars[grammar_name], self.tokenizer)

@cache_disabled()
def time_cfg_guide_run_rejection_sampling(self, grammar):
self._run_random_cfg(self.prebuilt_cfg_guide, rejection_sampling=True)

@cache_disabled()
def time_cfg_guide_run(self, grammar):
self._run_random_cfg(self.prebuilt_cfg_guide)
self._run_random_cfg(self.prebuilt_cfg_guide, rejection_sampling=False)

@cache_disabled()
def peakmem_cfg_guide_run(self, grammar):
Expand Down

0 comments on commit a0df818

Please sign in to comment.