Skip to content

Commit 3cc2414

Browse files
committed
Fix issues with lookup table
1 parent 57bb62d commit 3cc2414

File tree

6 files changed

+45
-17
lines changed

6 files changed

+45
-17
lines changed

syncode/evaluation/code_eval.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,18 @@ def run_code_eval(
3636
else:
3737
stop_words = None
3838

39-
pbar = tqdm(total=len(problems) * num_samples_per_task)
4039
if debug_task_id is None:
4140
time1 = time.time()
41+
pbar = tqdm(total=len(problems) * num_samples_per_task)
4242

4343
# Run evaluation for all tasks
4444
for task_id in list(problems.keys())[:num_tasks]:
45-
outputs.append(CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, task_id, stop_words=stop_words))
45+
outputs.append(
46+
CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, task_id, stop_words=stop_words)
47+
)
48+
pbar.update(num_samples_per_task)
4649

50+
pbar.close()
4751
if out_path is not None: write_jsonl(out_path, samples)
4852

4953
avg_time = (time.time() - time1) / len(problems)
@@ -54,10 +58,12 @@ def run_code_eval(
5458
CodeEval.write_results(syncode, out_path, avg_time, functional_result, num_tasks)
5559
else: # Debugging a specific task
5660
debug_task_id = list(problems.keys())[debug_task_id]
57-
return CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, debug_task_id, logger=logger, stop_words=stop_words)
61+
return CodeEval.run_eval_for_task(
62+
syncode, num_samples_per_task, format_tabs, problems, samples, debug_task_id, logger=logger, stop_words=stop_words
63+
)
5864
return outputs
5965

60-
def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, task_id, logger=common.EmptyLogger(), stop_words=None):
66+
def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, task_id, logger=common.EmptyLogger(), stop_words=None):
6167
"""
6268
run evaluation for a specific task
6369
"""
@@ -96,7 +102,6 @@ def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samp
96102
)
97103
samples += [result]
98104
all_completions.append(completion)
99-
pbar.update(num_samples_per_task)
100105

101106
# Clear the cache
102107
torch.cuda.empty_cache()

syncode/grammar_decoder.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,15 @@ def _get_partial_codes(self, input_ids: torch.LongTensor) -> list[(str, bytes)]:
206206
output = []
207207
for idx in range(len(input_ids)):
208208
if self.parse_output_only:
209-
partial_code, remainder_bytes = self._bytes_to_string(self.byte_tokenizer.decode(input_ids[idx, self.start_from:].tolist(), skip_special_tokens=True))
209+
partial_code, remainder_bytes = self._bytes_to_string(
210+
self.byte_tokenizer.decode(
211+
input_ids[idx, self.start_from:].to('cpu', non_blocking=True).tolist(), skip_special_tokens=True)
212+
)
210213
else:
211-
partial_code, remainder_bytes = self._bytes_to_string(self.byte_tokenizer.decode(input_ids[idx].tolist(), skip_special_tokens=True))
214+
partial_code, remainder_bytes = self._bytes_to_string(
215+
self.byte_tokenizer.decode(
216+
input_ids[idx].to('cpu', non_blocking=True).tolist(), skip_special_tokens=True)
217+
)
212218
output.append((partial_code, remainder_bytes))
213219
return output
214220

syncode/language_model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@ def generate_grammar_constrained_completion(
122122
print("WARNING: Opportunistic mode requires SAMPLE or GREEDY_SEARCH generation mode.")
123123
if not batch_size == 1:
124124
print("WARNING: Opportunistic mode requires batch_size of 1.")
125+
126+
# Ensure pad_token_id is set
127+
if 'pad_token_id' not in dir(self.tokenizer):
128+
if self.tokenizer.pad_token_id is None:
129+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
130+
125131
# Use generate from transformers library for other modes
126132
generated_ids = self.model.generate(
127133
**inputs,
@@ -190,6 +196,7 @@ def _generate(
190196
logits_processor = self.model._get_logits_processor(gen_config, token_ids.size(1), token_ids, prefix_allowed_tokens_fn=None, logits_processor=[])
191197

192198
max_tokens = self.gen_args['max_new_tokens']+token_ids.size(1)
199+
self.model.config.pad_token_id = pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
193200

194201
while True:
195202
try:

syncode/mask_store/fsm_set.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,22 @@ class JointFSMState:
1212
def __init__(self, terminal: str, state_id: int):
1313
self.terminal = terminal
1414
self.state_id = state_id
15-
self._hash = hash((self.terminal, self.state_id)) # Pre-compute hash on creation
16-
15+
self._hash = JointFSMState.det_hash(self.terminal, self.state_id)
16+
1717
def __eq__(self, other: 'JointFSMState'):
1818
return self.terminal == other.terminal and self.state_id == other.state_id
1919

2020
def __hash__(self):
2121
return self._hash
2222

23+
@staticmethod
24+
def det_hash(terminal: str, state_id: int):
25+
h = 0
26+
for char in terminal:
27+
h = (h * 31 + ord(char)) & 0xFFFFFFFF
28+
h = (h * 31 + state_id) & 0xFFFFFFFF
29+
return h
30+
2331
def __repr__(self):
2432
return f"({self.terminal}, {self.state_id})"
2533

syncode/mask_store/mask_store.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self,
6565

6666
followings_terminas_map = None
6767
if parse_table is not None:
68-
followings_terminas_map = self._compute_following_terminals_map(terminal_names, parse_table)
68+
followings_terminas_map = self._compute_following_terminals_map(terminal_names, parse_table, ignore_terminals)
6969

7070
# Create consume prefix cache
7171
self._consume_prefix_cache = {}
@@ -105,8 +105,9 @@ def init_mask_store(
105105

106106
if use_cache and os.path.exists(fsm_path):
107107
try:
108-
mask_store = pickle.load(open(fsm_path, 'rb'))
109-
return mask_store
108+
with open(fsm_path, 'rb') as f:
109+
mask_store = pickle.load(f)
110+
return mask_store
110111
except Exception as e:
111112
logger.warning(f"Error loading mask store: {e}")
112113

@@ -134,7 +135,8 @@ def init_mask_store(
134135
def _compute_following_terminals_map(
135136
self,
136137
terminals: Iterable[str],
137-
parse_table
138+
parse_table,
139+
ignore_terminals: Iterable[str]
138140
) -> defaultdict:
139141
"""
140142
From terminals, filter out terminals that cannot follow the current terminal
@@ -150,9 +152,8 @@ def _compute_following_terminals_map(
150152
# We iterate through each cur_terminal:
151153
for cur_terminal in terminals:
152154
# Add all ignore terminals to the following terminals
153-
for next_terminal in terminals:
154-
if 'IGNORE' in next_terminal:
155-
following_terminals_map[cur_terminal].add(next_terminal)
155+
for next_terminal in ignore_terminals:
156+
following_terminals_map[cur_terminal].add(next_terminal)
156157

157158
# We iterate through each parser_state:
158159
for _, row in parse_table.states.items():

tests/test_language_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ class TestTokenizer:
3838
def __init__(self) -> None:
3939
vocab = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/', '(', ')', ' ', '\n', '\t', '=']
4040
self.vocab = vocab
41-
self.eos_token_id = ''
41+
self.eos_token_id = 1
42+
self.pad_token_id = 2
4243

4344
def __call__(self, input_batch: list[str], return_tensors="pt") -> BatchEncoding:
4445
# This works since we have single character tokens

0 commit comments

Comments
 (0)