diff --git a/config/examples/unify.yaml b/config/examples/unify.yaml new file mode 100644 index 000000000..16c0870ea --- /dev/null +++ b/config/examples/unify.yaml @@ -0,0 +1,5 @@ +llm: + api_type: "unify" + model: "llama-3-8b-chat@together-ai" # or Get a list of models here: https://docs.unify.ai/python/utils#list-models + base_url: "https://api.unify.ai/v0" + api_key: "Enter your Unify API key here" # or Get your API key from https://console.unify.ai \ No newline at end of file diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index e7c280ee3..8736e3234 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -34,6 +34,7 @@ class LLMType(Enum): OPENROUTER = "openrouter" BEDROCK = "bedrock" ARK = "ark" # https://www.volcengine.com/docs/82379/1263482#python-sdk + UNIFY = "unify" def __missing__(self, key): return self.OPENAI diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 31907d9e8..d3533595d 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -50,6 +50,7 @@ LLMType.MISTRAL, LLMType.YI, LLMType.OPENROUTER, + LLMType.UNIFY, ] ) class OpenAILLM(BaseLLM): diff --git a/metagpt/provider/unify.py b/metagpt/provider/unify.py new file mode 100644 index 000000000..61e9f4c56 --- /dev/null +++ b/metagpt/provider/unify.py @@ -0,0 +1,122 @@ +from typing import Optional, Dict, List, Union +from openai.types import Completion, CompletionUsage +from openai.types.chat import ChatCompletion + +from metagpt.configs.llm_config import LLMConfig, LLMType +from metagpt.const import USE_CONFIG_TIMEOUT +from metagpt.logs import log_llm_stream, logger +from metagpt.provider.base_llm import BaseLLM +from metagpt.provider.llm_provider_registry import register_provider +from metagpt.utils.cost_manager import CostManager +from metagpt.utils.token_counter import count_message_tokens, OPENAI_TOKEN_COSTS +from unify.clients import Unify, AsyncUnify + +@register_provider([LLMType.UNIFY]) +class UnifyLLM(BaseLLM): + def __init__(self, config: LLMConfig): + self.config = config + self._init_client() + self.cost_manager = CostManager(token_costs=OPENAI_TOKEN_COSTS) # Using OpenAI costs as Unify is compatible + + def _init_client(self): + self.model = self.config.model + self.client = Unify( + api_key=self.config.api_key, + endpoint=f"{self.config.model}@{self.config.provider}", + ) + self.async_client = AsyncUnify( + api_key=self.config.api_key, + endpoint=f"{self.config.model}@{self.config.provider}", + ) + + def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: + return { + "messages": messages, + "max_tokens": self.config.max_token, + "temperature": self.config.temperature, + "stream": stream, + } + + def get_choice_text(self, resp: Union[ChatCompletion, str]) -> str: + if isinstance(resp, str): + return resp + return resp.choices[0].message.content if resp.choices else "" + + def _update_costs(self, usage: dict): + prompt_tokens = usage.get("prompt_tokens", 0) + completion_tokens = usage.get("completion_tokens", 0) + self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + + async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion: + try: + response = await self.async_client.generate( + messages=messages, + max_tokens=self.config.max_token, + temperature=self.config.temperature, + stream=False, + ) + # Construct a ChatCompletion object to match OpenAI's format + chat_completion = ChatCompletion( + id="unify_chat_completion", + object="chat.completion", + created=0, # Unify doesn't provide this, so we use 0 + model=self.model, + choices=[{ + "index": 0, + "message": { + "role": "assistant", + "content": response, + }, + "finish_reason": "stop", + }], + usage=CompletionUsage( + prompt_tokens=count_message_tokens(messages, self.model), + completion_tokens=count_message_tokens([{"role": "assistant", "content": response}], self.model), + total_tokens=0, # Will be calculated below + ), + ) + chat_completion.usage.total_tokens = chat_completion.usage.prompt_tokens + chat_completion.usage.completion_tokens + self._update_costs(chat_completion.usage.model_dump()) + return chat_completion + except Exception as e: + logger.error(f"Error in Unify chat completion: {str(e)}") + raise + + async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: + try: + stream = self.client.generate( + messages=messages, + max_tokens=self.config.max_token, + temperature=self.config.temperature, + stream=True, + ) + collected_content = [] + for chunk in stream: + log_llm_stream(chunk) + collected_content.append(chunk) + + full_content = "".join(collected_content) + usage = { + "prompt_tokens": count_message_tokens(messages, self.model), + "completion_tokens": count_message_tokens([{"role": "assistant", "content": full_content}], self.model), + } + self._update_costs(usage) + return full_content + except Exception as e: + logger.error(f"Error in Unify chat completion stream: {str(e)}") + raise + + async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion: + return await self._achat_completion(messages, timeout=timeout) + + async def acompletion_text(self, messages: list[dict], stream=False, timeout=USE_CONFIG_TIMEOUT) -> str: + if stream: + return await self._achat_completion_stream(messages, timeout=timeout) + response = await self._achat_completion(messages, timeout=timeout) + return self.get_choice_text(response) + + def get_model_name(self): + return self.model + + def get_usage(self) -> Optional[Dict[str, int]]: + return self.cost_manager.get_latest_usage() \ No newline at end of file