Skip to content

fix(py/plugins/genai): fixes for genai plugin response schema and streaming #2538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions py/packages/genkit-ai/src/genkit/core/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def extract_json(text: str, throw_on_bad_json: bool = True) -> Any:
>>> extract_json('invalid json', throw_on_bad_json=False)
None
"""

if text.strip() == '':
return None

opening_char = None
closing_char = None
start_pos = None
Expand Down
5 changes: 5 additions & 0 deletions py/packages/genkit-ai/tests/genkit/core/extract_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ def test_extract_items(name, steps):
{'text': 'prefix{"a":1}suffix'},
{'expected': {'a': 1}},
),
(
'returns None for empty str',
{'text': ''},
{'expected': None},
),
(
'extracts simple array',
{'text': 'prefix[1,2,3]suffix'},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
from typing import Any

from google import genai
from google.genai import types as genai_types

from genkit.ai.registry import GenkitRegistry
from genkit.core.action import ActionKind, ActionRunContext
Expand Down Expand Up @@ -318,7 +319,7 @@ def __init__(
self._client = client
self._registry = registry

def _create_vertexai_tool(self, tool: ToolDefinition) -> genai.types.Tool:
def _create_vertexai_tool(self, tool: ToolDefinition) -> genai_types.Tool:
"""Create a tool that is compatible with VertexAI API.

Args:
Expand All @@ -327,15 +328,15 @@ def _create_vertexai_tool(self, tool: ToolDefinition) -> genai.types.Tool:
Returns:
Genai tool compatible with VertexAI API.
"""
function = genai.types.FunctionDeclaration(
function = genai_types.FunctionDeclaration(
name=tool.name,
description=tool.description,
parameters=tool.input_schema,
response=tool.output_schema,
)
return genai.types.Tool(function_declarations=[function])
return genai_types.Tool(function_declarations=[function])

def _create_gemini_tool(self, tool: ToolDefinition) -> genai.types.Tool:
def _create_gemini_tool(self, tool: ToolDefinition) -> genai_types.Tool:
"""Create a tool that is compatible with Gemini API.

Args:
Expand All @@ -345,12 +346,12 @@ def _create_gemini_tool(self, tool: ToolDefinition) -> genai.types.Tool:
Genai tool compatible with Gemini API.
"""
params = self._convert_schema_property(tool.input_schema)
function = genai.types.FunctionDeclaration(
function = genai_types.FunctionDeclaration(
name=tool.name, description=tool.description, parameters=params
)
return genai.types.Tool(function_declarations=[function])
return genai_types.Tool(function_declarations=[function])

def _get_tools(self, request: GenerateRequest) -> list[genai.types.Tool]:
def _get_tools(self, request: GenerateRequest) -> list[genai_types.Tool]:
"""Generates VertexAI Gemini compatible tool definitions.

Args:
Expand All @@ -372,7 +373,7 @@ def _get_tools(self, request: GenerateRequest) -> list[genai.types.Tool]:

def _convert_schema_property(
self, input_schema: dict[str, Any]
) -> genai.types.Schema | None:
) -> genai_types.Schema | None:
"""Sanitizes a schema to be compatible with Gemini API.

Args:
Expand All @@ -384,18 +385,18 @@ def _convert_schema_property(
if not input_schema or 'type' not in input_schema:
return None

schema = genai.types.Schema()
schema = genai_types.Schema()
if input_schema.get('description'):
schema.description = input_schema['description']

if 'type' in input_schema:
schema_type = genai.types.Type(input_schema['type'])
schema_type = genai_types.Type(input_schema['type'])
schema.type = schema_type

if schema_type == genai.types.Type.ARRAY:
if schema_type == genai_types.Type.ARRAY:
schema.items = input_schema['items']

if schema_type == genai.types.Type.OBJECT:
if schema_type == genai_types.Type.OBJECT:
schema.properties = {}
properties = input_schema['properties']
for key in properties:
Expand All @@ -406,7 +407,7 @@ def _convert_schema_property(

return schema

def _call_tool(self, call: genai.types.FunctionCall) -> genai.types.Content:
def _call_tool(self, call: genai_types.FunctionCall) -> genai_types.Content:
"""Calls tool's function from the registry.

