Skip to content

Commit

Permalink
some bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
xd2333 committed Jun 5, 2024
1 parent 8ff2208 commit aa2c884
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 61 deletions.
58 changes: 31 additions & 27 deletions GalTransl/Backend/SakuraTranslate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
Sakura_TRANS_PROMPT010,
Sakura_SYSTEM_PROMPT010,
GalTransl_SYSTEM_PROMPT,
GalTransl_TRANS_PROMPT
GalTransl_TRANS_PROMPT,
)


Expand Down Expand Up @@ -76,10 +76,14 @@ def __init__(

def init_chatbot(self, eng_type, config: CProjectConfig):
from GalTransl.Backend.revChatGPT.V3 import Chatbot as ChatbotV3
section_name = "SakuraLLM" if "SakuraLLM" in config.keyValues else "Sakura"

backendSpecific = config.projectConfig["backendSpecific"]
section_name = "SakuraLLM" if "SakuraLLM" in backendSpecific else "Sakura"
endpoint = self.endpoint
endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
eng_name = config.getBackendConfigSection(section_name).get("rewriteModelName", "gpt-3.5-turbo")
eng_name = config.getBackendConfigSection(section_name).get(
"rewriteModelName", "gpt-3.5-turbo"
)
if eng_type == "sakura-009":
self.system_prompt = Sakura_SYSTEM_PROMPT
self.trans_prompt = Sakura_TRANS_PROMPT
Expand All @@ -90,11 +94,11 @@ def init_chatbot(self, eng_type, config: CProjectConfig):
self.system_prompt = GalTransl_SYSTEM_PROMPT
self.trans_prompt = GalTransl_TRANS_PROMPT
self.chatbot = ChatbotV3(
api_key="sk-114514",
system_prompt=self.system_prompt,
engine=eng_name,
api_address=endpoint + "/v1/chat/completions",
timeout=60,
api_key="sk-114514",
system_prompt=self.system_prompt,
engine=eng_name,
api_address=endpoint + "/v1/chat/completions",
timeout=60,
)
self.chatbot.update_proxy(
self.proxyProvider.getProxy().addr if self.proxyProvider else None # type: ignore
Expand Down Expand Up @@ -160,7 +164,7 @@ async def translate(self, trans_list: CTransList, gptdict=""):
LOGGER.info("-> [请求错误]报错:%s, 即将重试" % ex)
await asyncio.sleep(3)
continue

resp = resp.replace("*EOF*", "").strip()
result_list = resp.strip("\n").split("\n")
# fix trick
Expand Down Expand Up @@ -243,10 +247,8 @@ async def translate(self, trans_list: CTransList, gptdict=""):
if len(trans_list) > 1:
LOGGER.warning("-> 对半拆分重试")
half_len = len(trans_list) // 3
half_len=1 if half_len<1 else half_len
return await self.translate(
trans_list[: half_len], gptdict
)
half_len = 1 if half_len < 1 else half_len
return await self.translate(trans_list[:half_len], gptdict)
# 拆成单句后,才开始计算重试次数
self.retry_count += 1
# 5次重试则填充原文
Expand Down Expand Up @@ -308,7 +310,7 @@ async def batch_translate(

trans_result_list = []
len_trans_list = len(trans_list_unhit)
transl_step_count=0
transl_step_count = 0
while i < len_trans_list:
# await asyncio.sleep(1)

Expand All @@ -320,16 +322,16 @@ async def batch_translate(
)
num, trans_result = await self.translate(trans_list_split, dic_prompt)

if self.transl_dropout>0 and num==num_pre_request:
if self.transl_dropout<num:
num-=self.transl_dropout
trans_result=trans_result[:num]
if self.transl_dropout > 0 and num == num_pre_request:
if self.transl_dropout < num:
num -= self.transl_dropout
trans_result = trans_result[:num]

i += num if num > 0 else 0
transl_step_count+=1
if transl_step_count>=self.save_steps:
transl_step_count += 1
if transl_step_count >= self.save_steps:
save_transCache_to_json(trans_list, cache_file_path)
transl_step_count=0
transl_step_count = 0
LOGGER.info("".join([repr(tran) for tran in trans_result]))
trans_result_list += trans_result
LOGGER.info(f"{filename}: {len(trans_result_list)}/{len_trans_list}")
Expand All @@ -355,12 +357,14 @@ def _del_previous_message(self) -> None:
if last_user_message:
self.chatbot.conversation["default"].append(last_user_message)
if last_assistant_message:
last_assistant_message["content"] = last_assistant_message["content"].replace("*EOF*", "").strip()
if self.transl_dropout>0:
sp=last_assistant_message["content"].split("\n")
if len(sp)>self.transl_dropout:
sp=sp[:0-self.transl_dropout]
last_assistant_message["content"]="\n".join(sp)
last_assistant_message["content"] = (
last_assistant_message["content"].replace("*EOF*", "").strip()
)
if self.transl_dropout > 0:
sp = last_assistant_message["content"].split("\n")
if len(sp) > self.transl_dropout:
sp = sp[: 0 - self.transl_dropout]
last_assistant_message["content"] = "\n".join(sp)
self.chatbot.conversation["default"].append(last_assistant_message)

def _del_last_answer(self):
Expand Down
74 changes: 46 additions & 28 deletions GalTransl/Frontend/GPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

async def doLLMTranslateSingleFile(
semaphore: Semaphore,
endpoint_queue : Queue,
endpoint_queue: Queue,
file_path: str,
projectConfig: CProjectConfig,
eng_type: str,
Expand All @@ -47,7 +47,7 @@ async def doLLMTranslateSingleFile(
async with semaphore:
if endpoint_queue is not None:
endpoint = await endpoint_queue.get()
try:
try:
st = time()
proj_dir = projectConfig.getProjectDir()
input_dir = projectConfig.getInputPath()
Expand All @@ -60,25 +60,35 @@ async def doLLMTranslateSingleFile(
output_file_dir = dirname(output_file_path)
makedirs(output_file_dir, exist_ok=True)
cache_file_path = joinpath(cache_dir, file_name)
LOGGER.info(f"engine type: {eng_type}, file: {file_name}, start translating..")
LOGGER.info(
f"engine type: {eng_type}, file: {file_name}, start translating.."
)

match eng_type:
case "gpt35-0613" | "gpt35-1106" | "gpt35-0125":
gptapi = CGPT35Translate(projectConfig, eng_type, proxyPool, tokenPool)
gptapi = CGPT35Translate(
projectConfig, eng_type, proxyPool, tokenPool
)
case "gpt4" | "gpt4-turbo":
gptapi = CGPT4Translate(projectConfig, eng_type, proxyPool, tokenPool)
gptapi = CGPT4Translate(
projectConfig, eng_type, proxyPool, tokenPool
)
case "newbing":
cookiePool: list[str] = []
for i in projectConfig.getBackendConfigSection("bingGPT4")["cookiePath"]:
for i in projectConfig.getBackendConfigSection("bingGPT4")[
"cookiePath"
]:
cookiePool.append(joinpath(projectConfig.getProjectDir(), i))
gptapi = CBingGPT4Translate(projectConfig, cookiePool, proxyPool)
case "sakura-009" | "sakura-010" | "galtransl-v1":
gptapi = CSakuraTranslate(projectConfig, eng_type, endpoint, proxyPool)
gptapi = CSakuraTranslate(
projectConfig, eng_type, endpoint, proxyPool
)
case "rebuildr" | "rebuilda":
gptapi = CRebuildTranslate(projectConfig, eng_type)
case _:
raise ValueError(f"不支持的翻译引擎类型 {eng_type}")

# 1、初始化trans_list
origin_input = ""

Expand Down Expand Up @@ -196,11 +206,13 @@ async def doLLMTranslateSingleFile(
LOGGER.info(f"文件 {file_name} 翻译完成,用时 {et-st:.3f}s.")
return True


async def run_task(task, progress_bar):
result = await task # Wait for the individual task to complete
progress_bar.update(1) # Update the progress bar
return result


async def doLLMTranslate(
projectConfig: CProjectConfig,
tokenPool: COpenAITokenPool,
Expand All @@ -220,44 +232,50 @@ async def doLLMTranslate(
gpt_dic = CGptDict(initDictList(gpt_dic_dir, default_dic_dir, project_dir))

workersPerProject = projectConfig.getKey("workersPerProject")

if "sakura" in eng_type or "galtransl" in eng_type:
endpoint_queue = Queue()
section_name = "SakuraLLM" if "SakuraLLM" in projectConfig.keyValues else "Sakura"
backendSpecific = projectConfig.projectConfig["backendSpecific"]
section_name = "SakuraLLM" if "SakuraLLM" in backendSpecific else "Sakura"
if "endpoints" in projectConfig.getBackendConfigSection(section_name):
endpoints = projectConfig.getBackendConfigSection(section_name)["endpoints"]
else:
endpoints = [projectConfig.getBackendConfigSection(section_name)["endpoint"]]
repeated = (workersPerProject+ len(endpoints) -1) // len(endpoints)
endpoints = [
projectConfig.getBackendConfigSection(section_name)["endpoint"]
]
repeated = (workersPerProject + len(endpoints) - 1) // len(endpoints)
for _ in range(repeated):
for endpoint in endpoints:
endpoint_queue.put_nowait(endpoint)
LOGGER.info(f"当前使用 {workersPerProject} 个Sakura worker引擎")
else:
endpoint_queue = None

file_list = get_file_list(projectConfig.getInputPath())
if not file_list:
raise RuntimeError(f"{projectConfig.getInputPath()}中没有待翻译的文件")
semaphore = Semaphore(workersPerProject)
progress_bar = atqdm(total=len(file_list), desc="Processing files")
tasks = [
run_task(doLLMTranslateSingleFile(
semaphore,
endpoint_queue,
file_name,
projectConfig,
eng_type,
pre_dic,
post_dic,
gpt_dic,
tPlugins,
fPlugins,
proxyPool,
tokenPool,
), progress_bar)
run_task(
doLLMTranslateSingleFile(
semaphore,
endpoint_queue,
file_name,
projectConfig,
eng_type,
pre_dic,
post_dic,
gpt_dic,
tPlugins,
fPlugins,
proxyPool,
tokenPool,
),
progress_bar,
)
for file_name in file_list
]
# await atqdm.gather(*tasks)
await gather(*tasks) # run
progress_bar.close()
progress_bar.close()
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from GalTransl.GTPlugin import GFilePlugin


webvtt_path = os.path.abspath("plugins/file_subtitle_vtt/webvtt")
webvtt_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(webvtt_path)
import webvtt

Expand Down Expand Up @@ -55,7 +55,7 @@ def load_file(self, file_path: str) -> list:
raise e
elif file_path.endswith(".lrc"):
try:
matches = self.pattern.findall(text)
matches = self.lrc_pattern.findall(text)
LOGGER.debug(f"matches: {matches}")
result = [
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
Core:
Name: srt,lrc,vtt字幕文件读取
Type: file
Module: file_subtitle_srt_lrc
Module: file_subtitle_srt_lrc_vtt

Documentation:
Author: cx2333, PiDanShouRouZhou
Version: 1.0
Description: 可以直接处理.srt|.lrc|.vtt字幕文件并翻译。

Settings: # 这里存放插件的设置
保存双语字幕: true # 是否保存双语字幕
保存双语字幕: false # 是否保存双语字幕[true|false]
上下双语1左右双语2: 1 # 双语字幕类型,1表示上下双语,2表示左右双语
4 changes: 2 additions & 2 deletions sampleProject/config.inc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ backendSpecific:
- newbing_cookies/cookie2.json # 你的 cookies 文件2,会自动切换
SakuraLLM: # Sakura/Galtransl API
endpoints:
- http://127.0.0.1:5000
- http://127.0.0.1:5001
- http://127.0.0.1:8080
#- http://127.0.0.1:5001 # 可以填入多个endpoints,用于多线程
rewriteModelName: "" # 使用指定的模型型号替换默认模型
# 自动问题分析配置
problemAnalyze:
Expand Down

0 comments on commit aa2c884

Please sign in to comment.