Skip to content

Commit

Permalink
Merge pull request #9 from mgoin/patch-2
Browse files Browse the repository at this point in the history
Update SUT_API offline to work for OpenAI
  • Loading branch information
arjunsuresh authored Jul 16, 2024
2 parents 99ee8b6 + 280a294 commit 4e4aff3
Showing 1 changed file with 12 additions and 25 deletions.
37 changes: 12 additions & 25 deletions language/llama2-70b/SUT_API.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,30 +220,21 @@ def process_queries(self):

tik1 = time.time()

# OpenAI-API servers don't require padding and can take input tokens
# directly, so we build our input_ids_tensor as a jagged list
input_ids_tensor = []
input_masks_tensor = []
input_len = []
for q in qitem:
input_ids_tensor.append(pad(self.data_object.input_ids[q.index],
(max_seq_len - self.data_object.input_lens[q.index], 0, 0, 0),
value=self.tokenizer.pad_token_id))
input_masks_tensor.append(pad(self.data_object.attention_masks[q.index],
(max_seq_len - self.data_object.input_lens[q.index], 0, 0, 0),
value=0))
input_len.append(self.data_object.input_lens[q.index])
input_ids_tensor = torch.cat(input_ids_tensor)
input_masks_tensor = torch.cat(input_masks_tensor)

assert input_ids_tensor.shape == input_masks_tensor.shape
assert input_ids_tensor.shape[0] <= self.batch_size
input_ids_tensor.append(self.data_object.input_ids[q.index])

# NOTE(mgoin): I don't think this has to be a torch tensor
# input_ids_tensor = torch.cat(input_ids_tensor)

if self.api_servers:
decoded = self.tokenizer.batch_decode(input_ids_tensor)
cleaned = [entry.replace('</s>','').replace('<s>','') for entry in decoded]
cleaned_chunks = [list(c) for c in mit.divide(len(self.api_servers), cleaned)]
assert len(input_ids_tensor) <= self.batch_size

tik2 = time.time()

# NOTE(mgoin): I don't think threading is necessary since we are submitting all queries in one request
# The API server should take care of mini-batches and scheduling
if self.api_servers:
with ThreadPoolExecutor(max_workers=len(self.api_servers)) as executor:
#needs to be tested
Expand All @@ -257,14 +248,10 @@ def process_queries(self):

tik3 = time.time()

processed_output = self.data_object.postProcess(pred_output_tokens,
input_seq_lens=input_len,
query_id_list=query_ids)
if self.api_servers:
processed_output = np.array(self.tokenizer(output, padding='longest')['input_ids'])

processed_output = self.tokenizer(output)['input_ids']
for i in range(len(qitem)):
unpadded = np.delete(processed_output[i], np.where(processed_output[i] == 2))
# NOTE(mgoin): Not optimal to make numpy arrays just to serialize
unpadded = np.array(processed_output[i])
n_tokens = unpadded.shape[0]
response_array = array.array("B", unpadded.tobytes())
bi = response_array.buffer_info()
Expand Down

0 comments on commit 4e4aff3

Please sign in to comment.