Skip to content

Commit

Permalink
feat(adapters): Add Google VertexAI support (#469)
Browse files Browse the repository at this point in the history
Ref #468
  • Loading branch information
planetf1 authored Mar 6, 2025
1 parent 1b2cd6a commit 782c673
Show file tree
Hide file tree
Showing 9 changed files with 198 additions and 3 deletions.
8 changes: 8 additions & 0 deletions python/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,11 @@ BEEAI_LOG_LEVEL=INFO

# XAI_API_KEY=your-xai-api-key
# XAI_CHAT_MODEL=grok-2

########################
### Vertex AI specific configuration
########################

# GOOGLE_VERTEX_CHAT_MODEL=gemini-2.0-flash-lite-001
# GOOGLE_VERTEX_PROJECT=""
# GOOGLE_VERTEX_ENDPOINT=""
15 changes: 15 additions & 0 deletions python/beeai_framework/adapters/vertexai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2025 IBM Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


15 changes: 15 additions & 0 deletions python/beeai_framework/adapters/vertexai/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2025 IBM Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


50 changes: 50 additions & 0 deletions python/beeai_framework/adapters/vertexai/backend/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2025 IBM Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os

from beeai_framework.adapters.litellm.chat import LiteLLMChatModel
from beeai_framework.backend.constants import ProviderName
from beeai_framework.utils.custom_logger import BeeLogger

logger = BeeLogger(__name__)


class VertexAIChatModel(LiteLLMChatModel):
@property
def provider_id(self) -> ProviderName:
return "vertexai"

def __init__(self, model_id: str | None = None, settings: dict | None = None) -> None:
_settings = settings.copy() if settings is not None else {}

vertexai_project = _settings.get("vertexai_project", os.getenv("VERTEXAI_PROJECT"))
if not vertexai_project:
raise ValueError(
"Project ID is required for Vertex AI model. Specify *vertexai_project* "
+ "or set VERTEXAI_PROJECT environment variable"
)

# Ensure standard google auth credentials are available
# Set GOOGLE_APPLICATION_CREDENTIALS / GOOGLE_CREDENTIALS / GOOGLE_APPLICATION_CREDENTIALS_JSON

super().__init__(
model_id if model_id else os.getenv("VERTEXAI_CHAT_MODEL", "geminid-2.0-flash-lite-001"),
provider_id="vertex_ai",
settings=_settings
| {
"vertex_project": vertexai_project,
},
)
5 changes: 3 additions & 2 deletions python/beeai_framework/backend/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

from pydantic import BaseModel

ProviderName = Literal["ollama", "openai", "watsonx", "groq", "xai"]
ProviderHumanName = Literal["Ollama", "OpenAI", "Watsonx", "Groq", "XAI"]
ProviderName = Literal["ollama", "openai", "watsonx", "groq", "xai", "vertexai"]
ProviderHumanName = Literal["Ollama", "OpenAI", "Watsonx", "Groq", "XAI", "VertexAI"]


class ProviderDef(BaseModel):
Expand All @@ -39,4 +39,5 @@ class ProviderModelDef(BaseModel):
"watsonx": ProviderDef(name="Watsonx", module="watsonx", aliases=["watsonx", "ibm"]),
"Groq": ProviderDef(name="Groq", module="groq", aliases=["groq"]),
"xAI": ProviderDef(name="XAI", module="xai", aliases=["xai", "grok"]),
"vertexAI": ProviderDef(name="VertexAI", module="vertexai", aliases=["vertexai", "google"]),
}
2 changes: 1 addition & 1 deletion python/docs/backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ The following table depicts supported providers. Each provider requires specific
| `Watsonx` || | `@ibm-cloud/watsonx-ai` | WATSONX_CHAT_MODEL<br/>WATSONX_EMBEDDING_MODEL<br>WATSONX_API_KEY<br/>WATSONX_PROJECT_ID<br/>WATSONX_SPACE_ID<br>WATSONX_VERSION<br>WATSONX_REGION |
| `Groq` || | | GROQ_CHAT_MODEL<br>GROQ_API_KEY |
| `Amazon Bedrock` | | | Coming soon! | AWS_CHAT_MODEL<br>AWS_EMBEDDING_MODEL<br>AWS_ACCESS_KEY_ID<br>AWS_SECRET_ACCESS_KEY<br>AWS_REGION<br>AWS_SESSION_TOKEN |
| `Google Vertex` | | | Coming soon! | GOOGLE_VERTEX_CHAT_MODEL<br>GOOGLE_VERTEX_EMBEDDING_MODEL<br>GOOGLE_VERTEX_PROJECT<br>GOOGLE_VERTEX_ENDPOINT<br>GOOGLE_VERTEX_LOCATION |
| `Google Vertex` | | | | VERTEXAI_CHAT_MODEL<br>VERTEXAI_PROJECT<br>GOOGLE_APPLICATION_CREDENTIALS<br>GOOGLE_APPLICATION_CREDENTIALS_JSON<br>GOOGLE_CREDENTIALS |
| `Azure OpenAI` | | | Coming soon! | AZURE_OPENAI_CHAT_MODEL<br>AZURE_OPENAI_EMBEDDING_MODEL<br>AZURE_OPENAI_API_KEY<br>AZURE_OPENAI_API_ENDPOINT<br>AZURE_OPENAI_API_RESOURCE<br>AZURE_OPENAI_API_VERSION |
| `Anthropic` | | | Coming soon! | ANTHROPIC_CHAT_MODEL<br>ANTHROPIC_EMBEDDING_MODEL<br>ANTHROPIC_API_KEY<br>ANTHROPIC_API_BASE_URL<br>ANTHROPIC_API_HEADERS |
| `xAI` || | | XAI_CHAT_MODEL<br>XAI_API_KEY |
Expand Down
98 changes: 98 additions & 0 deletions python/examples/backend/providers/vertexai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import asyncio
from typing import Any

