Skip to content

Commit

Permalink
update codegen inference script
Browse files Browse the repository at this point in the history
  • Loading branch information
zfj1998 committed May 3, 2024
1 parent c2ec0dd commit cecce36
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 0 deletions.
1 change: 1 addition & 0 deletions RepoCoder/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ This project contains the basic components of RepoCoder. Here is an overview:
|-- build_prompt.py # build the prompt with the unfinished code and the retrieved code snippets
|-- run_pipeline.py # run the code completion pipeline
|-- compute_score.py # evaluate the performance of the code completion
|-- codegen_inference.py # an example script for using CodeGen to generate code completions
|-- utils.py # utility functions
|-- datasets/datasets.zip # the input data for the code completion task
|-- function_level_completion_4k_context_codex.test.jsonl
Expand Down
77 changes: 77 additions & 0 deletions RepoCoder/codegen_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch
import tqdm
import json
from transformers import AutoModelForCausalLM, AutoTokenizer


class Tools:
@staticmethod
def load_jsonl(path):
with open(path, 'r') as f:
return [json.loads(line) for line in f.readlines()]

@staticmethod
def dump_jsonl(obj, path):
with open(path, 'w') as f:
for line in obj:
f.write(json.dumps(line) + '\n')


class CodeGen:
def __init__(self, model_name, batch_size):
self.model_name = model_name
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.tokenizer.add_special_tokens({'pad_token': self.tokenizer.eos_token})
self.model.cuda()
self.batch_size = batch_size
print('done loading model')

def _get_batchs(self, prompts, batch_size):
batches = []
for i in range(0, len(prompts), batch_size):
batches.append(prompts[i:i+batch_size])
return batches

def _generate_batch(self, prompt_batch, max_new_tokens=100):
prompts = self.tokenizer(prompt_batch, return_tensors='pt', padding=True, truncation=True)

with torch.no_grad():
gen_tokens = self.model.generate(
input_ids = prompts['input_ids'].cuda(),
attention_mask = prompts['attention_mask'].cuda(),
do_sample=False,
max_new_tokens=max_new_tokens,
)
gen_text = self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
for i in range(len(gen_text)): # remove the prompt
gen_text[i] = gen_text[i][len(prompt_batch[i]):]
return gen_text

def batch_generate(self, file):
print(f'generating from {file}')
lines = Tools.load_jsonl(file)
# have a new line at the end
prompts = [f"{line['prompt']}\n" for line in lines]
batches = self._get_batchs(prompts, self.batch_size)
gen_text = []
for batch in tqdm.tqdm(batches):
gen_text.extend(self._generate_batch(batch))
print(f'generated {len(gen_text)} samples')
assert len(gen_text) == len(prompts)
new_lines = []
for line, gen in zip(lines, gen_text):
new_lines.append({
'prompt': line['prompt'],
'metadata': line['metadata'],
'choices': [{'text': gen}]
})
Tools.dump_jsonl(new_lines, file.replace('.jsonl', f'_{self.model_name.split("/")[-1]}.jsonl'))


if __name__ == '__main__':
file_path = 'datasets/line_level_completion_1k_context_codegen.test.jsonl'
tiny_codegen = 'Salesforce/codegen-350M-mono'

cg = CodeGen(tiny_codegen, batch_size=8)
cg.batch_generate(file_path)

0 comments on commit cecce36

Please sign in to comment.