Args:
Expand All @@ -420,9 +421,9 @@ def _call_tool(self, call: genai.types.FunctionCall) -> genai.types.Content:
)
args = tool_function.input_type.validate_python(call.args)
tool_answer = tool_function.run(args)
return genai.types.Content(
return genai_types.Content(
parts=[
genai.types.Part.from_function_response(
genai_types.Part.from_function_response(
name=call.name,
response={
'content': tool_answer.response,
Expand Down Expand Up @@ -456,8 +457,8 @@ async def generate(

async def _generate(
self,
request_contents: list[genai.types.Content],
request_cfg: genai.types.GenerateContentConfig,
request_contents: list[genai_types.Content],
request_cfg: genai_types.GenerateContentConfig,
) -> GenerateResponse:
"""Call google-genai generate.

Expand All @@ -483,8 +484,8 @@ async def _generate(

async def _streaming_generate(
self,
request_contents: list[genai.types.Content],
request_cfg: genai.types.GenerateContentConfig | None,
request_contents: list[genai_types.Content],
request_cfg: genai_types.GenerateContentConfig | None,
ctx: ActionRunContext,
) -> GenerateResponse:
"""Call google-genai generate for streaming.
Expand All @@ -500,9 +501,10 @@ async def _streaming_generate(
generator = self._client.aio.models.generate_content_stream(
model=self._version, contents=request_contents, config=request_cfg
)
accumulated_content = []
async for response_chunk in await generator:
content = self._contents_from_response(response_chunk)

accumulated_content.append(*content)
ctx.send_chunk(
chunk=GenerateResponseChunk(
content=content,
Expand All @@ -512,7 +514,7 @@ async def _streaming_generate(
return GenerateResponse(
message=Message(
role=Role.MODEL,
content=[TextPart(text='')],
content=accumulated_content,
)
)

Expand Down Expand Up @@ -540,7 +542,7 @@ def is_multimode(self):

def _build_messages(
self, request: GenerateRequest
) -> list[genai.types.Content]:
) -> list[genai_types.Content]:
"""Build google-genai request contents from Genkit request.

Args:
Expand All @@ -549,20 +551,20 @@ def _build_messages(
Returns:
list of google-genai contents.
"""
request_contents: list[genai.types.Content] = []
request_contents: list[genai_types.Content] = []

for msg in request.messages:
content_parts: list[genai.types.Part] = []
content_parts: list[genai_types.Part] = []
for p in msg.content:
content_parts.append(PartConverter.to_gemini(p))
request_contents.append(
genai.types.Content(parts=content_parts, role=msg.role)
genai_types.Content(parts=content_parts, role=msg.role)
)

return request_contents

def _contents_from_response(
self, response: genai.types.GenerateContentResponse
self, response: genai_types.GenerateContentResponse
) -> list:
"""Retrieve contents from google-genai response.

Expand All @@ -582,7 +584,7 @@ def _contents_from_response(

def _genkit_to_googleai_cfg(
self, request: GenerateRequest
) -> genai.types.GenerateContentConfig | None:
) -> genai_types.GenerateContentConfig | None:
"""Translate GenerationCommonConfig to Google Ai GenerateContentConfig.

Args:
Expand All @@ -596,19 +598,19 @@ def _genkit_to_googleai_cfg(
if request.config:
request_config = request.config
if isinstance(request_config, GenerationCommonConfig):
cfg = genai.types.GenerateContentConfig(
cfg = genai_types.GenerateContentConfig(
max_output_tokens=request_config.max_output_tokens,
top_k=request_config.top_k,
top_p=request_config.top_p,
temperature=request_config.temperature,
stop_sequences=request_config.stop_sequences,
)
elif isinstance(request_config, dict):
cfg = genai.types.GenerateContentConfig(**request_config)
cfg = genai_types.GenerateContentConfig(**request_config)

if request.output:
if not cfg:
cfg = genai.types.GenerateContentConfig()
cfg = genai_types.GenerateContentConfig()

response_mime_type = (
'application/json'
Expand All @@ -617,9 +619,14 @@ def _genkit_to_googleai_cfg(
)
cfg.response_mime_type = response_mime_type

if request.output.schema_ and request.output.constrained:
cfg.response_schema = self._convert_schema_property(
request.output.schema_
)

if request.tools:
if not cfg:
cfg = genai.types.GenerateContentConfig()
cfg = genai_types.GenerateContentConfig()

tools = self._get_tools(request)
cfg.tools = tools
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ async def test_generate_stream_text_response(mocker, version):
)
])
assert isinstance(response, GenerateResponse)
assert response.message.content[0].root.text == ''
assert response.message.content == []


@pytest.mark.asyncio
Expand Down
36 changes: 36 additions & 0 deletions py/samples/hello-google-genai/src/hello.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,42 @@ async def say_hi_stream(name: str, ctx):
return result


class RpgCharacter(BaseModel):
name: str = Field(description='name of the character')
story: str = Field(description='back story')
weapons: list[str] = Field(description='list of weapons (3-4)')


@ai.flow()
async def generate_character(name: str, ctx):
if ctx.is_streaming:
stream, result = ai.generate_stream(
prompt=f'generate an RPG character named {name}',
output_schema=RpgCharacter,
)
async for data in stream:
ctx.send_chunk(data.output)

return (await result).output
else:
result = await ai.generate(
prompt=f'generate an RPG character named {name}',
output_schema=RpgCharacter,
)
return result.output


@ai.flow()
async def generate_character_unconstrained(name: str, ctx):
result = await ai.generate(
prompt=f'generate an RPG character named {name}',
output_schema=RpgCharacter,
output_constrained=False,
output_instructions=True,
)
return result.output


async def main() -> None:
print(await say_hi(', tell me a joke'))

Expand Down
Loading