Skip to content

Commit

Permalink
Merge pull request #1544 from EvensXia/fix_metagpt_from_evensxia
Browse files Browse the repository at this point in the history
fix ollama_api to add llava support
  • Loading branch information
geekan authored Oct 30, 2024
2 parents 8a46f96 + e209e0e commit 9db0874
Show file tree
Hide file tree
Showing 6 changed files with 345 additions and 130 deletions.
4 changes: 2 additions & 2 deletions examples/llm_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ async def main():
# check if the configured llm supports llm-vision capacity. If not, it will throw a error
invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png")
img_base64 = encode_image(invoice_path)
res = await llm.aask(msg="if this is a invoice, just return True else return False", images=[img_base64])
assert "true" in res.lower()
res = await llm.aask(msg="return `True` if this image might be a invoice, or return `False`", images=[img_base64])
assert ("true" in res.lower()) or ("invoice" in res.lower())


if __name__ == "__main__":
Expand Down
8 changes: 6 additions & 2 deletions metagpt/configs/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ class LLMType(Enum):
GEMINI = "gemini"
METAGPT = "metagpt"
AZURE = "azure"
OLLAMA = "ollama"
OLLAMA = "ollama" # /chat at ollama api
OLLAMA_GENERATE = "ollama.generate" # /generate at ollama api
OLLAMA_EMBEDDINGS = "ollama.embeddings" # /embeddings at ollama api
OLLAMA_EMBED = "ollama.embed" # /embed at ollama api
QIANFAN = "qianfan" # Baidu BCE
DASHSCOPE = "dashscope" # Aliyun LingJi DashScope
MOONSHOT = "moonshot"
Expand Down Expand Up @@ -105,7 +108,8 @@ def check_llm_key(cls, v):
root_config_path = CONFIG_ROOT / "config2.yaml"
if root_config_path.exists():
raise ValueError(
f"Please set your API key in {root_config_path}. If you also set your config in {repo_config_path}, \nthe former will overwrite the latter. This may cause unexpected result.\n"
f"Please set your API key in {root_config_path}. If you also set your config in {repo_config_path}, \n"
f"the former will overwrite the latter. This may cause unexpected result.\n"
)
elif repo_config_path.exists():
raise ValueError(f"Please set your API key in {repo_config_path}")
Expand Down
50 changes: 4 additions & 46 deletions metagpt/provider/general_api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from contextlib import asynccontextmanager
from enum import Enum
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Dict,
Expand Down Expand Up @@ -121,7 +122,7 @@ def fmt(key, val):


class OpenAIResponse:
def __init__(self, data, headers):
def __init__(self, data: Union[bytes, Any], headers: dict):
self._headers = headers
self.data = data

