Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
arjunsuresh committed Jul 16, 2024
1 parent 4e4aff3 commit 8b4c88f
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions language/llama2-70b/SUT_API.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,9 @@ def query_api_vllm(self, inputs, idx):
'Content-Type': 'application/json',
}
json_data = {
'model': self.api_model_name,
'prompt': inputs,
'max_tokens': 8,
'temperature': 0,
"model": self.api_model_name,
"prompt": inputs,
"max_tokens": 8
}

response_code = 0
Expand Down Expand Up @@ -227,7 +226,7 @@ def process_queries(self):
input_ids_tensor.append(self.data_object.input_ids[q.index])

# NOTE(mgoin): I don't think this has to be a torch tensor
# input_ids_tensor = torch.cat(input_ids_tensor)
input_ids_tensor = torch.cat(input_ids_tensor)

assert len(input_ids_tensor) <= self.batch_size

Expand All @@ -236,6 +235,9 @@ def process_queries(self):
# NOTE(mgoin): I don't think threading is necessary since we are submitting all queries in one request
# The API server should take care of mini-batches and scheduling
if self.api_servers:
decoded = self.tokenizer.batch_decode(input_ids_tensor)
cleaned = [entry.replace('</s>','').replace('<s>','') for entry in decoded]
cleaned_chunks = [list(c) for c in mit.divide(len(self.api_servers), cleaned)]
with ThreadPoolExecutor(max_workers=len(self.api_servers)) as executor:
#needs to be tested
output_chunks = list(executor.map(self.api_action_handler,cleaned_chunks,range(len(self.api_servers))))
Expand Down

0 comments on commit 8b4c88f

Please sign in to comment.