Skip to content

Commit

Permalink
✨ Feature: Add feature: support setting rate limit for each model ind…
Browse files Browse the repository at this point in the history
…ividually
  • Loading branch information
yym68686 committed Nov 5, 2024
1 parent 1778d52 commit cdf3ed9
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 59 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,12 @@ providers:
- gemini-1.5-flash-exp-0827 # Add this line, both gemini-1.5-flash-exp-0827 and gemini-1.5-flash can be requested
tools: true
preferences:
API_KEY_RATE_LIMIT: 15/min # Each API Key can request up to 15 times per minute, optional. The default is 999999/min.
# API_KEY_RATE_LIMIT: 15/min,10/day # Supports multiple frequency constraints
API_KEY_COOLDOWN_PERIOD: 60 # Each API Key will be cooled down for 60 seconds after encountering a 429 error. Optional, the default is 0 seconds. When set to 0, the cooling mechanism is not enabled.
api_key_rate_limit: 15/min # Each API Key can request up to 15 times per minute, optional. The default is 999999/min. Supports multiple frequency constraints: 15/min,10/day
# api_key_rate_limit: # You can set different frequency limits for each model
# gpt-4o: 3/min
# chatgpt-4o-latest: 2/min
# default: 4/min # If the model does not set the frequency limit, use the frequency limit of default
api_key_cooldown_period: 60 # Each API Key will be cooled down for 60 seconds after encountering a 429 error. Optional, the default is 0 seconds. When set to 0, the cooling mechanism is not enabled. When there are multiple API keys, the cooling mechanism will take effect.

- provider: vertex
project_id: gen-lang-client-xxxxxxxxxxxxxx # Description: Your Google Cloud project ID. Format: String, usually composed of lowercase letters, numbers, and hyphens. How to obtain: You can find your project ID in the project selector of the Google Cloud Console.
Expand Down
9 changes: 6 additions & 3 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,12 @@ providers:
- gemini-1.5-flash-exp-0827 # 加上这一行,gemini-1.5-flash-exp-0827 和 gemini-1.5-flash 都可以被请求
tools: true
preferences:
API_KEY_RATE_LIMIT: 15/min # 每个 API Key 每分钟最多请求次数,选填。默认为 999999/min
# API_KEY_RATE_LIMIT: 15/min,10/day # 支持多个频率约束条件
API_KEY_COOLDOWN_PERIOD: 60 # 每个 API Key 遭遇 429 错误后的冷却时间,单位为秒,选填。默认为 0 秒, 当设置为 0 秒时,不启用冷却机制。
api_key_rate_limit: 15/min # 每个 API Key 每分钟最多请求次数,选填。默认为 999999/min。支持多个频率约束条件:15/min,10/day
# api_key_rate_limit: # 可以为每个模型设置不同的频率限制
# gpt-4o: 3/min
# chatgpt-4o-latest: 2/min
# default: 4/min # 如果模型没有设置频率限制,使用 default 的频率限制
api_key_cooldown_period: 60 # 每个 API Key 遭遇 429 错误后的冷却时间,单位为秒,选填。默认为 0 秒, 当设置为 0 秒时,不启用冷却机制。当存在多个 API key 时才会生效。

- provider: vertex
project_id: gen-lang-client-xxxxxxxxxxxxxx # 描述: 您的Google Cloud项目ID。格式: 字符串,通常由小写字母、数字和连字符组成。获取方式: 在Google Cloud Console的项目选择器中可以找到您的项目ID。
Expand Down
27 changes: 14 additions & 13 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,20 +655,22 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
engine = "gpt"

model_dict = get_model_dict(provider)
if "claude" not in model_dict[request.model] \
and "gpt" not in model_dict[request.model] \
and "gemini" not in model_dict[request.model] \
original_model = model_dict[request.model]

if "claude" not in original_model \
and "gpt" not in original_model \
and "gemini" not in original_model \
and parsed_url.netloc != 'api.cloudflare.com' \
and parsed_url.netloc != 'api.cohere.com':
engine = "openrouter"

