Skip to content

Commit

Permalink
Merge pull request #556 from lion-agi/update-openai-chat
Browse files Browse the repository at this point in the history
adding ollama endpoint
  • Loading branch information
ohdearquant authored Jan 29, 2025
2 parents 157d0a0 + b9bd4a1 commit 470cceb
Show file tree
Hide file tree
Showing 10 changed files with 472 additions and 47 deletions.
108 changes: 84 additions & 24 deletions lionagi/libs/schema/function_to_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
import inspect
from typing import Any, Literal

from .extract_docstring import extract_docstring
from pydantic import Field, field_validator

from lionagi.libs.schema.extract_docstring import extract_docstring
from lionagi.libs.validate.common_field_validators import (
validate_model_to_type,
)
from lionagi.operatives.models.schema_model import SchemaModel

py_json_msp = {
"str": "string",
Expand All @@ -18,12 +24,60 @@
}


class FunctionSchema(SchemaModel):
name: str
description: str | None = Field(
None,
description=(
"A description of what the function does, used by the "
"model to choose when and how to call the function."
),
)
parameters: dict[str, Any] | None = Field(
None,
description=(
"The parameters the functions accepts, described as a JSON Schema object. "
"See the guide (https://platform.openai.com/docs/guides/function-calling) "
"for examples, and the JSON Schema reference for documentation about the "
"format. Omitting parameters defines a function with an empty parameter list."
),
validation_alias="request_options",
)
strict: bool | None = Field(
None,
description=(
"Whether to enable strict schema adherence when generating the function call. "
"If set to true, the model will follow the exact schema defined in the parameters "
"field. Only a subset of JSON Schema is supported when strict is true."
),
)

@field_validator("parameters", mode="before")
def _validate_parameters(cls, v):
if v is None:
return None
if isinstance(v, dict):
return v
try:
model_type = validate_model_to_type(cls, v)
return model_type.model_json_schema()
except Exception:
raise ValueError(f"Invalid model type: {v}")

def to_dict(self):
dict_ = super().to_dict()
return {"type": "function", "function": dict_}


def function_to_schema(
f_,
style: Literal["google", "rest"] = "google",
*,
request_options: dict[str, Any] | None = None,
strict: bool = None,
func_description: str = None,
parametert_description: dict[str, str] = None,
return_obj: bool = False,
) -> dict:
"""
Generate a schema description for a given function. in openai format
Expand Down Expand Up @@ -78,27 +132,33 @@ def function_to_schema(
"required": [],
}

for name, param in sig.parameters.items():
# Default type to string and update if type hint is available
param_type = "string"
if param.annotation is not inspect.Parameter.empty:
param_type = py_json_msp[param.annotation.__name__]

# Extract parameter description from docstring, if available
param_description = parametert_description.get(name)

# Assuming all parameters are required for simplicity
parameters["required"].append(name)
parameters["properties"][name] = {
"type": param_type,
"description": param_description,
}

return {
"type": "function",
"function": {
"name": func_name,
"description": func_description,
"parameters": parameters,
},
if not request_options:
for name, param in sig.parameters.items():
# Default type to string and update if type hint is available
param_type = "string"
if param.annotation is not inspect.Parameter.empty:
param_type = py_json_msp[param.annotation.__name__]

# Extract parameter description from docstring, if available
param_description = parametert_description.get(name)

# Assuming all parameters are required for simplicity
parameters["required"].append(name)
parameters["properties"][name] = {
"type": param_type,
"description": param_description,
}
else:
parameters = request_options

params = {
"name": func_name,
"description": func_description,
"parameters": parameters,
}
if strict:
params["strict"] = strict

if return_obj:
return FunctionSchema(**params)
return FunctionSchema(**params).to_dict()
26 changes: 21 additions & 5 deletions lionagi/service/endpoints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class EndpointConfig(BaseModel):
api_version: str | None = None
allowed_roles: list[str] | None = None
request_options: type | None = Field(None, exclude=True)
invoke_with_endpoint: bool | None = None


class EndPoint(ABC):
Expand All @@ -91,19 +92,28 @@ class EndPoint(ABC):
HTTP requests.
"""

