Skip to content

Commit

Permalink
同步更改
Browse files Browse the repository at this point in the history
  • Loading branch information
manjieqi committed Jul 31, 2024
1 parent 633afd7 commit f7ac28c
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 69 deletions.
11 changes: 7 additions & 4 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
#监听地址
HOST=0.0.0.0
HOST='0.0.0.0'

#监听端口
PORT=5000

#访问密码
PASSWORD=xxxxxxxxxxxxxxxx
#访问密码(对应应用的api key)
PASSWORD=

#GCP Project ID
PROJECT_ID=gen-lang-client-xxxxxxxxx
PROJECT_ID=project1, project2, project3

#访问区域
REGION=us-east5

#调试开关
DEBUG=False
52 changes: 30 additions & 22 deletions build/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python3

import os
import sys
import json
Expand All @@ -7,6 +9,8 @@
import pkg_resources
import importlib.util

colorama.init()

def check_requirements():
if getattr(sys, 'frozen', False):
print("从可执行文件启动,跳过依赖项检查。")
Expand Down Expand Up @@ -45,7 +49,7 @@ def check_requirements():
print("所有依赖项检查通过。")
return True

#获取目录位置
# 获取目录位置
def get_base_path():
if getattr(sys, 'frozen', False):
# 如果是打包后的可执行文件
Expand All @@ -55,7 +59,7 @@ def get_base_path():
return os.path.dirname(os.path.abspath(__file__))


def check_directory_structure():
def check_directory_structure(project_ids: str):
# 检查当前目录
app_base_dir = get_base_path()

Expand All @@ -69,10 +73,26 @@ def check_directory_structure():

# 检查 auth.json 文件
print(f"检查 auth.json 文件..")
auth_file = os.path.join(auth_dir, 'auth.json')
if not os.path.exists(auth_file):
print(f"错误: {auth_dir}下'auth.json'谷歌验证文件缺失!")
return False
default_auth_file = os.path.join(auth_dir, 'auth.json')

if os.path.exists(default_auth_file):
print(f"检测到 auth.json 文件存在,使用单项目:{project_ids[0]}。")
else:
print(f"未检测到 auth.json 文件,检查账号配置...")

missing_files = []
for project_id in project_ids:
project_auth_file = os.path.join(auth_dir, f"{project_id}.json")
if not os.path.exists(project_auth_file):
missing_files.append(f"{project_id}.json")

if missing_files:
print(colorama.Back.RED + colorama.Fore.WHITE + f"警告:未检测到以下项目验证文件:")
for file in missing_files:
print(f" - {file}")
input(colorama.Style.RESET_ALL + "按Ctrl+C退出或按回车键忽略...")
else:
print(f"多账号模式配置检查通过。")
time.sleep(0.1)

# 检查 .env 文件
Expand Down Expand Up @@ -103,17 +123,6 @@ def check_directory_structure():

return True

def manage_gcp_auth():
# 检查是否有激活的服务账号
check_cmd = "gcloud auth list --filter=status:ACTIVE --format='value(account)' | grep -q '@.*\\.iam\\.gserviceaccount\\.com'"
if subprocess.run(check_cmd, shell=True).returncode == 0:
print("GCP 服务账号已激活。")
else:
print("激活 GCP 服务账号...")
key_file = os.path.join('auth', 'service-account-key.json')
activate_cmd = f"gcloud auth activate-service-account --key-file={key_file}"
subprocess.run(activate_cmd, shell=True, check=True)

def load_proxy_server():
if getattr(sys, 'frozen', False):
module_path = os.path.join(sys._MEIPASS, 'proxy_server.py')
Expand All @@ -127,22 +136,21 @@ def load_proxy_server():


def main():
proxy_server = load_proxy_server()
if not check_requirements():
input("依赖项未满足。请安装后重启。")
sys.exit(1)

if not check_directory_structure():
if not check_directory_structure(proxy_server.project_ids):
input("目录必要文件验证失败,取消启动。")
sys.exit(1)

print("目录文件验证成功。")
# manage_gcp_auth()

# 启动 proxy_server.py
# time.sleep(0.5)
# time.sleep(0.5)
print("启动服务器...")
colorama.init()
proxy_server = load_proxy_server()
print(f"DEBUG mode: {proxy_server.debug_mode}")

