Skip to content

Commit 5adbf95

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: stop updating write mode in the global settings during tool execution
Two tools - detect_anomalies and analyze_contribution are modifying the settings passed to them, which is not right as the settings are held and passed by the top level, which means several tools share the same settings. PiperOrigin-RevId: 832081738
1 parent 23ad40b commit 5adbf95

File tree

2 files changed

+102
-20
lines changed

2 files changed

+102
-20
lines changed

src/google/adk/tools/bigquery/query_tool.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,21 +1093,23 @@ def analyze_contribution(
10931093
"""
10941094

10951095
# Create a session and run the create model query.
1096-
original_write_mode = settings.write_mode
10971096
try:
1098-
if original_write_mode == WriteMode.BLOCKED:
1097+
execute_sql_settings = settings
1098+
if execute_sql_settings.write_mode == WriteMode.BLOCKED:
10991099
raise ValueError("analyze_contribution is not allowed in this session.")
1100-
elif original_write_mode != WriteMode.PROTECTED:
1100+
elif execute_sql_settings.write_mode != WriteMode.PROTECTED:
11011101
# Running create temp model requires a session. So we set the write mode
11021102
# to PROTECTED to run the create model query and job query in the same
11031103
# session.
1104-
settings.write_mode = WriteMode.PROTECTED
1104+
execute_sql_settings = settings.model_copy(
1105+
update={"write_mode": WriteMode.PROTECTED}
1106+
)
11051107

11061108
result = _execute_sql(
11071109
project_id=project_id,
11081110
query=create_model_query,
11091111
credentials=credentials,
1110-
settings=settings,
1112+
settings=execute_sql_settings,
11111113
tool_context=tool_context,
11121114
caller_id="analyze_contribution",
11131115
)
@@ -1118,18 +1120,15 @@ def analyze_contribution(
11181120
project_id=project_id,
11191121
query=get_insights_query,
11201122
credentials=credentials,
1121-
settings=settings,
1123+
settings=execute_sql_settings,
11221124
tool_context=tool_context,
11231125
caller_id="analyze_contribution",
11241126
)
11251127
except Exception as ex: # pylint: disable=broad-except
11261128
return {
11271129
"status": "ERROR",
1128-
"error_details": f"Error during analyze_contribution: {str(ex)}",
1130+
"error_details": f"Error during analyze_contribution: {repr(ex)}",
11291131
}
1130-
finally:
1131-
# Restore the original write mode.
1132-
settings.write_mode == original_write_mode
11331132

11341133
return result
11351134

@@ -1327,21 +1326,23 @@ def detect_anomalies(
13271326
"""
13281327

13291328
# Create a session and run the create model query.
1330-
original_write_mode = settings.write_mode
13311329
try:
1332-
if settings.write_mode == WriteMode.BLOCKED:
1330+
execute_sql_settings = settings
1331+
if execute_sql_settings.write_mode == WriteMode.BLOCKED:
13331332
raise ValueError("anomaly detection is not allowed in this session.")
1334-
elif original_write_mode != WriteMode.PROTECTED:
1333+
elif execute_sql_settings.write_mode != WriteMode.PROTECTED:
13351334
# Running create temp model requires a session. So we set the write mode
13361335
# to PROTECTED to run the create model query and job query in the same
13371336
# session.
1338-
settings.write_mode = WriteMode.PROTECTED
1337+
execute_sql_settings = settings.model_copy(
1338+
update={"write_mode": WriteMode.PROTECTED}
1339+
)
13391340

13401341
result = _execute_sql(
13411342
project_id=project_id,
13421343
query=create_model_query,
13431344
credentials=credentials,
1344-
settings=settings,
1345+
settings=execute_sql_settings,
13451346
tool_context=tool_context,
13461347
caller_id="detect_anomalies",
13471348
)
@@ -1352,17 +1353,14 @@ def detect_anomalies(
13521353
project_id=project_id,
13531354
query=anomaly_detection_query,
13541355
credentials=credentials,
1355-
settings=settings,
1356+
settings=execute_sql_settings,
13561357
tool_context=tool_context,
13571358
caller_id="detect_anomalies",
13581359
)
13591360
except Exception as ex: # pylint: disable=broad-except
13601361
return {
13611362
"status": "ERROR",
1362-
"error_details": f"Error during anomaly detection: {str(ex)}",
1363+
"error_details": f"Error during anomaly detection: {repr(ex)}",
13631364
}
1364-
finally:
1365-
# Restore the original write mode.
1366-
settings.write_mode == original_write_mode
13671365

13681366
return result

tests/unittests/tools/bigquery/test_bigquery_query_tool.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1849,3 +1849,87 @@ def test_execute_sql_maximum_bytes_billed_config():
18491849
bq_client.query_and_wait.assert_called_once()
18501850
call_args = bq_client.query_and_wait.call_args
18511851
assert call_args.kwargs["job_config"].maximum_bytes_billed == 11_000_000
1852+
1853+
1854+
@pytest.mark.parametrize(
1855+
("tool_call",),
1856+
[
1857+
pytest.param(
1858+
lambda settings, tool_context: execute_sql(
1859+
project_id="test-project",
1860+
query="SELECT * FROM `test-dataset.test-table`",
1861+
credentials=mock.create_autospec(Credentials, instance=True),
1862+
settings=settings,
1863+
tool_context=tool_context,
1864+
),
1865+
id="execute-sql",
1866+
),
1867+
pytest.param(
1868+
lambda settings, tool_context: forecast(
1869+
project_id="test-project",
1870+
history_data="SELECT * FROM `test-dataset.test-table`",
1871+
timestamp_col="ts_col",
1872+
data_col="data_col",
1873+
credentials=mock.create_autospec(Credentials, instance=True),
1874+
settings=settings,
1875+
tool_context=tool_context,
1876+
),
1877+
id="forecast",
1878+
),
1879+
pytest.param(
1880+
lambda settings, tool_context: analyze_contribution(
1881+
project_id="test-project",
1882+
input_data="test-dataset.test-table",
1883+
dimension_id_cols=["dim1", "dim2"],
1884+
contribution_metric="SUM(metric)",
1885+
is_test_col="is_test",
1886+
credentials=mock.create_autospec(Credentials, instance=True),
1887+
settings=settings,
1888+
tool_context=tool_context,
1889+
),
1890+
id="analyze-contribution",
1891+
),
1892+
pytest.param(
1893+
lambda settings, tool_context: detect_anomalies(
1894+
project_id="test-project",
1895+
history_data="SELECT * FROM `test-dataset.test-table`",
1896+
times_series_timestamp_col="ts_timestamp",
1897+
times_series_data_col="ts_data",
1898+
credentials=mock.create_autospec(Credentials, instance=True),
1899+
settings=settings,
1900+
tool_context=tool_context,
1901+
),
1902+
id="detect-anomalies",
1903+
),
1904+
],
1905+
)
1906+
def test_tool_call_doesnt_change_global_settings(tool_call):
1907+
"""Test query tools don't change global settings."""
1908+
settings = BigQueryToolConfig(write_mode=WriteMode.ALLOWED)
1909+
tool_context = mock.create_autospec(ToolContext, instance=True)
1910+
tool_context.state.get.return_value = (
1911+
"test-bq-session-id",
1912+
"_anonymous_dataset",
1913+
)
1914+
1915+
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
1916+
# The mock instance
1917+
bq_client = Client.return_value
1918+
1919+
# Simulate the result of query API
1920+
query_job = mock.create_autospec(bigquery.QueryJob)
1921+
query_job.destination.dataset_id = "_anonymous_dataset"
1922+
bq_client.query.return_value = query_job
1923+
bq_client.query_and_wait.return_value = []
1924+
1925+
# Test settings write mode before
1926+
assert settings.write_mode == WriteMode.ALLOWED
1927+
1928+
# Call the tool
1929+
result = tool_call(settings, tool_context)
1930+
1931+
# Test successfull executeion of the tool
1932+
assert result == {"status": "SUCCESS", "rows": []}
1933+
1934+
# Test settings write mode after
1935+
assert settings.write_mode == WriteMode.ALLOWED

0 commit comments

Comments
 (0)