Skip to content

Commit f64a989

Browse files
authored
fix(py/plugins/genai): expose full Gemini config schema (#2540)
1 parent e06b062 commit f64a989

File tree

6 files changed

+31
-8
lines changed

6 files changed

+31
-8
lines changed

py/packages/genkit-ai/src/genkit/ai/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
Role,
4848
Supports,
4949
TextPart,
50+
ToolDefinition,
5051
ToolRequest,
5152
ToolRequestPart,
5253
ToolResponse,
@@ -85,5 +86,6 @@
8586
ToolResponsePart.__name__,
8687
GenerationCommonConfig.__name__,
8788
GenerationUsage.__name__,
89+
ToolDefinition.__name__,
8890
tool_response.__name__,
8991
]

py/packages/genkit-ai/src/genkit/blocks/prompt.py

-2
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,6 @@ def to_generate_action_options(
339339
model = model or registry.default_model
340340
if model is None:
341341
raise Exception('No model configured.')
342-
if not isinstance(config, GenerationCommonConfig | dict | None):
343-
raise AttributeError('Invalid generate config provided')
344342
resolved_msgs: list[Message] = []
345343
if system:
346344
resolved_msgs.append(

py/plugins/google-genai/src/genkit/plugins/google_genai/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
GeminiEmbeddingModels,
2121
VertexEmbeddingModels,
2222
)
23-
from genkit.plugins.google_genai.models.gemini import GeminiVersion
23+
from genkit.plugins.google_genai.models.gemini import (
24+
GeminiConfigSchema,
25+
GeminiVersion,
26+
)
2427

2528

2629
def package_name() -> str:
@@ -40,4 +43,5 @@ def package_name() -> str:
4043
VertexEmbeddingModels.__name__,
4144
GeminiVersion.__name__,
4245
EmbeddingTaskType.__name__,
46+
GeminiConfigSchema.__name__,
4347
]

py/plugins/google-genai/src/genkit/plugins/google_genai/google.py

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131
from genkit.plugins.google_genai.models.gemini import (
3232
GeminiApiOnlyVersion,
33+
GeminiConfigSchema,
3334
GeminiModel,
3435
GeminiVersion,
3536
)
@@ -101,6 +102,7 @@ def initialize(self, ai: GenkitRegistry) -> None:
101102
name=google_genai_name(version),
102103
fn=gemini_model.generate,
103104
metadata=gemini_model.metadata,
105+
config_schema=GeminiConfigSchema,
104106
)
105107

106108
embeding_models = (

py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -124,22 +124,27 @@
124124
from google import genai
125125
from google.genai import types as genai_types
126126

127-
from genkit.ai.registry import GenkitRegistry
128-
from genkit.core.action import ActionKind, ActionRunContext
129-
from genkit.core.typing import (
127+
from genkit.ai import (
128+
ActionKind,
129+
ActionRunContext,
130130
GenerateRequest,
131131
GenerateResponse,
132132
GenerateResponseChunk,
133133
GenerationCommonConfig,
134+
GenkitRegistry,
134135
Message,
135136
ModelInfo,
136137
Role,
137138
Supports,
138-
TextPart,
139139
ToolDefinition,
140140
)
141141
from genkit.plugins.google_genai.models.utils import PartConverter
142142

143+
144+
class GeminiConfigSchema(genai_types.GenerateContentConfig):
145+
pass
146+
147+
143148
GEMINI_1_0_PRO = ModelInfo(
144149
label='Google AI - Gemini Pro',
145150
versions=['gemini-pro', 'gemini-1.0-pro-latest', 'gemini-1.0-pro-001'],
@@ -605,6 +610,8 @@ def _genkit_to_googleai_cfg(
605610
temperature=request_config.temperature,
606611
stop_sequences=request_config.stop_sequences,
607612
)
613+
elif isinstance(request_config, GeminiConfigSchema):
614+
cfg = request_config
608615
elif isinstance(request_config, dict):
609616
cfg = genai_types.GenerateContentConfig(**request_config)
610617

py/samples/hello-google-genai/src/hello.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,15 @@
2828
from genkit.plugins.google_ai.models import gemini
2929
from genkit.plugins.google_genai import (
3030
EmbeddingTaskType,
31+
GeminiConfigSchema,
3132
GeminiEmbeddingModels,
3233
GoogleGenai,
3334
google_genai_name,
3435
)
3536

3637
ai = Genkit(
3738
plugins=[GoogleGenai()],
38-
model=google_genai_name('gemini-2.0-flash'),
39+
model=google_genai_name('gemini-2.0-flash-exp'),
3940
)
4041

4142

@@ -191,6 +192,15 @@ async def generate_character_unconstrained(name: str, ctx):
191192
return result.output
192193

193194

195+
@ai.flow()
196+
async def generate_images(name: str, ctx):
197+
result = await ai.generate(
198+
prompt=f'tell me a about the Eifel Tower with photos',
199+
config=GeminiConfigSchema(response_modalities=['text', 'image']),
200+
)
201+
return result
202+
203+
194204
async def main() -> None:
195205
print(await say_hi(', tell me a joke'))
196206

0 commit comments

Comments
 (0)