import uvicorn
uvicorn.run(proxy_server.app, host=proxy_server.hostaddr, port=proxy_server.lsnport)
Expand Down
207 changes: 164 additions & 43 deletions build/proxy_server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import json
import os
import sys
import re
import ast
import random

from typing import Optional
from anthropic import AnthropicVertex
from typing import Optional, Dict, Any
from anthropic import AsyncAnthropicVertex
from dotenv import load_dotenv
from fastapi import FastAPI, Header, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -30,15 +33,12 @@ def get_base_path():
env_path = os.path.join(get_base_path(), '.env')
load_dotenv(env_path)

os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.join(get_base_path(), 'auth', 'auth.json')
hostaddr = '0.0.0.0' if is_docker else os.getenv('HOST', '127.0.0.1')
lsnport = int(os.getenv('PORT', 5000))
project_id = os.getenv('PROJECT_ID')
project_ids = os.getenv('PROJECT_ID').split(', ')
region = os.getenv('REGION')
password = os.getenv('PASSWORD')

# VertexAI 配置
vertex_client = AnthropicVertex(project_id=project_id, region=region)
debug_mode = os.getenv('DEBUG', 'False').lower() == 'true'

# CORS 配置
app.add_middleware(
Expand All @@ -54,6 +54,61 @@ class MessageRequest(BaseModel):
stream: Optional[bool] = False
# 添加其他可能的字段

#负载均衡加权算法
class WeightedRandomSelector:
def __init__(self, project_ids):
self.projects = {pid: 1 for pid in project_ids}

def get_project(self):
if len(self.projects) <= 1:
return next(iter(self.projects))

total_weight = sum(self.projects.values())
r = random.uniform(0, total_weight)
for pid, weight in self.projects.items():
r -= weight
if r <= 0:
self._update_weights(pid)
return pid
return list(self.projects.keys())[-1]

def _update_weights(self, selected_pid):
decrease = min(self.projects[selected_pid], 0.5)
increase = decrease / (len(self.projects) - 1)

for pid in self.projects:
if pid == selected_pid:
self.projects[pid] -= decrease
else:
self.projects[pid] += increase

def print_weights(self):
print("Current project weights:")
for pid, weight in self.projects.items():
print(f" {pid}: {weight:.4f}")
print() # 添加一个空行,使输出更易读

# 创建一个全局的 WeightedRandomSelector 实例
global_selector = WeightedRandomSelector(project_ids)

def load_balance_selector():
default_auth_file = os.path.join(os.path.join(get_base_path(), 'auth'), 'auth.json')
# 检查是否存在 auth.json
if os.path.exists(default_auth_file):
# 如果存在 auth.json,返回第一个 project_id 和 auth.json
return project_ids[0], default_auth_file
else:
# 如果不存在 auth.json,加权随机选择
project_id = global_selector.get_project()
auth_file = os.path.join(get_base_path(), 'auth', f'{project_id}.json')
if not os.path.exists(auth_file):
# 如果文件不存在,抛出 HTTPException
raise HTTPException(
status_code=500,
detail="No valid authentication file found. Please check your configuration."
)
return project_id, auth_file

def vertex_model(original_model):
# 定义模型名称映射
mapping_file = os.path.join(get_base_path(), 'model_mapping.json')
Expand All @@ -64,54 +119,120 @@ def vertex_model(original_model):
# 比较密码
def check_auth(api_key: Optional[str]) -> bool:
if not password: # 如果密码未设置或为空字符串
debug_mode and print("No password set. Skipping...")
return True
return api_key and compare_digest(api_key, password)

def prepare_vertex_request(data: Dict[Any, Any]) -> Dict[Any, Any]:
vertex_request = {}
for key, value in data.items():
if key == 'model':
vertex_request[key] = vertex_model(value)
else:
vertex_request[key] = value
return vertex_request

def parse_vertex_error(error_string):
# 将错误信息分为两部分
parts = error_string.split(' - ', 1)

# 从第一部分提取错误代码
error_code_match = re.search(r'Error code: (\d+)', parts[0])
error_code = int(error_code_match.group(1)) if error_code_match else 500

# 使用 ast.literal_eval 解析第二部分
try:
error_json = ast.literal_eval(parts[1])
error_message = error_json[0]['error'].get('message', 'Unknown error')
error_type = error_json[0]['error'].get('status') or error_json[0]['error'].get('type') or 'UNKNOWN_ERROR'
except Exception as e:
print(e)
# 如果解析失败,使用默认值
error_message = "Failed to parse error details"
error_type = "PARSE_ERROR"

error_content = {
"type": "error",
"error": {
"type": error_type,
"message": error_message
}
}

return error_code, error_content

async def handle_stream_request(vertex_client, vertex_request: Dict[Any, Any]):
try:
message_iterator = await vertex_client.messages.create(**vertex_request)
async def generate():
try:
is_first_chunk = True
debug_mode and print("Start streaming:")
async for chunk in message_iterator:
response = f"event: {chunk.type}\ndata: {json.dumps(chunk.model_dump())}\n\n"
yield response
# debug mode下输出响应
debug_mode and print(response)
if is_first_chunk:
is_first_chunk = False
await vertex_client.close()
debug_mode and print("Stream ended.")
except Exception as e:
error_code, error_content = parse_vertex_error(str(e))
print(e)
yield f"event: error\ndata: {json.dumps(error_content)}\n\n"
await vertex_client.close()

return generate()
except Exception as e:
print(e)
error_code, error_content = parse_vertex_error(str(e))
await vertex_client.close()
return error_code, error_content

async def handle_non_stream_request(vertex_client, vertex_request: Dict[Any, Any]):
try:
response = await vertex_client.messages.create(**vertex_request)
# debug mode下输出响应
debug_mode and print(response)
await vertex_client.close()
return JSONResponse(content=response.model_dump(), status_code=200)
except Exception as e:
# 捕获初始化时的任何异常并返回错误信息
print(e)
error_code, error_content = parse_vertex_error(str(e))
await vertex_client.close()
return JSONResponse(content=error_content, status_code=error_code)

@app.post("/v1/messages")
async def proxy_request(request: Request, x_api_key: Optional[str] = Header(None)):
# 密码验证
if not check_auth(x_api_key):
raise HTTPException(status_code=401, detail="Unauthorized")

# 获取原始请求数据
data = await request.json()

# print("Original request:")
# print(data)
debug_mode and print(f"received request: {data}")

# 准备发送到 VertexAI 的请求
try:
# 创建一个新的字典来存储请求参数
vertex_request = {}

# 遍历原始请求中的所有键值对
for key, value in data.items():
if key == 'model':
# 对模型名称进行转换
vertex_request[key] = vertex_model(value)
else:
# 直接复制其他所有参数
vertex_request[key] = value

# 输出处理后的请求
# print("Processed request:")
# print(json.dumps(vertex_request, indent=2))

# 发送请求到 VertexAI
# 检查是否为流式请求
if vertex_request.get('stream', False):
def generate():
yield 'event: ping\ndata: {"type": "ping"}\n\n'
for chunk in vertex_client.messages.create(**vertex_request):
response = f"event: {chunk.type}\ndata: {json.dumps(chunk.model_dump())}\n\n"
# print(f"{response}")
yield response
vertex_request = prepare_vertex_request(data)
project_id, auth_file = load_balance_selector()
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = auth_file
vertex_client = AsyncAnthropicVertex(project_id=project_id, region=region)
print(f"accessing VertexAI via project {project_id}")

return StreamingResponse(generate(), media_type='text/event-stream', headers={'X-Accel-Buffering': 'no'})
if vertex_request.get('stream', False):
result = await handle_stream_request(vertex_client=vertex_client, vertex_request=vertex_request)

if isinstance(result, tuple):
error_code, error_content = result
return JSONResponse(content=error_content, status_code=error_code)
else:
return StreamingResponse(result, media_type='text/event-stream', headers={'X-Accel-Buffering': 'no'})
else:
response = vertex_client.messages.create(**vertex_request)
# print(f"{response}")
return JSONResponse(content=response.model_dump(), status_code=200)

return await handle_non_stream_request(vertex_client=vertex_client, vertex_request=vertex_request)
except Exception as e:
print(e)
return JSONResponse(content={"error": str(e)}, status_code=500)
finally:
# 打印当前权重,重置文件变量
debug_mode and global_selector.print_weights()
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = ""

0 comments on commit f7ac28c

Please sign in to comment.