Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
isthaison committed Dec 4, 2024
2 parents 3260642 + 0c849bd commit 4447a6c
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 27 deletions.
82 changes: 55 additions & 27 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@
from openai import OpenAI
import openai
from ollama import Client
from rag.nlp import is_english
from rag.nlp import is_chinese
from rag.utils import num_tokens_from_string
from groq import Groq
import os
import json
import requests
import asyncio

LENGTH_NOTIFICATION_CN = "······\n由于长度的原因,回答被截断了,要继续吗?"
LENGTH_NOTIFICATION_EN = "...\nFor the content length reason, it stopped, continue?"

class Base(ABC):
def __init__(self, key, model_name, base_url):
Expand All @@ -47,8 +49,10 @@ def chat(self, system, history, gen_conf):
**gen_conf)
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, response.usage.total_tokens
except openai.APIError as e:
return "**ERROR**: " + str(e), 0
Expand Down Expand Up @@ -80,8 +84,10 @@ def chat_streamly(self, system, history, gen_conf):
else: total_tokens = resp.usage.total_tokens

if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans

except openai.APIError as e:
Expand Down Expand Up @@ -167,8 +173,10 @@ def chat(self, system, history, gen_conf):
**self._format_params(gen_conf))
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, response.usage.total_tokens
except openai.APIError as e:
return "**ERROR**: " + str(e), 0
Expand Down Expand Up @@ -207,8 +215,10 @@ def chat_streamly(self, system, history, gen_conf):
else resp.usage["total_tokens"]
)
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans

except Exception as e:
Expand Down Expand Up @@ -242,8 +252,10 @@ def chat(self, system, history, gen_conf):
ans += response.output.choices[0]['message']['content']
tk_count += response.usage.total_tokens
if response.output.choices[0].get("finish_reason", "") == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count

return "**ERROR**: " + response.message, tk_count
Expand Down Expand Up @@ -276,8 +288,10 @@ def _chat_streamly(self, system, history, gen_conf, incremental_output=False):
ans = resp.output.choices[0]['message']['content']
tk_count = resp.usage.total_tokens
if resp.output.choices[0].get("finish_reason", "") == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans
else:
yield ans + "\n**ERROR**: " + resp.message if not re.search(r" (key|quota)", str(resp.message).lower()) else "Out of credit. Please set the API key in **settings > Model providers.**"
Expand Down Expand Up @@ -308,8 +322,10 @@ def chat(self, system, history, gen_conf):
)
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, response.usage.total_tokens
except Exception as e:
return "**ERROR**: " + str(e), 0
Expand All @@ -333,8 +349,10 @@ def chat_streamly(self, system, history, gen_conf):
delta = resp.choices[0].delta.content
ans += delta
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
yield ans
Expand Down Expand Up @@ -525,8 +543,10 @@ def chat(self, system, history, gen_conf):
response = response.json()
ans = response["choices"][0]["message"]["content"].strip()
if response["choices"][0]["finish_reason"] == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, response["usage"]["total_tokens"]
except Exception as e:
return "**ERROR**: " + str(e), 0
Expand Down Expand Up @@ -594,8 +614,10 @@ def chat(self, system, history, gen_conf):
**gen_conf)
ans = response.choices[0].message.content
if response.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, response.usage.total_tokens
except openai.APIError as e:
return "**ERROR**: " + str(e), 0
Expand All @@ -618,8 +640,10 @@ def chat_streamly(self, system, history, gen_conf):
ans += resp.choices[0].delta.content
total_tokens += 1
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans

except openai.APIError as e:
Expand Down Expand Up @@ -811,8 +835,10 @@ def chat(self, system, history, gen_conf):
)
ans = response.choices[0].message.content
if response.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, response.usage.total_tokens
except Exception as e:
return ans + "\n**ERROR**: " + str(e), 0
Expand All @@ -838,8 +864,10 @@ def chat_streamly(self, system, history, gen_conf):
ans += resp.choices[0].delta.content
total_tokens += 1
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans

except Exception as e:
Expand Down
8 changes: 8 additions & 0 deletions rag/nlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,14 @@ def is_english(texts):
return True
return False

def is_chinese(text):
chinese = 0
for ch in text:
if '\u4e00' <= ch <= '\u9fff':
chinese += 1
if chinese / len(text) > 0.2:
return True
return False

def tokenize(d, t, eng):
d["content_with_weight"] = t
Expand Down

0 comments on commit 4447a6c

Please sign in to comment.