Skip to content

Commit 2efc184

Browse files
sararobcopybara-github
authored andcommitted
chore: Add support for abstract types in AFC
PiperOrigin-RevId: 831873580
1 parent 6bb0b74 commit 2efc184

File tree

3 files changed

+219
-21
lines changed

3 files changed

+219
-21
lines changed

src/google/adk/tools/_automatic_function_calling_util.py

Lines changed: 84 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -296,20 +296,59 @@ def from_function_with_options(
296296
) -> 'types.FunctionDeclaration':
297297

298298
parameters_properties = {}
299-
for name, param in inspect.signature(func).parameters.items():
300-
if param.kind in (
301-
inspect.Parameter.POSITIONAL_OR_KEYWORD,
302-
inspect.Parameter.KEYWORD_ONLY,
303-
inspect.Parameter.POSITIONAL_ONLY,
304-
):
305-
# This snippet catches the case when type hints are stored as strings
306-
if isinstance(param.annotation, str):
307-
param = param.replace(annotation=typing.get_type_hints(func)[name])
308-
309-
schema = _function_parameter_parse_util._parse_schema_from_parameter(
310-
variant, param, func.__name__
311-
)
312-
parameters_properties[name] = schema
299+
parameters_json_schema = {}
300+
try:
301+
annotation_under_future = typing.get_type_hints(func)
302+
except TypeError:
303+
# This can happen if func is a mock object
304+
annotation_under_future = {}
305+
try:
306+
for name, param in inspect.signature(func).parameters.items():
307+
if param.kind in (
308+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
309+
inspect.Parameter.KEYWORD_ONLY,
310+
inspect.Parameter.POSITIONAL_ONLY,
311+
):
312+
param = _function_parameter_parse_util._handle_params_as_deferred_annotations(
313+
param, annotation_under_future, name
314+
)
315+
316+
schema = _function_parameter_parse_util._parse_schema_from_parameter(
317+
variant, param, func.__name__
318+
)
319+
parameters_properties[name] = schema
320+
except ValueError:
321+
# If the function has complex parameter types that fail in _parse_schema_from_parameter,
322+
# we try to generate a json schema for the parameter using pydantic.TypeAdapter.
323+
parameters_properties = {}
324+
for name, param in inspect.signature(func).parameters.items():
325+
if param.kind in (
326+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
327+
inspect.Parameter.KEYWORD_ONLY,
328+
inspect.Parameter.POSITIONAL_ONLY,
329+
):
330+
try:
331+
if param.annotation == inspect.Parameter.empty:
332+
param = param.replace(annotation=Any)
333+
334+
param = _function_parameter_parse_util._handle_params_as_deferred_annotations(
335+
param, annotation_under_future, name
336+
)
337+
338+
_function_parameter_parse_util._raise_for_invalid_enum_value(param)
339+
340+
json_schema_dict = _function_parameter_parse_util._generate_json_schema_for_parameter(
341+
param
342+
)
343+
344+
parameters_json_schema[name] = types.Schema.model_validate(
345+
json_schema_dict
346+
)
347+
except Exception as e:
348+
_function_parameter_parse_util._raise_for_unsupported_param(
349+
param, func.__name__, e
350+
)
351+
313352
declaration = types.FunctionDeclaration(
314353
name=func.__name__,
315354
description=func.__doc__,
@@ -324,6 +363,12 @@ def from_function_with_options(
324363
declaration.parameters
325364
)
326365
)
366+
elif parameters_json_schema:
367+
declaration.parameters = types.Schema(
368+
type='OBJECT',
369+
properties=parameters_json_schema,
370+
)
371+
327372
if variant == GoogleLLMVariant.GEMINI_API:
328373
return declaration
329374

@@ -372,17 +417,35 @@ def from_function_with_options(
372417
inspect.Parameter.POSITIONAL_OR_KEYWORD,
373418
annotation=return_annotation,
374419
)
375-
# This snippet catches the case when type hints are stored as strings
376420
if isinstance(return_value.annotation, str):
377421
return_value = return_value.replace(
378422
annotation=typing.get_type_hints(func)['return']
379423
)
380424

381-
declaration.response = (
382-
_function_parameter_parse_util._parse_schema_from_parameter(
383-
variant,
384-
return_value,
385-
func.__name__,
425+
response_schema: Optional[types.Schema] = None
426+
response_json_schema: Optional[Union[Dict[str, Any], types.Schema]] = None
427+
try:
428+
response_schema = (
429+
_function_parameter_parse_util._parse_schema_from_parameter(
430+
variant,
431+
return_value,
432+
func.__name__,
433+
)
434+
)
435+
except ValueError:
436+
try:
437+
response_json_schema = (
438+
_function_parameter_parse_util._generate_json_schema_for_parameter(
439+
return_value
440+
)
386441
)
387-
)
442+
response_json_schema = types.Schema.model_validate(response_json_schema)
443+
except Exception as e:
444+
_function_parameter_parse_util._raise_for_unsupported_param(
445+
return_value, func.__name__, e
446+
)
447+
if response_schema:
448+
declaration.response = response_schema
449+
elif response_json_schema:
450+
declaration.response = response_json_schema
388451
return declaration

src/google/adk/tools/_function_parameter_parse_util.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,91 @@
4949
logger = logging.getLogger('google_adk.' + __name__)
5050

5151

52+
def _handle_params_as_deferred_annotations(
53+
param: inspect.Parameter, annotation_under_future: dict[str, Any], name: str
54+
) -> inspect.Parameter:
55+
"""Catches the case when type hints are stored as strings."""
56+
if isinstance(param.annotation, str):
57+
param = param.replace(annotation=annotation_under_future[name])
58+
return param
59+
60+
61+
def _add_unevaluated_items_to_fixed_len_tuple_schema(
62+
json_schema: dict[str, Any],
63+
) -> dict[str, Any]:
64+
"""Adds 'unevaluatedItems': False to schemas for fixed-length tuples.
65+
66+
For example, the schema for a parameter of type `tuple[float, float]` would
67+
be:
68+
{
69+
"type": "array",
70+
"prefixItems": [
71+
{
72+
"type": "number"
73+
},
74+
{
75+
"type": "number"
76+
},
77+
],
78+
"minItems": 2,
79+
"maxItems": 2,
80+
"unevaluatedItems": False
81+
}
82+
83+
"""
84+
if (
85+
json_schema.get('maxItems')
86+
and (
87+
json_schema.get('prefixItems')
88+
and len(json_schema['prefixItems']) == json_schema['maxItems']
89+
)
90+
and json_schema.get('type') == 'array'
91+
):
92+
json_schema['unevaluatedItems'] = False
93+
return json_schema
94+
95+
96+
def _raise_for_unsupported_param(
97+
param: inspect.Parameter,
98+
func_name: str,
99+
exception: Exception,
100+
) -> None:
101+
raise ValueError(
102+
f'Failed to parse the parameter {param} of function {func_name} for'
103+
' automatic function calling.Automatic function calling works best with'
104+
' simpler function signature schema, consider manually parsing your'
105+
f' function declaration for function {func_name}.'
106+
) from exception
107+
108+
109+
def _raise_for_invalid_enum_value(param: inspect.Parameter):
110+
"""Raises an error if the default value is not a valid enum value."""
111+
if inspect.isclass(param.annotation) and issubclass(param.annotation, Enum):
112+
if param.default is not inspect.Parameter.empty and param.default not in [
113+
e.value for e in param.annotation
114+
]:
115+
raise ValueError(
116+
f'Default value {param.default} is not a valid enum value for'
117+
f' {param.annotation}.'
118+
)
119+
120+
121+
def _generate_json_schema_for_parameter(
122+
param: inspect.Parameter,
123+
) -> dict[str, Any]:
124+
"""Generates a JSON schema for a parameter using pydantic.TypeAdapter."""
125+
126+
param_schema_adapter = pydantic.TypeAdapter(
127+
param.annotation,
128+
config=pydantic.ConfigDict(arbitrary_types_allowed=True),
129+
)
130+
json_schema_dict = param_schema_adapter.json_schema()
131+
json_schema_dict = _add_unevaluated_items_to_fixed_len_tuple_schema(
132+
json_schema_dict
133+
)
134+
return json_schema_dict
135+
136+
52137
def _is_builtin_primitive_or_compound(
53138
annotation: inspect.Parameter.annotation,
54139
) -> bool:

tests/unittests/tools/test_from_function_with_options.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections.abc import Sequence
1516
from typing import Any
1617
from typing import Dict
1718

@@ -192,3 +193,52 @@ def test_function() -> None:
192193
# VERTEX_AI should have response schema for None return
193194
assert declaration.response is not None
194195
assert declaration.response.type == types.Type.NULL
196+
197+
198+
def test_from_function_with_collections_type_parameter():
199+
"""Test from_function_with_options with collections type parameter."""
200+
201+
def test_function(
202+
artifact_key: str,
203+
input_edit_ids: Sequence[str],
204+
) -> str:
205+
"""Saves a sequence of edit IDs."""
206+
return f'Saved {len(input_edit_ids)} edit IDs for artifact {artifact_key}'
207+
208+
declaration = _automatic_function_calling_util.from_function_with_options(
209+
test_function, GoogleLLMVariant.VERTEX_AI
210+
)
211+
212+
assert declaration.name == 'test_function'
213+
assert declaration.parameters.type == types.Type.OBJECT
214+
assert (
215+
declaration.parameters.properties['artifact_key'].type
216+
== types.Type.STRING
217+
)
218+
assert (
219+
declaration.parameters.properties['input_edit_ids'].type
220+
== types.Type.ARRAY
221+
)
222+
assert (
223+
declaration.parameters.properties['input_edit_ids'].items.type
224+
== types.Type.STRING
225+
)
226+
assert declaration.response.type == types.Type.STRING
227+
228+
229+
def test_from_function_with_collections_return_type():
230+
"""Test from_function_with_options with collections return type."""
231+
232+
def test_function(
233+
names: list[str],
234+
) -> Sequence[str]:
235+
"""Returns a sequence of names."""
236+
return names
237+
238+
declaration = _automatic_function_calling_util.from_function_with_options(
239+
test_function, GoogleLLMVariant.VERTEX_AI
240+
)
241+
242+
assert declaration.name == 'test_function'
243+
assert declaration.response.type == types.Type.ARRAY
244+
assert declaration.response.items.type == types.Type.STRING

0 commit comments

Comments
 (0)