-
Notifications
You must be signed in to change notification settings - Fork 8.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e62deca
commit 8ded605
Showing
5 changed files
with
245 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import time | ||
import os | ||
from toolbox import update_ui, get_conf, update_ui_lastest_msg, log_chat | ||
from toolbox import check_packages, report_exception, have_any_recent_upload_image_files | ||
from toolbox import ChatBotWithCookies | ||
|
||
model_name = 'Hugging Face Playground' | ||
|
||
|
||
def validate_key(): | ||
HUGGINGFACE_ACCESS_TOKEN = get_conf("HUGGINGFACE_ACCESS_TOKEN") | ||
if HUGGINGFACE_ACCESS_TOKEN == '': return False | ||
return True | ||
|
||
|
||
def predict_no_ui_long_connection(inputs: str, llm_kwargs: dict, history: list = [], sys_prompt: str = "", | ||
observe_window: list = [], console_slience: bool = False): | ||
""" | ||
⭐多线程方法 | ||
函数的说明请见 request_llms/bridge_all.py | ||
""" | ||
watch_dog_patience = 5 | ||
response = "" | ||
|
||
llm_kwargs["llm_model"] = llm_kwargs["llm_model"].replace("HF:", "").strip() | ||
|
||
if validate_key() is False: | ||
raise RuntimeError('请配置HUGGINGFACE_ACCESS_TOKEN') | ||
|
||
# 开始接收回复 | ||
from .com_hfplayground import HFPlaygroundInit | ||
hfp_init = HFPlaygroundInit() | ||
for chunk, response in hfp_init.generate_chat(inputs, llm_kwargs, history, sys_prompt): | ||
if len(observe_window) >= 1: | ||
observe_window[0] = response | ||
if len(observe_window) >= 2: | ||
if (time.time() - observe_window[1]) > watch_dog_patience: | ||
raise RuntimeError("程序终止。") | ||
return response | ||
|
||
|
||
def predict(inputs: str, llm_kwargs: dict, plugin_kwargs: dict, chatbot: ChatBotWithCookies, | ||
history: list = [], system_prompt: str = '', stream: bool = True, additional_fn: str = None): | ||
""" | ||
⭐单线程方法 | ||
函数的说明请见 request_llms/bridge_all.py | ||
""" | ||
chatbot.append([inputs, ""]) | ||
yield from update_ui(chatbot=chatbot, history=history) | ||
|
||
# 尝试导入依赖,如果缺少依赖,则给出安装建议 | ||
try: | ||
check_packages(["openai"]) | ||
except: | ||
yield from update_ui_lastest_msg( | ||
f"导入软件依赖失败。使用该模型需要额外依赖,安装方法```pip install --upgrade openai```。", | ||
chatbot=chatbot, history=history, delay=0) | ||
return | ||
|
||
if validate_key() is False: | ||
yield from update_ui_lastest_msg(lastmsg="[Local Message] 请配置HUGGINGFACE_ACCESS_TOKEN", chatbot=chatbot, | ||
history=history, delay=0) | ||
return | ||
|
||
if additional_fn is not None: | ||
from core_functional import handle_core_functionality | ||
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot) | ||
chatbot[-1] = [inputs, ""] | ||
yield from update_ui(chatbot=chatbot, history=history) | ||
|
||
llm_kwargs["llm_model"] = llm_kwargs["llm_model"].replace("HF:", "").strip() | ||
|
||
# 开始接收回复 | ||
from .com_hfplayground import HFPlaygroundInit | ||
hfp_init = HFPlaygroundInit() | ||
for chunk, response in hfp_init.generate_chat(inputs, llm_kwargs, history, system_prompt): | ||
chatbot[-1] = [inputs, response] | ||
yield from update_ui(chatbot=chatbot, history=history) | ||
history.extend([inputs, response]) | ||
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=response) | ||
yield from update_ui(chatbot=chatbot, history=history) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import httpx | ||
import openai | ||
from openai import OpenAI | ||
from toolbox import get_conf, encode_image, get_pictures_list | ||
from loguru import logger | ||
import os | ||
|
||
proxies = get_conf("proxies") | ||
if proxies is not None: | ||
proxies = {k + "://": v for k, v in proxies.items()} | ||
|
||
|
||
def input_encode_handler(inputs: str, llm_kwargs: dict): | ||
if llm_kwargs["most_recent_uploaded"].get("path"): | ||
image_paths = get_pictures_list(llm_kwargs["most_recent_uploaded"]["path"]) | ||
md_encode = [] | ||
for md_path in image_paths: | ||
type_ = os.path.splitext(md_path)[1].replace(".", "") | ||
type_ = "jpeg" if type_ == "jpg" else type_ | ||
md_encode.append({"data": encode_image(md_path), "type": type_}) | ||
return inputs, md_encode | ||
|
||
|
||
class HFPlaygroundInit: | ||
|
||
def __init__(self): | ||
HUGGINGFACE_ACCESS_TOKEN = get_conf("HUGGINGFACE_ACCESS_TOKEN") | ||
self.client = OpenAI(base_url="https://api-inference.huggingface.co/v1/", api_key=HUGGINGFACE_ACCESS_TOKEN, | ||
http_client=httpx.Client(proxies=proxies)) | ||
self.model = '' | ||
|
||
def __conversation_user(self, user_input: str, llm_kwargs: dict): | ||
return {"role": "user", "content": user_input} | ||
|
||
def __conversation_history(self, history: list, llm_kwargs: dict): | ||
messages = [] | ||
conversation_cnt = len(history) // 2 | ||
if conversation_cnt: | ||
for index in range(0, 2 * conversation_cnt, 2): | ||
what_i_have_asked = self.__conversation_user(history[index], llm_kwargs) | ||
what_gpt_answer = { | ||
"role": "assistant", | ||
"content": history[index + 1] | ||
} | ||
messages.append(what_i_have_asked) | ||
messages.append(what_gpt_answer) | ||
return messages | ||
|
||
@staticmethod | ||
def preprocess_param(param, default=0.95, min_val=0.01, max_val=0.99): | ||
"""预处理参数,保证其在允许范围内,并处理精度问题""" | ||
try: | ||
param = float(param) | ||
except ValueError: | ||
return default | ||
|
||
if param <= min_val: | ||
return min_val | ||
elif param >= max_val: | ||
return max_val | ||
else: | ||
return round(param, 2) # 可挑选精度,目前是两位小数 | ||
|
||
def __conversation_message_payload(self, inputs: str, llm_kwargs: dict, history: list, system_prompt: str): | ||
messages = [] | ||
if system_prompt: | ||
messages.append({"role": "system", "content": system_prompt}) | ||
self.model = llm_kwargs['llm_model'] | ||
messages.extend(self.__conversation_history(history, llm_kwargs)) # 处理 history | ||
if inputs.strip() == "": # 处理空输入导致报错的问题 https://github.com/binary-husky/gpt_academic/issues/1640 提示 {"error":{"code":"1214","message":"messages[1]:content和tool_calls 字段不能同时为空"} | ||
inputs = "." # 空格、换行、空字符串都会报错,所以用最没有意义的一个点代替 | ||
messages.append(self.__conversation_user(inputs, llm_kwargs)) # 处理用户对话 | ||
""" | ||
采样温度,控制输出的随机性,必须为正数 | ||
取值范围是:(0.0, 1.0),不能等于 0,默认值为 0.95, | ||
值越大,会使输出更随机,更具创造性; | ||
值越小,输出会更加稳定或确定 | ||
建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数 | ||
""" | ||
temperature = self.preprocess_param( | ||
param=llm_kwargs.get('temperature', 0.95), | ||
default=0.95, | ||
min_val=0.01, | ||
max_val=0.99 | ||
) | ||
""" | ||
用温度取样的另一种方法,称为核取样 | ||
取值范围是:(0.0, 1.0) 开区间, | ||
不能等于 0 或 1,默认值为 0.7 | ||
模型考虑具有 top_p 概率质量 tokens 的结果 | ||
例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens | ||
建议您根据应用场景调整 top_p 或 temperature 参数, | ||
但不要同时调整两个参数 | ||
""" | ||
top_p = self.preprocess_param( | ||
param=llm_kwargs.get('top_p', 0.70), | ||
default=0.70, | ||
min_val=0.01, | ||
max_val=0.99 | ||
) | ||
response = self.client.chat.completions.create( | ||
model=self.model, messages=messages, stream=True, | ||
temperature=temperature, | ||
top_p=top_p, | ||
max_tokens=llm_kwargs.get('max_tokens', 1024 * 4), | ||
) | ||
return response | ||
|
||
def generate_chat(self, inputs: str, llm_kwargs: dict, history: list, system_prompt: str): | ||
self.model = llm_kwargs['llm_model'] | ||
response = self.__conversation_message_payload(inputs, llm_kwargs, history, system_prompt) | ||
bro_results = '' | ||
for chunk in response: | ||
bro_results += chunk.choices[0].delta.content | ||
yield chunk.choices[0].delta.content, bro_results | ||
|
||
|
||
if __name__ == '__main__': | ||
HFP = HFPlaygroundInit() | ||
r = HFP.generate_chat('你好', {'llm_model': 'Qwen/Qwen2.5-72B-Instruct'}, [], '你是WPSAi') | ||
for i in r: | ||
print(i) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters