Skip to content

Commit 81a53b5

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: add table metadata info into Spanner tool get_table_schema and fix the key usage info
This can help to provide more context and information about the table, like parent-child relationship, and row deletion policy etc. PiperOrigin-RevId: 797562858
1 parent 52a3d6c commit 81a53b5

File tree

3 files changed

+327
-17
lines changed

3 files changed

+327
-17
lines changed

contributing/samples/spanner/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ distributed via the `google.adk.tools.spanner` module. These tools include:
2323

2424
1. `get_table_schema`
2525

26-
Fetches Spanner database table schema.
26+
Fetches Spanner database table schema and metadata information.
2727

2828
1. `execute_sql`
2929

src/google/adk/tools/spanner/metadata_tool.py

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def get_table_schema(
8181
credentials: Credentials,
8282
named_schema: str = "",
8383
) -> dict:
84-
"""Get schema information about a Spanner table.
84+
"""Get schema and metadata information about a Spanner table.
8585
8686
Args:
8787
project_id (str): The Google Cloud project id.
@@ -102,7 +102,8 @@ def get_table_schema(
102102
"status": "SUCCESS",
103103
"results":
104104
{
105-
'colA': {
105+
"schema": {
106+
'colA': {
106107
'SPANNER_TYPE': 'STRING(1024)',
107108
'TABLE_SCHEMA': '',
108109
'ORDINAL_POSITION': 1,
@@ -111,14 +112,31 @@ def get_table_schema(
111112
'IS_GENERATED': 'NEVER',
112113
'GENERATION_EXPRESSION': None,
113114
'IS_STORED': None,
114-
'KEY_COLUMN_USAGE': { # This part is added if it's a key column
115-
'CONSTRAINT_NAME': 'PK_Table1',
116-
'ORDINAL_POSITION': 1,
117-
'POSITION_IN_UNIQUE_CONSTRAINT': None
118-
}
115+
'KEY_COLUMN_USAGE': [
116+
# This part is added if it's a key column
117+
{
118+
'CONSTRAINT_NAME': 'PK_Table1',
119+
'ORDINAL_POSITION': 1,
120+
'POSITION_IN_UNIQUE_CONSTRAINT': None
121+
}
122+
]
123+
},
124+
'colB': { ... },
125+
...
119126
},
120-
'colB': { ... },
121-
...
127+
"metadata": [
128+
{
129+
'TABLE_SCHEMA': '',
130+
'TABLE_NAME': 'MyTable',
131+
'TABLE_TYPE': 'BASE TABLE',
132+
'PARENT_TABLE_NAME': NULL,
133+
'ON_DELETE_ACTION': NULL,
134+
'SPANNER_STATE': 'COMMITTED',
135+
'INTERLEAVE_TYPE': NULL,
136+
'ROW_DELETION_POLICY_EXPRESSION':
137+
'OLDER_THAN(CreatedAt, INTERVAL 1 DAY)',
138+
}
139+
]
122140
}
123141
"""
124142

@@ -160,7 +178,24 @@ def get_table_schema(
160178
"named_schema": spanner_param_types.STRING,
161179
}
162180

163-
schema = {}
181+
table_metadata_query = """
182+
SELECT
183+
TABLE_SCHEMA,
184+
TABLE_NAME,
185+
TABLE_TYPE,
186+
PARENT_TABLE_NAME,
187+
ON_DELETE_ACTION,
188+
SPANNER_STATE,
189+
INTERLEAVE_TYPE,
190+
ROW_DELETION_POLICY_EXPRESSION
191+
FROM
192+
INFORMATION_SCHEMA.TABLES
193+
WHERE
194+
TABLE_NAME = @table_name
195+
AND TABLE_SCHEMA = @named_schema;
196+
"""
197+
198+
results = {"schema": {}, "metadata": []}
164199
try:
165200
spanner_client = client.get_spanner_client(
166201
project=project_id, credentials=credentials
@@ -200,7 +235,7 @@ def get_table_schema(
200235
"GENERATION_EXPRESSION": generation_expression,
201236
"IS_STORED": is_stored,
202237
}
203-
schema[column_name] = column_metadata
238+
results["schema"][column_name] = column_metadata
204239

205240
key_column_result_set = snapshot.execute_sql(
206241
key_column_usage_query, params=params, param_types=param_types
@@ -219,15 +254,33 @@ def get_table_schema(
219254
"POSITION_IN_UNIQUE_CONSTRAINT": position_in_unique_constraint,
220255
}
221256
# Attach key column info to the existing column schema entry
222-
if column_name in schema:
223-
schema[column_name]["KEY_COLUMN_USAGE"] = key_column_properties
257+
if column_name in results["schema"]:
258+
results["schema"][column_name].setdefault(
259+
"KEY_COLUMN_USAGE", []
260+
).append(key_column_properties)
261+
262+
table_metadata_result_set = snapshot.execute_sql(
263+
table_metadata_query, params=params, param_types=param_types
264+
)
265+
for row in table_metadata_result_set:
266+
metadata_result = {
267+
"TABLE_SCHEMA": row[0],
268+
"TABLE_NAME": row[1],
269+
"TABLE_TYPE": row[2],
270+
"PARENT_TABLE_NAME": row[3],
271+
"ON_DELETE_ACTION": row[4],
272+
"SPANNER_STATE": row[5],
273+
"INTERLEAVE_TYPE": row[6],
274+
"ROW_DELETION_POLICY_EXPRESSION": row[7],
275+
}
276+
results["metadata"].append(metadata_result)
224277

225278
try:
226-
json.dumps(schema)
279+
json.dumps(results)
227280
except:
228-
schema = str(schema)
281+
results = str(results)
229282

230-
return {"status": "SUCCESS", "results": schema}
283+
return {"status": "SUCCESS", "results": results}
231284
except Exception as ex:
232285
return {
233286
"status": "ERROR",
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest.mock import MagicMock
16+
from unittest.mock import patch
17+
18+
from google.adk.tools.spanner import metadata_tool
19+
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect
20+
import pytest
21+
22+
23+
@pytest.fixture
24+
def mock_credentials():
25+
return MagicMock()
26+
27+
28+
@pytest.fixture
29+
def mock_spanner_ids():
30+
return {
31+
"project_id": "test-project",
32+
"instance_id": "test-instance",
33+
"database_id": "test-database",
34+
"table_name": "test-table",
35+
}
36+
37+
38+
@patch("google.adk.tools.spanner.client.get_spanner_client")
39+
def test_list_table_names_success(
40+
mock_get_spanner_client, mock_spanner_ids, mock_credentials
41+
):
42+
"""Test list_table_names function with success."""
43+
mock_spanner_client = MagicMock()
44+
mock_instance = MagicMock()
45+
mock_database = MagicMock()
46+
mock_table = MagicMock()
47+
mock_table.table_id = "table1"
48+
mock_database.list_tables.return_value = [mock_table]
49+
mock_instance.database.return_value = mock_database
50+
mock_spanner_client.instance.return_value = mock_instance
51+
mock_get_spanner_client.return_value = mock_spanner_client
52+
53+
result = metadata_tool.list_table_names(
54+
mock_spanner_ids["project_id"],
55+
mock_spanner_ids["instance_id"],
56+
mock_spanner_ids["database_id"],
57+
mock_credentials,
58+
)
59+
assert result["status"] == "SUCCESS"
60+
assert result["results"] == ["table1"]
61+
62+
63+
@patch("google.adk.tools.spanner.client.get_spanner_client")
64+
def test_list_table_names_error(
65+
mock_get_spanner_client, mock_spanner_ids, mock_credentials
66+
):
67+
"""Test list_table_names function with error."""
68+
mock_get_spanner_client.side_effect = Exception("Test Exception")
69+
result = metadata_tool.list_table_names(
70+
mock_spanner_ids["project_id"],
71+
mock_spanner_ids["instance_id"],
72+
mock_spanner_ids["database_id"],
73+
mock_credentials,
74+
)
75+
assert result["status"] == "ERROR"
76+
assert result["error_details"] == "Test Exception"
77+
78+
79+
@patch("google.adk.tools.spanner.client.get_spanner_client")
80+
def test_get_table_schema_success(
81+
mock_get_spanner_client, mock_spanner_ids, mock_credentials
82+
):
83+
"""Test get_table_schema function with success."""
84+
mock_spanner_client = MagicMock()
85+
mock_instance = MagicMock()
86+
mock_database = MagicMock()
87+
mock_snapshot = MagicMock()
88+
89+
mock_columns_result = [(
90+
"col1", # COLUMN_NAME
91+
"", # TABLE_SCHEMA
92+
"STRING(MAX)", # SPANNER_TYPE
93+
1, # ORDINAL_POSITION
94+
None, # COLUMN_DEFAULT
95+
"NO", # IS_NULLABLE
96+
"NEVER", # IS_GENERATED
97+
None, # GENERATION_EXPRESSION
98+
None, # IS_STORED
99+
)]
100+
101+
mock_key_columns_result = [(
102+
"col1", # COLUMN_NAME
103+
"PK_Table", # CONSTRAINT_NAME
104+
1, # ORDINAL_POSITION
105+
None, # POSITION_IN_UNIQUE_CONSTRAINT
106+
)]
107+
108+
mock_table_metadata_result = [(
109+
"", # TABLE_SCHEMA
110+
"test_table", # TABLE_NAME
111+
"BASE TABLE", # TABLE_TYPE
112+
None, # PARENT_TABLE_NAME
113+
None, # ON_DELETE_ACTION
114+
"COMMITTED", # SPANNER_STATE
115+
None, # INTERLEAVE_TYPE
116+
"OLDER_THAN(CreatedAt, INTERVAL 1 DAY)", # ROW_DELETION_POLICY_EXPRESSION
117+
)]
118+
119+
mock_snapshot.execute_sql.side_effect = [
120+
mock_columns_result,
121+
mock_key_columns_result,
122+
mock_table_metadata_result,
123+
]
124+
125+
mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
126+
mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
127+
mock_instance.database.return_value = mock_database
128+
mock_spanner_client.instance.return_value = mock_instance
129+
mock_get_spanner_client.return_value = mock_spanner_client
130+
131+
result = metadata_tool.get_table_schema(
132+
mock_spanner_ids["project_id"],
133+
mock_spanner_ids["instance_id"],
134+
mock_spanner_ids["database_id"],
135+
mock_spanner_ids["table_name"],
136+
mock_credentials,
137+
)
138+
139+
assert result["status"] == "SUCCESS"
140+
assert "col1" in result["results"]["schema"]
141+
assert result["results"]["schema"]["col1"]["SPANNER_TYPE"] == "STRING(MAX)"
142+
assert "KEY_COLUMN_USAGE" in result["results"]["schema"]["col1"]
143+
assert (
144+
result["results"]["schema"]["col1"]["KEY_COLUMN_USAGE"][0][
145+
"CONSTRAINT_NAME"
146+
]
147+
== "PK_Table"
148+
)
149+
assert "metadata" in result["results"]
150+
assert result["results"]["metadata"][0]["TABLE_NAME"] == "test_table"
151+
assert (
152+
result["results"]["metadata"][0]["ROW_DELETION_POLICY_EXPRESSION"]
153+
== "OLDER_THAN(CreatedAt, INTERVAL 1 DAY)"
154+
)
155+
156+
157+
@patch("google.adk.tools.spanner.client.get_spanner_client")
158+
def test_list_table_indexes_success(
159+
mock_get_spanner_client, mock_spanner_ids, mock_credentials
160+
):
161+
"""Test list_table_indexes function with success."""
162+
mock_spanner_client = MagicMock()
163+
mock_instance = MagicMock()
164+
mock_database = MagicMock()
165+
mock_snapshot = MagicMock()
166+
mock_result_set = MagicMock()
167+
mock_result_set.__iter__.return_value = iter([(
168+
"PRIMARY_KEY",
169+
"",
170+
"PRIMARY_KEY",
171+
"",
172+
True,
173+
False,
174+
None,
175+
)])
176+
mock_snapshot.execute_sql.return_value = mock_result_set
177+
mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
178+
mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
179+
mock_instance.database.return_value = mock_database
180+
mock_spanner_client.instance.return_value = mock_instance
181+
mock_get_spanner_client.return_value = mock_spanner_client
182+
183+
result = metadata_tool.list_table_indexes(
184+
mock_spanner_ids["project_id"],
185+
mock_spanner_ids["instance_id"],
186+
mock_spanner_ids["database_id"],
187+
mock_spanner_ids["table_name"],
188+
mock_credentials,
189+
)
190+
assert result["status"] == "SUCCESS"
191+
assert len(result["results"]) == 1
192+
assert result["results"][0]["INDEX_NAME"] == "PRIMARY_KEY"
193+
194+
195+
@patch("google.adk.tools.spanner.client.get_spanner_client")
196+
def test_list_table_index_columns_success(
197+
mock_get_spanner_client, mock_spanner_ids, mock_credentials
198+
):
199+
"""Test list_table_index_columns function with success."""
200+
mock_spanner_client = MagicMock()
201+
mock_instance = MagicMock()
202+
mock_database = MagicMock()
203+
mock_snapshot = MagicMock()
204+
mock_result_set = MagicMock()
205+
mock_result_set.__iter__.return_value = iter([(
206+
"PRIMARY_KEY",
207+
"",
208+
"col1",
209+
1,
210+
"NO",
211+
"STRING(MAX)",
212+
)])
213+
mock_snapshot.execute_sql.return_value = mock_result_set
214+
mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
215+
mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
216+
mock_instance.database.return_value = mock_database
217+
mock_spanner_client.instance.return_value = mock_instance
218+
mock_get_spanner_client.return_value = mock_spanner_client
219+
220+
result = metadata_tool.list_table_index_columns(
221+
mock_spanner_ids["project_id"],
222+
mock_spanner_ids["instance_id"],
223+
mock_spanner_ids["database_id"],
224+
mock_spanner_ids["table_name"],
225+
mock_credentials,
226+
)
227+
assert result["status"] == "SUCCESS"
228+
assert len(result["results"]) == 1
229+
assert result["results"][0]["COLUMN_NAME"] == "col1"
230+
231+
232+
@patch("google.adk.tools.spanner.client.get_spanner_client")
233+
def test_list_named_schemas_success(
234+
mock_get_spanner_client, mock_spanner_ids, mock_credentials
235+
):
236+
"""Test list_named_schemas function with success."""
237+
mock_spanner_client = MagicMock()
238+
mock_instance = MagicMock()
239+
mock_database = MagicMock()
240+
mock_snapshot = MagicMock()
241+
mock_result_set = MagicMock()
242+
mock_result_set.__iter__.return_value = iter([("schema1",), ("schema2",)])
243+
mock_snapshot.execute_sql.return_value = mock_result_set
244+
mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
245+
mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
246+
mock_instance.database.return_value = mock_database
247+
mock_spanner_client.instance.return_value = mock_instance
248+
mock_get_spanner_client.return_value = mock_spanner_client
249+
250+
result = metadata_tool.list_named_schemas(
251+
mock_spanner_ids["project_id"],
252+
mock_spanner_ids["instance_id"],
253+
mock_spanner_ids["database_id"],
254+
mock_credentials,
255+
)
256+
assert result["status"] == "SUCCESS"
257+
assert result["results"] == ["schema1", "schema2"]

0 commit comments

Comments
 (0)