from pydantic import BaseModel, Field

from beeai_framework.adapters.vertexai.backend.chat import VertexAIChatModel
from beeai_framework.backend.chat import ChatModel
from beeai_framework.backend.message import UserMessage
from beeai_framework.cancellation import AbortSignal
from beeai_framework.emitter import EventMeta
from beeai_framework.errors import AbortError
from beeai_framework.parsers.field import ParserField
from beeai_framework.parsers.line_prefix import LinePrefixParser, LinePrefixParserNode


async def vertexai_from_name() -> None:
llm = ChatModel.from_name("vertexai:gemini-2.0-flash-lite-001")
user_message = UserMessage("what states are part of New England?")
response = await llm.create(messages=[user_message])
print(response.get_text_content())


async def vertexai_sync() -> None:
llm = VertexAIChatModel("gemini-2.0-flash-lite-001")
user_message = UserMessage("what is the capital of Massachusetts?")
response = await llm.create(messages=[user_message])
print(response.get_text_content())


async def vertexai_stream() -> None:
llm = VertexAIChatModel("gemini-2.0-flash-lite-001")
user_message = UserMessage("How many islands make up the country of Cape Verde?")
response = await llm.create(messages=[user_message], stream=True)
print(response.get_text_content())


async def vertexai_stream_abort() -> None:
llm = VertexAIChatModel("gemini-2.0-flash-lite-001")
user_message = UserMessage("What is the smallest of the Cape Verde islands?")

try:
response = await llm.create(messages=[user_message], stream=True, abort_signal=AbortSignal.timeout(0.5))

if response is not None:
print(response.get_text_content())
else:
print("No response returned.")
except AbortError as err:
print(f"Aborted: {err}")


async def vertexai_structure() -> None:
class TestSchema(BaseModel):
answer: str = Field(description="your final answer")

llm = VertexAIChatModel("gemini-2.0-flash-lite-001")
user_message = UserMessage("How many islands make up the country of Cape Verde?")
response = await llm.create_structure(schema=TestSchema, messages=[user_message])
print(response.object)


async def vertexai_stream_parser() -> None:
llm = VertexAIChatModel("gemini-2.0-flash-lite-001")

parser = LinePrefixParser(
nodes={
"test": LinePrefixParserNode(
prefix="Prefix: ", field=ParserField.from_type(str), is_start=True, is_end=True
)
}
)

async def on_new_token(data: dict[str, Any], event: EventMeta) -> None:
await parser.add(data["value"].get_text_content())

user_message = UserMessage("Produce 3 lines each starting with 'Prefix: ' followed by a sentence and a new line.")
await llm.create(messages=[user_message], stream=True).observe(lambda emitter: emitter.on("newToken", on_new_token))
result = await parser.end()
print(result)


async def main() -> None:
print("*" * 10, "vertexai_from_name")
await vertexai_from_name()
print("*" * 10, "vertexai_sync")
await vertexai_sync()
print("*" * 10, "vertexai_stream")
await vertexai_stream()
print("*" * 10, "vertexai_stream_abort")
await vertexai_stream_abort()
print("*" * 10, "vertexai_structure")
await vertexai_structure()
print("*" * 10, "vertexai_stream_parser")
await vertexai_stream_parser()


if __name__ == "__main__":
asyncio.run(main())
6 changes: 6 additions & 0 deletions python/tests/backend/test_chatmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from beeai_framework.adapters.groq.backend.chat import GroqChatModel
from beeai_framework.adapters.ollama.backend.chat import OllamaChatModel
from beeai_framework.adapters.openai.backend.chat import OpenAIChatModel
from beeai_framework.adapters.vertexai.backend.chat import VertexAIChatModel
from beeai_framework.adapters.watsonx.backend.chat import WatsonxChatModel
from beeai_framework.adapters.xai.backend.chat import XAIChatModel
from beeai_framework.backend.chat import (
Expand Down Expand Up @@ -183,3 +184,8 @@ def test_chat_model_from(monkeypatch: pytest.MonkeyPatch) -> None:

xai_chat_model = ChatModel.from_name("xai:grok-2")
assert isinstance(xai_chat_model, XAIChatModel)

#
monkeypatch.setenv("VERTEXAI_PROJECT", "myproject")
vertexai_chat_model = ChatModel.from_name("vertexai:gemini-2.0-flash-lite-001")
assert isinstance(vertexai_chat_model, VertexAIChatModel)
2 changes: 2 additions & 0 deletions python/tests/examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
"backend/providers/openai_example.py" if os.getenv("OPENAI_API_KEY") is None else None,
"backend/providers/groq.py" if os.getenv("GROQ_API_KEY") is None else None,
"backend/providers/xai.py" if os.getenv("XAI_API_KEY") is None else None,
# Google backend picks up environment variables/google auth credentials directly
"backend/providers/vertexai.py",
# requires Searx instance
"workflows/searx_agent.py",
],
Expand Down

0 comments on commit 782c673

Please sign in to comment.