def __init__(self, config: dict) -> None:
def __init__(
self, config: dict | EndpointConfig | type[EndpointConfig], **kwargs
) -> None:
"""Initializes the EndPoint with a given configuration.
Args:
config (dict): Configuration data that matches the EndpointConfig
config (dict | EndpointConfig): Configuration data that matches the EndpointConfig
schema.
"""
self.config = EndpointConfig(**config)
if isinstance(config, dict):
self.config = EndpointConfig(**config)
if isinstance(config, EndpointConfig):
self.config = config
if isinstance(config, type) and issubclass(config, EndpointConfig):
self.config = config()
if kwargs:
self.update_config(**kwargs)

def update_config(self, **kwargs):
config = self.config.model_dump()
config.update(kwargs)
self.config = EndpointConfig(**config)
self.config = self.config.model_validate(config)

@property
def name(self) -> str | None:
Expand Down Expand Up @@ -354,6 +364,7 @@ class APICalling(Event):
exclude=True,
description="Whether to include token usage information into instruction messages",
)
response_obj: BaseModel | None = Field(None, exclude=True)

@model_validator(mode="after")
def _validate_streaming(self) -> Self:
Expand Down Expand Up @@ -648,7 +659,12 @@ async def invoke(self) -> None:
f"API call to {self.endpoint.full_url} failed: {e1}"
)
else:
self.execution.response = response
self.response_obj = response
self.execution.response = (
response.model_dump()
if isinstance(response, BaseModel)
else response
)
self.execution.status = EventStatus.COMPLETED

def __str__(self) -> str:
Expand Down
7 changes: 7 additions & 0 deletions lionagi/service/endpoints/match_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ def match_endpoint(

return OpenRouterChatCompletionEndPoint()

if provider == "ollama":
from ..providers.ollama_.chat_completions import (
OllamaChatCompletionEndPoint,
)

return OllamaChatCompletionEndPoint()

return OpenAIChatCompletionEndPoint(
config={
"provider": provider,
Expand Down
32 changes: 20 additions & 12 deletions lionagi/service/imodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
interval: float | None = None,
limit_requests: int = None,
limit_tokens: int = None,
invoke_with_endpoint: bool = False,
invoke_with_endpoint: bool = None,
concurrency_limit: int | None = None,
streaming_process_func: Callable = None,
requires_api_key: bool = True,
Expand Down Expand Up @@ -95,6 +95,16 @@ def __init__(
Additional keyword arguments, such as `model`, or any other
provider-specific fields.
"""
model = kwargs.get("model", None)
if model:
if not provider:
if "/" in model:
provider = model.split("/")[0]
model = model.replace(provider + "/", "")
kwargs["model"] = model
else:
raise ValueError("Provider must be provided")

if api_key is None:
provider = str(provider or "").strip().lower()
match provider:
Expand All @@ -110,6 +120,8 @@ def __init__(
api_key = "GROQ_API_KEY"
case "exa":
api_key = "EXA_API_KEY"
case "ollama":
api_key = "ollama"
case "":
if requires_api_key:
raise ValueError("API key must be provided")
Expand All @@ -121,16 +133,6 @@ def __init__(
api_key = os.getenv(api_key)

kwargs["api_key"] = api_key
model = kwargs.get("model", None)
if model:
if not provider:
if "/" in model:
provider = model.split("/")[0]
model = model.replace(provider + "/", "")
kwargs["model"] = model
else:
raise ValueError("Provider must be provided")

if isinstance(endpoint, EndPoint):
self.endpoint = endpoint
else:
Expand All @@ -145,7 +147,13 @@ def __init__(
if base_url:
self.endpoint.config.base_url = base_url

self.should_invoke_endpoint = invoke_with_endpoint
if (
invoke_with_endpoint is None
and self.endpoint.config.invoke_with_endpoint is True
):
invoke_with_endpoint = True

self.should_invoke_endpoint = invoke_with_endpoint or False
self.kwargs = kwargs
self.executor = RateLimitedAPIExecutor(
queue_capacity=queue_capacity,
Expand Down
3 changes: 3 additions & 0 deletions lionagi/service/providers/ollama_/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (c) 2023 - 2025, HaiyangLi <quantocean.li at gmail dot com>
#
# SPDX-License-Identifier: Apache-2.0
Loading

0 comments on commit 470cceb

Please sign in to comment.