-
Notifications
You must be signed in to change notification settings - Fork 1
/
star_coder_chat.py
75 lines (65 loc) · 2.41 KB
/
star_coder_chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import AsyncIterable
import httpx
import httpx_sse
from fastapi_poe import PoeBot
from fastapi_poe.types import QueryRequest
from sse_starlette.sse import ServerSentEvent
BASE_URL = "https://api.together.xyz/inference"
BASE_PROMPT = """"
<|user|>
Hi!
<|end|>
<|assistant|>
I am the StarCoderChat bot. I help users with programming and code related questions. \
I wrap any code in my response in backticks so that it can be rendered using Markdown.
<|end|>
"""
@dataclass
class StarCoderChatBot(PoeBot):
TOGETHER_API_KEY: str # Together.ai api key
def construct_prompt(self, query: QueryRequest):
prompt = "\n"
prompt = BASE_PROMPT
for message in query.query:
if message.role == "user":
prompt += f"<|user|>\n {message.content}\n<|end|>\n"
elif message.role == "bot":
prompt += f"<|assistant|>\n {message.content}\n<|end|>\n"
elif message.role == "system":
pass
else:
raise ValueError(f"unknown role {message.role}.")
prompt += "<|assistant|>"
return prompt
async def query_together_ai(self, prompt) -> str:
payload = {
"model": "HuggingFaceH4/starchat-alpha",
"prompt": prompt,
"max_tokens": 1000,
"stop": ["<|endoftext|>", "<|end|>"],
"stream_tokens": True,
"temperature": 0.7,
"top_p": 0.7,
"top_k": 50,
"repetition_penalty": 1,
}
headers = {
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"Bearer {self.TOGETHER_API_KEY}",
}
async with httpx.AsyncClient() as aclient:
async with httpx_sse.aconnect_sse(
aclient, "POST", BASE_URL, headers=headers, json=payload
) as event_source:
async for event in event_source.aiter_sse():
if event.data != "[DONE]":
token = json.loads(event.data)["choices"][0]["text"]
yield token
async def get_response(self, query: QueryRequest) -> AsyncIterable[ServerSentEvent]:
prompt = self.construct_prompt(query)
async for word in self.query_together_ai(prompt):
yield self.text_event(word)