Skip to content

Commit 9481f9b

Browse files
committed
feat: Make MAX_DOWNLOADED_QUERY_RESULT_ROWS configurable
- Add max_downloaded_rows field to BigQueryToolConfig with default value 50 - Add optional max_rows parameter to execute_sql function - Parameter takes precedence over config when both provided - Maintain backward compatibility with existing code - Add comprehensive tests for config and parameter functionality - Update truncation logic to use resolved max_downloaded_rows value
1 parent c8f8b4a commit 9481f9b

File tree

4 files changed

+109
-4
lines changed

4 files changed

+109
-4
lines changed

src/google/adk/tools/bigquery/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,10 @@ class BigQueryToolConfig(BaseModel):
5454
By default, the tool will allow only read operations. This behaviour may
5555
change in future versions.
5656
"""
57+
58+
max_downloaded_rows: int = 50
59+
"""Maximum number of rows to download from query results.
60+
61+
By default, limits query results to 50 rows to prevent excessive memory usage.
62+
Set to a higher value if you need more rows returned.
63+
"""

src/google/adk/tools/bigquery/query_tool.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from .config import BigQueryToolConfig
2828
from .config import WriteMode
2929

30-
MAX_DOWNLOADED_QUERY_RESULT_ROWS = 50
3130
BIGQUERY_SESSION_INFO_KEY = "bigquery_session_info"
3231

3332

@@ -37,6 +36,7 @@ def execute_sql(
3736
credentials: Credentials,
3837
config: BigQueryToolConfig,
3938
tool_context: ToolContext,
39+
max_rows: int | None = None,
4040
) -> dict:
4141
"""Run a BigQuery or BigQuery ML SQL query in the project and return the result.
4242
@@ -160,7 +160,7 @@ def execute_sql(
160160
query,
161161
job_config=job_config,
162162
project=project_id,
163-
max_results=MAX_DOWNLOADED_QUERY_RESULT_ROWS,
163+
max_results=max_rows or (config.max_downloaded_rows if config else 50),
164164
)
165165
rows = []
166166
for row in row_iterator:
@@ -175,9 +175,10 @@ def execute_sql(
175175
rows.append(row_values)
176176

177177
result = {"status": "SUCCESS", "rows": rows}
178+
max_downloaded_rows = max_rows or (config.max_downloaded_rows if config else 50)
178179
if (
179-
MAX_DOWNLOADED_QUERY_RESULT_ROWS is not None
180-
and len(rows) == MAX_DOWNLOADED_QUERY_RESULT_ROWS
180+
max_downloaded_rows is not None
181+
and len(rows) == max_downloaded_rows
181182
):
182183
result["result_is_likely_truncated"] = True
183184
return result

tests/unittests/tools/bigquery/test_bigquery_query_tool.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,3 +973,86 @@ def test_execute_sql_result_dtype(
973973
# Test the tool worked without invoking default auth
974974
result = execute_sql(project, query, credentials, tool_config, tool_context)
975975
assert result == {"status": "SUCCESS", "rows": tool_result_rows}
976+
977+
978+
def test_execute_sql_max_rows_config():
979+
"""Test execute_sql tool respects max_downloaded_rows from config."""
980+
project = "my_project"
981+
query = "SELECT 123 AS num"
982+
statement_type = "SELECT"
983+
query_result = [{"num": i} for i in range(20)] # 20 rows
984+
credentials = mock.create_autospec(Credentials, instance=True)
985+
tool_config = BigQueryToolConfig(max_downloaded_rows=10)
986+
tool_context = mock.create_autospec(ToolContext, instance=True)
987+
988+
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
989+
bq_client = Client.return_value
990+
query_job = mock.create_autospec(bigquery.QueryJob)
991+
query_job.statement_type = statement_type
992+
bq_client.query.return_value = query_job
993+
bq_client.query_and_wait.return_value = query_result[:10]
994+
995+
result = execute_sql(project, query, credentials, tool_config, tool_context)
996+
997+
# Check that max_results was called with config value
998+
bq_client.query_and_wait.assert_called_once()
999+
call_args = bq_client.query_and_wait.call_args
1000+
assert call_args.kwargs["max_results"] == 10
1001+
1002+
# Check truncation flag is set
1003+
assert result["status"] == "SUCCESS"
1004+
assert result["result_is_likely_truncated"] is True
1005+
1006+
1007+
def test_execute_sql_max_rows_parameter():
1008+
"""Test execute_sql tool respects max_rows parameter."""
1009+
project = "my_project"
1010+
query = "SELECT 123 AS num"
1011+
statement_type = "SELECT"
1012+
query_result = [{"num": i} for i in range(20)] # 20 rows
1013+
credentials = mock.create_autospec(Credentials, instance=True)
1014+
tool_config = BigQueryToolConfig(max_downloaded_rows=10)
1015+
tool_context = mock.create_autospec(ToolContext, instance=True)
1016+
1017+
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
1018+
bq_client = Client.return_value
1019+
query_job = mock.create_autospec(bigquery.QueryJob)
1020+
query_job.statement_type = statement_type
1021+
bq_client.query.return_value = query_job
1022+
bq_client.query_and_wait.return_value = query_result[:5]
1023+
1024+
# Parameter should override config
1025+
result = execute_sql(project, query, credentials, tool_config, tool_context, max_rows=5)
1026+
1027+
# Check that max_results was called with parameter value
1028+
bq_client.query_and_wait.assert_called_once()
1029+
call_args = bq_client.query_and_wait.call_args
1030+
assert call_args.kwargs["max_results"] == 5
1031+
1032+
# Check truncation flag is set
1033+
assert result["status"] == "SUCCESS"
1034+
assert result["result_is_likely_truncated"] is True
1035+
1036+
1037+
def test_execute_sql_no_truncation():
1038+
"""Test execute_sql tool when results are not truncated."""
1039+
project = "my_project"
1040+
query = "SELECT 123 AS num"
1041+
statement_type = "SELECT"
1042+
query_result = [{"num": i} for i in range(3)] # Only 3 rows
1043+
credentials = mock.create_autospec(Credentials, instance=True)
1044+
tool_config = BigQueryToolConfig(max_downloaded_rows=10)
1045+
tool_context = mock.create_autospec(ToolContext, instance=True)
1046+
1047+
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
1048+
bq_client = Client.return_value
1049+
query_job = mock.create_autospec(bigquery.QueryJob)
1050+
query_job.statement_type = statement_type
1051+
bq_client.query.return_value = query_job
1052+
bq_client.query_and_wait.return_value = query_result
1053+
1054+
result = execute_sql(project, query, credentials, tool_config, tool_context)
1055+
1056+
# Check no truncation flag when fewer rows than limit
1057+
assert result["status"] == "SUCCESS"
1058+
assert "result_is_likely_truncated" not in result

tests/unittests/tools/bigquery/test_bigquery_tool_config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,17 @@ def test_bigquery_tool_config_experimental_warning():
2525
match="Config defaults may have breaking change in the future.",
2626
):
2727
BigQueryToolConfig()
28+
29+
30+
def test_bigquery_tool_config_max_downloaded_rows_default():
31+
"""Test BigQueryToolConfig max_downloaded_rows default value."""
32+
with pytest.warns(UserWarning):
33+
config = BigQueryToolConfig()
34+
assert config.max_downloaded_rows == 50
35+
36+
37+
def test_bigquery_tool_config_max_downloaded_rows_custom():
38+
"""Test BigQueryToolConfig max_downloaded_rows custom value."""
39+
with pytest.warns(UserWarning):
40+
config = BigQueryToolConfig(max_downloaded_rows=100)
41+
assert config.max_downloaded_rows == 100

0 commit comments

Comments
 (0)