if "claude" in model_dict[request.model] and engine == "vertex":
if "claude" in original_model and engine == "vertex":
engine = "vertex-claude"

if "gemini" in model_dict[request.model] and engine == "vertex":
if "gemini" in original_model and engine == "vertex":
engine = "vertex-gemini"

if "o1-preview" in model_dict[request.model] or "o1-mini" in model_dict[request.model]:
if "o1-preview" in original_model or "o1-mini" in original_model:
engine = "o1"
request.stream = False

Expand Down Expand Up @@ -702,17 +704,16 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
logger.info(json.dumps(payload, indent=4, ensure_ascii=False))

current_info = request_info.get()
model = model_dict[request.model]

timeout_value = None
# 先尝试精确匹配

if model in app.state.timeouts:
timeout_value = app.state.timeouts[model]
if original_model in app.state.timeouts:
timeout_value = app.state.timeouts[original_model]
else:
# 如果没有精确匹配,尝试模糊匹配
for timeout_model in app.state.timeouts:
if timeout_model in model:
if timeout_model in original_model:
timeout_value = app.state.timeouts[timeout_model]
break

Expand All @@ -723,11 +724,11 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
try:
async with app.state.client_manager.get_client(timeout_value) as client:
if request.stream:
generator = fetch_response_stream(client, url, headers, payload, engine, model)
generator = fetch_response_stream(client, url, headers, payload, engine, original_model)
wrapped_generator, first_response_time = await error_handling_wrapper(generator)
response = StarletteStreamingResponse(wrapped_generator, media_type="text/event-stream")
else:
generator = fetch_response(client, url, headers, payload, engine, model)
generator = fetch_response(client, url, headers, payload, engine, original_model)
wrapped_generator, first_response_time = await error_handling_wrapper(generator)
first_element = await anext(wrapped_generator)
first_element = first_element.lstrip("data: ")
Expand Down Expand Up @@ -1013,7 +1014,7 @@ async def request_model(self, request: Union[RequestModel, ImageGenerationReques
num_matching_providers = len(matching_providers)
index = 0

cooling_time = safe_get(provider, "preferences", "API_KEY_COOLDOWN_PERIOD", default=0)
cooling_time = safe_get(provider, "preferences", "api_key_cooldown_period", default=0)
api_key_count = provider_api_circular_list[channel_id].get_items_count()
if cooling_time > 0 and api_key_count > 1:
current_api = await provider_api_circular_list[channel_id].after_next_current()
Expand Down
45 changes: 22 additions & 23 deletions request.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ async def get_gemini_payload(request, engine, provider):
gemini_stream = "streamGenerateContent"
url = provider['base_url']
if url.endswith("v1beta"):
url = "https://generativelanguage.googleapis.com/v1beta/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next())
url = "https://generativelanguage.googleapis.com/v1beta/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next(model))
if url.endswith("v1"):
url = "https://generativelanguage.googleapis.com/v1/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next())
url = "https://generativelanguage.googleapis.com/v1/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next(model))

messages = []
systemInstruction = None
Expand Down Expand Up @@ -596,8 +596,10 @@ async def get_gpt_payload(request, engine, provider):
headers = {
'Content-Type': 'application/json',
}
model_dict = get_model_dict(provider)
model = model_dict[request.model]
if provider.get("api"):
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
url = provider['base_url']

messages = []
Expand Down Expand Up @@ -637,8 +639,6 @@ async def get_gpt_payload(request, engine, provider):
else:
messages.append({"role": msg.role, "content": content})

model_dict = get_model_dict(provider)
model = model_dict[request.model]
payload = {
"model": model,
"messages": messages,
Expand All @@ -663,8 +663,10 @@ async def get_openrouter_payload(request, engine, provider):
headers = {
'Content-Type': 'application/json'
}
model_dict = get_model_dict(provider)
model = model_dict[request.model]
if provider.get("api"):
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"

url = provider['base_url']

Expand Down Expand Up @@ -696,8 +698,6 @@ async def get_openrouter_payload(request, engine, provider):
else:
messages.append({"role": msg.role, "content": content})

model_dict = get_model_dict(provider)
model = model_dict[request.model]
payload = {
"model": model,
"messages": messages,
Expand Down Expand Up @@ -730,8 +730,10 @@ async def get_cohere_payload(request, engine, provider):
headers = {
'Content-Type': 'application/json'
}
model_dict = get_model_dict(provider)
model = model_dict[request.model]
if provider.get("api"):
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"

url = provider['base_url']

Expand Down Expand Up @@ -759,8 +761,6 @@ async def get_cohere_payload(request, engine, provider):
else:
messages.append({"role": role_map[msg.role], "message": content})

model_dict = get_model_dict(provider)
model = model_dict[request.model]
chat_history = messages[:-1]
query = messages[-1].get("message")
payload = {
Expand Down Expand Up @@ -798,11 +798,11 @@ async def get_cloudflare_payload(request, engine, provider):
headers = {
'Content-Type': 'application/json'
}
if provider.get("api"):
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"

model_dict = get_model_dict(provider)
model = model_dict[request.model]
if provider.get("api"):
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"

url = "https://api.cloudflare.com/client/v4/accounts/{cf_account_id}/ai/run/{cf_model_id}".format(cf_account_id=provider['cf_account_id'], cf_model_id=model)

msg = request.messages[-1]
Expand All @@ -816,7 +816,6 @@ async def get_cloudflare_payload(request, engine, provider):
content = msg.content
name = msg.name

model = model_dict[request.model]
payload = {
"prompt": content,
}
Expand Down Expand Up @@ -848,8 +847,10 @@ async def get_o1_payload(request, engine, provider):
headers = {
'Content-Type': 'application/json'
}
model_dict = get_model_dict(provider)
model = model_dict[request.model]
if provider.get("api"):
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"

url = provider['base_url']

Expand All @@ -871,8 +872,6 @@ async def get_o1_payload(request, engine, provider):
elif msg.role != "system":
messages.append({"role": msg.role, "content": content})

model_dict = get_model_dict(provider)
model = model_dict[request.model]
payload = {
"model": model,
"messages": messages,
Expand Down Expand Up @@ -925,7 +924,7 @@ async def get_claude_payload(request, engine, provider):
model = model_dict[request.model]
headers = {
"content-type": "application/json",
"x-api-key": f"{await provider_api_circular_list[provider['provider']].next()}",
"x-api-key": f"{await provider_api_circular_list[provider['provider']].next(model)}",
"anthropic-version": "2023-06-01",
"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" if "claude-3-5-sonnet" in model else "tools-2024-05-16",
}
Expand Down Expand Up @@ -1068,7 +1067,7 @@ async def get_dalle_payload(request, engine, provider):
"Content-Type": "application/json",
}
if provider.get("api"):
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
url = provider['base_url']
url = BaseAPI(url).image_url

Expand All @@ -1088,7 +1087,7 @@ async def get_whisper_payload(request, engine, provider):
# "Content-Type": "multipart/form-data",
}
if provider.get("api"):
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
url = provider['base_url']
url = BaseAPI(url).audio_transcriptions

Expand All @@ -1115,7 +1114,7 @@ async def get_moderation_payload(request, engine, provider):
"Content-Type": "application/json",
}
if provider.get("api"):
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
url = provider['base_url']
url = BaseAPI(url).moderations

Expand All @@ -1132,7 +1131,7 @@ async def get_embedding_payload(request, engine, provider):
"Content-Type": "application/json",
}
if provider.get("api"):
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
url = provider['base_url']
url = BaseAPI(url).embeddings

Expand Down
64 changes: 47 additions & 17 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,21 @@ async def get_user_rate_limit(app, api_index: str = None):
import asyncio

class ThreadSafeCircularList:
def __init__(self, items, rate_limit="99999/min"):
def __init__(self, items, rate_limit={"default": "999999/min"}):
self.items = items
self.index = 0
self.lock = asyncio.Lock()
self.requests = defaultdict(list) # 用于追踪每个 API key 的请求时间
# 修改为二级字典,第一级是item,第二级是model
self.requests = defaultdict(lambda: defaultdict(list))
self.cooling_until = defaultdict(float)
self.rate_limits = parse_rate_limit(rate_limit) # 现在返回一个限制条件列表
self.rate_limits = {}
if isinstance(rate_limit, dict):
for rate_limit_model, rate_limit_value in rate_limit.items():
self.rate_limits[rate_limit_model] = parse_rate_limit(rate_limit_value)
elif isinstance(rate_limit, str):
self.rate_limits["default"] = parse_rate_limit(rate_limit)
else:
logger.error(f"Error ThreadSafeCircularList: Unknown rate_limit type: {type(rate_limit)}, rate_limit: {rate_limit}")

async def set_cooling(self, item: str, cooling_time: int = 60):
"""设置某个 item 进入冷却状态
Expand All @@ -102,36 +110,58 @@ async def set_cooling(self, item: str, cooling_time: int = 60):
# self.requests[item] = []
logger.warning(f"API key {item} 已进入冷却状态,冷却时间 {cooling_time} 秒")

async def is_rate_limited(self, item) -> bool:
async def is_rate_limited(self, item, model: str = None) -> bool:
now = time()
# 检查是否在冷却中
if now < self.cooling_until[item]:
return True

# 获取适用的速率限制

if model:
model_key = model
else:
model_key = "default"

rate_limit = None
# 先尝试精确匹配
if model and model in self.rate_limits:
rate_limit = self.rate_limits[model]
else:
# 如果没有精确匹配,尝试模糊匹配
for limit_model in self.rate_limits:
if limit_model != "default" and model and limit_model in model:
rate_limit = self.rate_limits[limit_model]
break

# 如果都没匹配到,使用默认值
if rate_limit is None:
rate_limit = self.rate_limits.get("default", [(999999, 60)]) # 默认限制

# 检查所有速率限制条件
for limit_count, limit_period in self.rate_limits:
# 计算在当前时间窗口内的请求数量,而不是直接修改请求列表
recent_requests = sum(1 for req in self.requests[item] if req > now - limit_period)
for limit_count, limit_period in rate_limit:
# 使用特定模型的请求记录进行计算
recent_requests = sum(1 for req in self.requests[item][model_key] if req > now - limit_period)
if recent_requests >= limit_count:
logger.warning(f"API key {item} 已达到速率限制 ({limit_count}/{limit_period}秒)")
logger.warning(f"API key {item} 对模型 {model_key} 已达到速率限制 ({limit_count}/{limit_period}秒)")
return True

# 清理太旧的请求记录(比最长时间窗口还要老的记录)
max_period = max(period for _, period in self.rate_limits)
self.requests[item] = [req for req in self.requests[item] if req > now - max_period]
# 清理太旧的请求记录
max_period = max(period for _, period in rate_limit)
self.requests[item][model_key] = [req for req in self.requests[item][model_key] if req > now - max_period]

# 所有限制条件都通过,记录新的请求
self.requests[item].append(now)
# 记录新的请求
self.requests[item][model_key].append(now)
return False

async def next(self):
async def next(self, model: str = None):
async with self.lock:
start_index = self.index
while True:
item = self.items[self.index]
self.index = (self.index + 1) % len(self.items)

if not await self.is_rate_limited(item):
if not await self.is_rate_limited(item, model):
return item

# 如果已经检查了所有的 API key 都被限制
Expand Down Expand Up @@ -220,12 +250,12 @@ def update_config(config_data, use_config_url=False):
if isinstance(provider_api, str):
provider_api_circular_list[provider['provider']] = ThreadSafeCircularList(
[provider_api],
safe_get(provider, "preferences", "API_KEY_RATE_LIMIT", default="999999/min")
safe_get(provider, "preferences", "api_key_rate_limit", default={"default": "999999/min"})
)
if isinstance(provider_api, list):
provider_api_circular_list[provider['provider']] = ThreadSafeCircularList(
provider_api,
safe_get(provider, "preferences", "API_KEY_RATE_LIMIT", default="999999/min")
safe_get(provider, "preferences", "api_key_rate_limit", default={"default": "999999/min"})
)

if not provider.get("model"):
Expand Down

0 comments on commit cdf3ed9

Please sign in to comment.