diff --git a/README.md b/README.md index cdbc6df..8ed80f7 100644 --- a/README.md +++ b/README.md @@ -90,9 +90,12 @@ providers: - gemini-1.5-flash-exp-0827 # Add this line, both gemini-1.5-flash-exp-0827 and gemini-1.5-flash can be requested tools: true preferences: - API_KEY_RATE_LIMIT: 15/min # Each API Key can request up to 15 times per minute, optional. The default is 999999/min. - # API_KEY_RATE_LIMIT: 15/min,10/day # Supports multiple frequency constraints - API_KEY_COOLDOWN_PERIOD: 60 # Each API Key will be cooled down for 60 seconds after encountering a 429 error. Optional, the default is 0 seconds. When set to 0, the cooling mechanism is not enabled. + api_key_rate_limit: 15/min # Each API Key can request up to 15 times per minute, optional. The default is 999999/min. Supports multiple frequency constraints: 15/min,10/day + # api_key_rate_limit: # You can set different frequency limits for each model + # gpt-4o: 3/min + # chatgpt-4o-latest: 2/min + # default: 4/min # If the model does not set the frequency limit, use the frequency limit of default + api_key_cooldown_period: 60 # Each API Key will be cooled down for 60 seconds after encountering a 429 error. Optional, the default is 0 seconds. When set to 0, the cooling mechanism is not enabled. When there are multiple API keys, the cooling mechanism will take effect. - provider: vertex project_id: gen-lang-client-xxxxxxxxxxxxxx # Description: Your Google Cloud project ID. Format: String, usually composed of lowercase letters, numbers, and hyphens. How to obtain: You can find your project ID in the project selector of the Google Cloud Console. diff --git a/README_CN.md b/README_CN.md index bb048d3..4afb0bf 100644 --- a/README_CN.md +++ b/README_CN.md @@ -90,9 +90,12 @@ providers: - gemini-1.5-flash-exp-0827 # 加上这一行,gemini-1.5-flash-exp-0827 和 gemini-1.5-flash 都可以被请求 tools: true preferences: - API_KEY_RATE_LIMIT: 15/min # 每个 API Key 每分钟最多请求次数,选填。默认为 999999/min - # API_KEY_RATE_LIMIT: 15/min,10/day # 支持多个频率约束条件 - API_KEY_COOLDOWN_PERIOD: 60 # 每个 API Key 遭遇 429 错误后的冷却时间,单位为秒,选填。默认为 0 秒, 当设置为 0 秒时,不启用冷却机制。 + api_key_rate_limit: 15/min # 每个 API Key 每分钟最多请求次数,选填。默认为 999999/min。支持多个频率约束条件:15/min,10/day + # api_key_rate_limit: # 可以为每个模型设置不同的频率限制 + # gpt-4o: 3/min + # chatgpt-4o-latest: 2/min + # default: 4/min # 如果模型没有设置频率限制,使用 default 的频率限制 + api_key_cooldown_period: 60 # 每个 API Key 遭遇 429 错误后的冷却时间,单位为秒,选填。默认为 0 秒, 当设置为 0 秒时,不启用冷却机制。当存在多个 API key 时才会生效。 - provider: vertex project_id: gen-lang-client-xxxxxxxxxxxxxx # 描述: 您的Google Cloud项目ID。格式: 字符串,通常由小写字母、数字和连字符组成。获取方式: 在Google Cloud Console的项目选择器中可以找到您的项目ID。 diff --git a/main.py b/main.py index 2b7a325..b703dbd 100644 --- a/main.py +++ b/main.py @@ -655,20 +655,22 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A engine = "gpt" model_dict = get_model_dict(provider) - if "claude" not in model_dict[request.model] \ - and "gpt" not in model_dict[request.model] \ - and "gemini" not in model_dict[request.model] \ + original_model = model_dict[request.model] + + if "claude" not in original_model \ + and "gpt" not in original_model \ + and "gemini" not in original_model \ and parsed_url.netloc != 'api.cloudflare.com' \ and parsed_url.netloc != 'api.cohere.com': engine = "openrouter" - if "claude" in model_dict[request.model] and engine == "vertex": + if "claude" in original_model and engine == "vertex": engine = "vertex-claude" - if "gemini" in model_dict[request.model] and engine == "vertex": + if "gemini" in original_model and engine == "vertex": engine = "vertex-gemini" - if "o1-preview" in model_dict[request.model] or "o1-mini" in model_dict[request.model]: + if "o1-preview" in original_model or "o1-mini" in original_model: engine = "o1" request.stream = False @@ -702,17 +704,16 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A logger.info(json.dumps(payload, indent=4, ensure_ascii=False)) current_info = request_info.get() - model = model_dict[request.model] timeout_value = None # 先尝试精确匹配 - if model in app.state.timeouts: - timeout_value = app.state.timeouts[model] + if original_model in app.state.timeouts: + timeout_value = app.state.timeouts[original_model] else: # 如果没有精确匹配,尝试模糊匹配 for timeout_model in app.state.timeouts: - if timeout_model in model: + if timeout_model in original_model: timeout_value = app.state.timeouts[timeout_model] break @@ -723,11 +724,11 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A try: async with app.state.client_manager.get_client(timeout_value) as client: if request.stream: - generator = fetch_response_stream(client, url, headers, payload, engine, model) + generator = fetch_response_stream(client, url, headers, payload, engine, original_model) wrapped_generator, first_response_time = await error_handling_wrapper(generator) response = StarletteStreamingResponse(wrapped_generator, media_type="text/event-stream") else: - generator = fetch_response(client, url, headers, payload, engine, model) + generator = fetch_response(client, url, headers, payload, engine, original_model) wrapped_generator, first_response_time = await error_handling_wrapper(generator) first_element = await anext(wrapped_generator) first_element = first_element.lstrip("data: ") @@ -1013,7 +1014,7 @@ async def request_model(self, request: Union[RequestModel, ImageGenerationReques num_matching_providers = len(matching_providers) index = 0 - cooling_time = safe_get(provider, "preferences", "API_KEY_COOLDOWN_PERIOD", default=0) + cooling_time = safe_get(provider, "preferences", "api_key_cooldown_period", default=0) api_key_count = provider_api_circular_list[channel_id].get_items_count() if cooling_time > 0 and api_key_count > 1: current_api = await provider_api_circular_list[channel_id].after_next_current() diff --git a/request.py b/request.py index 241c234..623a357 100644 --- a/request.py +++ b/request.py @@ -125,9 +125,9 @@ async def get_gemini_payload(request, engine, provider): gemini_stream = "streamGenerateContent" url = provider['base_url'] if url.endswith("v1beta"): - url = "https://generativelanguage.googleapis.com/v1beta/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next()) + url = "https://generativelanguage.googleapis.com/v1beta/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next(model)) if url.endswith("v1"): - url = "https://generativelanguage.googleapis.com/v1/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next()) + url = "https://generativelanguage.googleapis.com/v1/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next(model)) messages = [] systemInstruction = None @@ -596,8 +596,10 @@ async def get_gpt_payload(request, engine, provider): headers = { 'Content-Type': 'application/json', } + model_dict = get_model_dict(provider) + model = model_dict[request.model] if provider.get("api"): - headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}" + headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}" url = provider['base_url'] messages = [] @@ -637,8 +639,6 @@ async def get_gpt_payload(request, engine, provider): else: messages.append({"role": msg.role, "content": content}) - model_dict = get_model_dict(provider) - model = model_dict[request.model] payload = { "model": model, "messages": messages, @@ -663,8 +663,10 @@ async def get_openrouter_payload(request, engine, provider): headers = { 'Content-Type': 'application/json' } + model_dict = get_model_dict(provider) + model = model_dict[request.model] if provider.get("api"): - headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}" + headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}" url = provider['base_url'] @@ -696,8 +698,6 @@ async def get_openrouter_payload(request, engine, provider): else: messages.append({"role": msg.role, "content": content}) - model_dict = get_model_dict(provider) - model = model_dict[request.model] payload = { "model": model, "messages": messages, @@ -730,8 +730,10 @@ async def get_cohere_payload(request, engine, provider): headers = { 'Content-Type': 'application/json' } + model_dict = get_model_dict(provider) + model = model_dict[request.model] if provider.get("api"): - headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}" + headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}" url = provider['base_url'] @@ -759,8 +761,6 @@ async def get_cohere_payload(request, engine, provider): else: messages.append({"role": role_map[msg.role], "message": content}) - model_dict = get_model_dict(provider) - model = model_dict[request.model] chat_history = messages[:-1] query = messages[-1].get("message") payload = { @@ -798,11 +798,11 @@ async def get_cloudflare_payload(request, engine, provider): headers = { 'Content-Type': 'application/json' } - if provider.get("api"): - headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}" - model_dict = get_model_dict(provider) model = model_dict[request.model] + if provider.get("api"): + headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}" + url = "https://api.cloudflare.com/client/v4/accounts/{cf_account_id}/ai/run/{cf_model_id}".format(cf_account_id=provider['cf_account_id'], cf_model_id=model) msg = request.messages[-1] @@ -816,7 +816,6 @@ async def get_cloudflare_payload(request, engine, provider): content = msg.content name = msg.name - model = model_dict[request.model] payload = { "prompt": content, } @@ -848,8 +847,10 @@ async def get_o1_payload(request, engine, provider): headers = { 'Content-Type': 'application/json' } + model_dict = get_model_dict(provider) + model = model_dict[request.model] if provider.get("api"): - headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}" + headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}" url = provider['base_url'] @@ -871,8 +872,6 @@ async def get_o1_payload(request, engine, provider): elif msg.role != "system": messages.append({"role": msg.role, "content": content}) - model_dict = get_model_dict(provider) - model = model_dict[request.model] payload = { "model": model, "messages": messages, @@ -925,7 +924,7 @@ async def get_claude_payload(request, engine, provider): model = model_dict[request.model] headers = { "content-type": "application/json", - "x-api-key": f"{await provider_api_circular_list[provider['provider']].next()}", + "x-api-key": f"{await provider_api_circular_list[provider['provider']].next(model)}", "anthropic-version": "2023-06-01", "anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" if "claude-3-5-sonnet" in model else "tools-2024-05-16", } @@ -1068,7 +1067,7 @@ async def get_dalle_payload(request, engine, provider): "Content-Type": "application/json", } if provider.get("api"): - headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}" + headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}" url = provider['base_url'] url = BaseAPI(url).image_url @@ -1088,7 +1087,7 @@ async def get_whisper_payload(request, engine, provider): # "Content-Type": "multipart/form-data", } if provider.get("api"): - headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}" + headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}" url = provider['base_url'] url = BaseAPI(url).audio_transcriptions @@ -1115,7 +1114,7 @@ async def get_moderation_payload(request, engine, provider): "Content-Type": "application/json", } if provider.get("api"): - headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}" + headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}" url = provider['base_url'] url = BaseAPI(url).moderations @@ -1132,7 +1131,7 @@ async def get_embedding_payload(request, engine, provider): "Content-Type": "application/json", } if provider.get("api"): - headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}" + headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}" url = provider['base_url'] url = BaseAPI(url).embeddings diff --git a/utils.py b/utils.py index cda0fba..156dcc2 100644 --- a/utils.py +++ b/utils.py @@ -80,13 +80,21 @@ async def get_user_rate_limit(app, api_index: str = None): import asyncio class ThreadSafeCircularList: - def __init__(self, items, rate_limit="99999/min"): + def __init__(self, items, rate_limit={"default": "999999/min"}): self.items = items self.index = 0 self.lock = asyncio.Lock() - self.requests = defaultdict(list) # 用于追踪每个 API key 的请求时间 + # 修改为二级字典,第一级是item,第二级是model + self.requests = defaultdict(lambda: defaultdict(list)) self.cooling_until = defaultdict(float) - self.rate_limits = parse_rate_limit(rate_limit) # 现在返回一个限制条件列表 + self.rate_limits = {} + if isinstance(rate_limit, dict): + for rate_limit_model, rate_limit_value in rate_limit.items(): + self.rate_limits[rate_limit_model] = parse_rate_limit(rate_limit_value) + elif isinstance(rate_limit, str): + self.rate_limits["default"] = parse_rate_limit(rate_limit) + else: + logger.error(f"Error ThreadSafeCircularList: Unknown rate_limit type: {type(rate_limit)}, rate_limit: {rate_limit}") async def set_cooling(self, item: str, cooling_time: int = 60): """设置某个 item 进入冷却状态 @@ -102,36 +110,58 @@ async def set_cooling(self, item: str, cooling_time: int = 60): # self.requests[item] = [] logger.warning(f"API key {item} 已进入冷却状态,冷却时间 {cooling_time} 秒") - async def is_rate_limited(self, item) -> bool: + async def is_rate_limited(self, item, model: str = None) -> bool: now = time() # 检查是否在冷却中 if now < self.cooling_until[item]: return True + # 获取适用的速率限制 + + if model: + model_key = model + else: + model_key = "default" + + rate_limit = None + # 先尝试精确匹配 + if model and model in self.rate_limits: + rate_limit = self.rate_limits[model] + else: + # 如果没有精确匹配,尝试模糊匹配 + for limit_model in self.rate_limits: + if limit_model != "default" and model and limit_model in model: + rate_limit = self.rate_limits[limit_model] + break + + # 如果都没匹配到,使用默认值 + if rate_limit is None: + rate_limit = self.rate_limits.get("default", [(999999, 60)]) # 默认限制 + # 检查所有速率限制条件 - for limit_count, limit_period in self.rate_limits: - # 计算在当前时间窗口内的请求数量,而不是直接修改请求列表 - recent_requests = sum(1 for req in self.requests[item] if req > now - limit_period) + for limit_count, limit_period in rate_limit: + # 使用特定模型的请求记录进行计算 + recent_requests = sum(1 for req in self.requests[item][model_key] if req > now - limit_period) if recent_requests >= limit_count: - logger.warning(f"API key {item} 已达到速率限制 ({limit_count}/{limit_period}秒)") + logger.warning(f"API key {item} 对模型 {model_key} 已达到速率限制 ({limit_count}/{limit_period}秒)") return True - # 清理太旧的请求记录(比最长时间窗口还要老的记录) - max_period = max(period for _, period in self.rate_limits) - self.requests[item] = [req for req in self.requests[item] if req > now - max_period] + # 清理太旧的请求记录 + max_period = max(period for _, period in rate_limit) + self.requests[item][model_key] = [req for req in self.requests[item][model_key] if req > now - max_period] - # 所有限制条件都通过,记录新的请求 - self.requests[item].append(now) + # 记录新的请求 + self.requests[item][model_key].append(now) return False - async def next(self): + async def next(self, model: str = None): async with self.lock: start_index = self.index while True: item = self.items[self.index] self.index = (self.index + 1) % len(self.items) - if not await self.is_rate_limited(item): + if not await self.is_rate_limited(item, model): return item # 如果已经检查了所有的 API key 都被限制 @@ -220,12 +250,12 @@ def update_config(config_data, use_config_url=False): if isinstance(provider_api, str): provider_api_circular_list[provider['provider']] = ThreadSafeCircularList( [provider_api], - safe_get(provider, "preferences", "API_KEY_RATE_LIMIT", default="999999/min") + safe_get(provider, "preferences", "api_key_rate_limit", default={"default": "999999/min"}) ) if isinstance(provider_api, list): provider_api_circular_list[provider['provider']] = ThreadSafeCircularList( provider_api, - safe_get(provider, "preferences", "API_KEY_RATE_LIMIT", default="999999/min") + safe_get(provider, "preferences", "api_key_rate_limit", default={"default": "999999/min"}) ) if not provider.get("model"):