Expand Down Expand Up @@ -320,49 +321,6 @@ def request(
resp, got_stream = self._interpret_response(result, stream)
return resp, got_stream, self.api_key

@overload
async def arequest(
self,
method,
url,
params,
headers,
files,
stream: Literal[True],
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
pass

@overload
async def arequest(
self,
method,
url,
params=...,
headers=...,
files=...,
*,
stream: Literal[True],
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
pass

@overload
async def arequest(
self,
method,
url,
params=...,
headers=...,
files=...,
stream: Literal[False] = ...,
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[OpenAIResponse, bool, str]:
pass

@overload
async def arequest(
self,
Expand Down Expand Up @@ -438,8 +396,8 @@ def request_headers(self, method: str, extra, request_id: Optional[str]) -> Dict
"X-LLM-Client-User-Agent": json.dumps(ua),
"User-Agent": user_agent,
}

headers.update(api_key_to_header(self.api_type, self.api_key))
if self.api_key:
headers.update(api_key_to_header(self.api_type, self.api_key))

if self.organization:
headers["LLM-Organization"] = self.organization
Expand Down
83 changes: 58 additions & 25 deletions metagpt/provider/general_api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,24 @@
# @Desc : General Async API for http-based LLM model

import asyncio
from typing import AsyncGenerator, Generator, Iterator, Tuple, Union
from typing import AsyncGenerator, Iterator, Optional, Tuple, Union

import aiohttp
import requests

from metagpt.logs import logger
from metagpt.provider.general_api_base import APIRequestor
from metagpt.provider.general_api_base import APIRequestor, OpenAIResponse


def parse_stream_helper(line: bytes) -> Union[bytes, None]:
def parse_stream_helper(line: bytes) -> Optional[bytes]:
if line and line.startswith(b"data:"):
if line.startswith(b"data: "):
# SSE event may be valid when it contain whitespace
# SSE event may be valid when it contains whitespace
line = line[len(b"data: ") :]
else:
line = line[len(b"data:") :]
if line.strip() == b"[DONE]":
# return here will cause GeneratorExit exception in urllib3
# and it will close http connection with TCP Reset
# Returning None to indicate end of stream
return None
else:
return line
Expand All @@ -37,7 +36,7 @@ def parse_stream(rbody: Iterator[bytes]) -> Iterator[bytes]:

class GeneralAPIRequestor(APIRequestor):
"""
usage
Usage example:
# full_url = "{base_url}{url}"
requester = GeneralAPIRequestor(base_url=base_url)
result, _, api_key = await requester.arequest(
Expand All @@ -50,26 +49,47 @@ class GeneralAPIRequestor(APIRequestor):
)
"""

def _interpret_response_line(self, rbody: bytes, rcode: int, rheaders, stream: bool) -> bytes:
# just do nothing to meet the APIRequestor process and return the raw data
# due to the openai sdk will convert the data into OpenAIResponse which we don't need in general cases.
def _interpret_response_line(self, rbody: bytes, rcode: int, rheaders: dict, stream: bool) -> OpenAIResponse:
"""
Process and return the response data wrapped in OpenAIResponse.
return rbody
Args:
rbody (bytes): The response body.
rcode (int): The response status code.
rheaders (dict): The response headers.
stream (bool): Whether the response is a stream.
Returns:
OpenAIResponse: The response data wrapped in OpenAIResponse.
"""
return OpenAIResponse(rbody, rheaders)

def _interpret_response(
self, result: requests.Response, stream: bool
) -> Tuple[Union[bytes, Iterator[Generator]], bytes]:
"""Returns the response(s) and a bool indicating whether it is a stream."""
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]:
"""
Interpret a synchronous response.
Args:
result (requests.Response): The response object.
stream (bool): Whether the response is a stream.
Returns:
Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]: A tuple containing the response content and a boolean indicating if it is a stream.
"""
content_type = result.headers.get("Content-Type", "")
if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type):
return (
self._interpret_response_line(line, result.status_code, result.headers, stream=True)
for line in parse_stream(result.iter_lines())
), True
(
self._interpret_response_line(line, result.status_code, result.headers, stream=True)
for line in parse_stream(result.iter_lines())
),
True,
)
else:
return (
self._interpret_response_line(
result.content, # let the caller to decode the msg
result.content, # let the caller decode the msg
result.status_code,
result.headers,
stream=False,
Expand All @@ -79,26 +99,39 @@ def _interpret_response(

async def _interpret_async_response(
self, result: aiohttp.ClientResponse, stream: bool
) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]:
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]:
"""
Interpret an asynchronous response.
Args:
result (aiohttp.ClientResponse): The response object.
stream (bool): Whether the response is a stream.
Returns:
Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]: A tuple containing the response content and a boolean indicating if it is a stream.
"""
content_type = result.headers.get("Content-Type", "")
if stream and (
"text/event-stream" in content_type or "application/x-ndjson" in content_type or content_type == ""
):
# the `Content-Type` of ollama stream resp is "application/x-ndjson"
return (
self._interpret_response_line(line, result.status, result.headers, stream=True)
async for line in result.content
), True
(
self._interpret_response_line(line, result.status, result.headers, stream=True)
async for line in result.content
),
True,
)
else:
try:
await result.read()
response_content = await result.read()
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
raise TimeoutError("Request timed out") from e
except aiohttp.ClientError as exp:
logger.warning(f"response: {result.content}, exp: {exp}")
logger.warning(f"response: {result}, exp: {exp}")
response_content = b""
return (
self._interpret_response_line(
await result.read(), # let the caller to decode the msg
response_content, # let the caller decode the msg
result.status,
result.headers,
stream=False,
Expand Down
Loading

0 comments on commit 9db0874

Please sign in to comment.