Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for Deepseek R1 model and display CoT #2118

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions request_llms/bridge_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,18 +1071,18 @@ def decode(self, *args, **kwargs):
except:
logger.error(trimmed_format_exc())
# -=-=-=-=-=-=- 幻方-深度求索大模型在线API -=-=-=-=-=-=-
if "deepseek-chat" in AVAIL_LLM_MODELS or "deepseek-coder" in AVAIL_LLM_MODELS:
if "deepseek-chat" in AVAIL_LLM_MODELS or "deepseek-coder" in AVAIL_LLM_MODELS or "deepseek-reasoner" in AVAIL_LLM_MODELS:
try:
deepseekapi_noui, deepseekapi_ui = get_predict_function(
api_key_conf_name="DEEPSEEK_API_KEY", max_output_token=4096, disable_proxy=False
)
)
model_info.update({
"deepseek-chat":{
"fn_with_ui": deepseekapi_ui,
"fn_without_ui": deepseekapi_noui,
"endpoint": deepseekapi_endpoint,
"can_multi_thread": True,
"max_token": 32000,
"max_token": 64000,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
Expand All @@ -1095,6 +1095,16 @@ def decode(self, *args, **kwargs):
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
"deepseek-reasoner":{
"fn_with_ui": deepseekapi_ui,
"fn_without_ui": deepseekapi_noui,
"endpoint": deepseekapi_endpoint,
"can_multi_thread": True,
"max_token": 64000,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
"enable_reasoning": True
},
})
except:
logger.error(trimmed_format_exc())
Expand Down
108 changes: 58 additions & 50 deletions request_llms/oai_std_model_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ def get_full_error(chunk, stream_response):

def decode_chunk(chunk):
"""
用于解读"content"和"finish_reason"的内容
用于解读"content"和"finish_reason"的内容(如果支持思维链也会返回"reasoning_content"内容)
"""
chunk = chunk.decode()
respose = ""
reasoning_content = ""
finish_reason = "False"
try:
chunk = json.loads(chunk[6:])
Expand All @@ -57,14 +58,20 @@ def decode_chunk(chunk):
return respose, finish_reason

try:
respose = chunk["choices"][0]["delta"]["content"]
if chunk["choices"][0]["delta"]["content"] is not None:
respose = chunk["choices"][0]["delta"]["content"]
except:
pass
try:
if chunk["choices"][0]["delta"]["reasoning_content"] is not None:
reasoning_content = chunk["choices"][0]["delta"]["reasoning_content"]
except:
pass
try:
finish_reason = chunk["choices"][0]["finish_reason"]
except:
pass
return respose, finish_reason
return respose, reasoning_content, finish_reason


def generate_message(input, model, key, history, max_output_token, system_prompt, temperature):
Expand Down Expand Up @@ -163,29 +170,23 @@ def predict_no_ui_long_connection(
system_prompt=sys_prompt,
temperature=llm_kwargs["temperature"],
)

from .bridge_all import model_info

reasoning = model_info[llm_kwargs['llm_model']].get('enable_reasoning', False)

retry = 0
while True:
try:
from .bridge_all import model_info

endpoint = model_info[llm_kwargs["llm_model"]]["endpoint"]
if not disable_proxy:
response = requests.post(
endpoint,
headers=headers,
proxies=proxies,
json=playload,
stream=True,
timeout=TIMEOUT_SECONDS,
)
else:
response = requests.post(
endpoint,
headers=headers,
json=playload,
stream=True,
timeout=TIMEOUT_SECONDS,
)
response = requests.post(
endpoint,
headers=headers,
proxies=None if disable_proxy else proxies,
json=playload,
stream=True,
timeout=TIMEOUT_SECONDS,
)
break
except:
retry += 1
Expand All @@ -194,10 +195,13 @@ def predict_no_ui_long_connection(
raise TimeoutError
if MAX_RETRY != 0:
logger.error(f"请求超时,正在重试 ({retry}/{MAX_RETRY}) ……")

stream_response = response.iter_lines()

result = ""
finish_reason = ""
if reasoning:
resoning_buffer = ""

stream_response = response.iter_lines()
while True:
try:
chunk = next(stream_response)
Expand All @@ -207,9 +211,9 @@ def predict_no_ui_long_connection(
break
except requests.exceptions.ConnectionError:
chunk = next(stream_response) # 失败了,重试一次?再失败就没办法了。
response_text, finish_reason = decode_chunk(chunk)
response_text, reasoning_content, finish_reason = decode_chunk(chunk)
# 返回的数据流第一次为空,继续等待
if response_text == "" and finish_reason != "False":
if response_text == "" and (reasoning == False or reasoning_content == "") and finish_reason != "False":
continue
if response_text == "API_ERROR" and (
finish_reason != "False" or finish_reason != "stop"
Expand All @@ -227,6 +231,8 @@ def predict_no_ui_long_connection(
print(f"[response] {result}")
break
result += response_text
if reasoning:
resoning_buffer += reasoning_content
if observe_window is not None:
# 观测窗,把已经获取的数据显示出去
if len(observe_window) >= 1:
Expand All @@ -241,6 +247,8 @@ def predict_no_ui_long_connection(
error_msg = chunk_decoded
logger.error(error_msg)
raise RuntimeError("Json解析不合常规")
if reasoning:
return '\n'.join(map(lambda x: '> ' + x, resoning_buffer.split('\n'))) + '\n\n' + result
return result

def predict(
Expand Down Expand Up @@ -298,32 +306,25 @@ def predict(
system_prompt=system_prompt,
temperature=llm_kwargs["temperature"],
)

from .bridge_all import model_info

reasoning = model_info[llm_kwargs['llm_model']].get('enable_reasoning', False)

history.append(inputs)
history.append("")
retry = 0
while True:
try:
from .bridge_all import model_info

endpoint = model_info[llm_kwargs["llm_model"]]["endpoint"]
if not disable_proxy:
response = requests.post(
endpoint,
headers=headers,
proxies=proxies,
json=playload,
stream=True,
timeout=TIMEOUT_SECONDS,
)
else:
response = requests.post(
endpoint,
headers=headers,
json=playload,
stream=True,
timeout=TIMEOUT_SECONDS,
)
response = requests.post(
endpoint,
headers=headers,
proxies=None if disable_proxy else proxies,
json=playload,
stream=True,
timeout=TIMEOUT_SECONDS,
)
break
except:
retry += 1
Expand All @@ -338,6 +339,8 @@ def predict(
raise TimeoutError

gpt_replying_buffer = ""
if reasoning:
gpt_reasoning_buffer = ""

stream_response = response.iter_lines()
while True:
Expand All @@ -347,9 +350,9 @@ def predict(
break
except requests.exceptions.ConnectionError:
chunk = next(stream_response) # 失败了,重试一次?再失败就没办法了。
response_text, finish_reason = decode_chunk(chunk)
response_text, reasoning_content, finish_reason = decode_chunk(chunk)
# 返回的数据流第一次为空,继续等待
if response_text == "" and finish_reason != "False":
if response_text == "" and (reasoning == False or reasoning_content == "") and finish_reason != "False":
status_text = f"finish_reason: {finish_reason}"
yield from update_ui(
chatbot=chatbot, history=history, msg=status_text
Expand Down Expand Up @@ -379,9 +382,14 @@ def predict(
logger.info(f"[response] {gpt_replying_buffer}")
break
status_text = f"finish_reason: {finish_reason}"
gpt_replying_buffer += response_text
# 如果这里抛出异常,一般是文本过长,详情见get_full_error的输出
history[-1] = gpt_replying_buffer
if reasoning:
gpt_replying_buffer += response_text
gpt_reasoning_buffer += reasoning_content
history[-1] = '\n'.join(map(lambda x: '> ' + x, gpt_reasoning_buffer.split('\n'))) + '\n\n' + gpt_replying_buffer
else:
gpt_replying_buffer += response_text
# 如果这里抛出异常,一般是文本过长,详情见get_full_error的输出
history[-1] = gpt_replying_buffer
chatbot[-1] = (history[-2], history[-1])
yield from update_ui(
chatbot=chatbot, history=history, msg=status_text
Expand Down