Skip to content

Commit db64d41

Browse files
committed
Setup scaffolding for XGoogRequestIdHeader checks
1 parent 37a6ab5 commit db64d41

File tree

3 files changed

+95
-1
lines changed

3 files changed

+95
-1
lines changed

google/cloud/spanner_v1/testing/interceptors.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,16 @@ def intercept(self, method, request_or_iterator, call_details):
9292
self._unary_req_segments.append(x_goog_request_id)
9393

9494
return method(request_or_iterator, call_details)
95+
96+
@property
97+
def unary_request_ids(self):
98+
return self._unary_req_segments
99+
100+
@property
101+
def stream_request_ids(self):
102+
return self._stream_req_segments
103+
104+
def reset(self):
105+
self._stream_req_segments.clear()
106+
self._unary_req_segments.clear()
107+
pass

tests/mockserver_tests/mock_server_test_base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def __init__(self, *args, **kwargs):
118118
self._client = None
119119
self._instance = None
120120
self._database = None
121+
self._interceptors = None
121122

122123
@classmethod
123124
def setup_class(cls):
@@ -146,11 +147,19 @@ def teardown_method(self, *args, **kwargs):
146147
@property
147148
def client(self) -> Client:
148149
if self._client is None:
150+
api_endpoint = "localhost:" + str(MockServerTestBase.port)
151+
channel = grpc.insecure_channel(api_endpoint)
152+
transport = None
153+
if self._interceptors and len(self._interceptors) > 0:
154+
channel = grpc.intercept_channel(channel, *self._interceptors)
155+
transport = SpannerGrpcTransport(channel=channel)
156+
149157
self._client = Client(
150158
project="p",
151159
credentials=AnonymousCredentials(),
152160
client_options=ClientOptions(
153-
api_endpoint="localhost:" + str(MockServerTestBase.port),
161+
transport=transport,
162+
api_endpoint=api_endpoint if transport is None else None,
154163
),
155164
)
156165
return self._client
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from tests.mockserver_tests.mock_server_test_base import (
16+
MockServerTestBase,
17+
add_select1_result,
18+
)
19+
from google.cloud.spanner_v1.testing.interceptors import XGoogRequestIDHeaderInterceptor
20+
21+
class TestRequestIDHeader(MockServerTestBase):
22+
# Firstly inject in the XGoogRequestIdHeader interceptor.
23+
x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor()
24+
MockServerTestBase._interceptors = [x_goog_request_id_interceptor]
25+
26+
def tearDown(self):
27+
x_goog_request_id_interceptor.reset()
28+
29+
def test_snapshot_read(self):
30+
add_select1_result()
31+
with self.database.snapshot() as snapshot:
32+
results = snapshot.execute_sql("select 1")
33+
result_list = []
34+
for result in results:
35+
result_list.append(result)
36+
self.assertEqual(1, row[0])
37+
self.assertEqual(1, len(result_list))
38+
39+
requests = self.spanner_service.requests
40+
self.assertEqual(2, len(requests), msg=requests)
41+
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
42+
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
43+
44+
# Now ensure monotonicity of the received request-id segments.
45+
stream_segments, unary_segments = self.canonicalize_request_id_headers()
46+
assert len(unary_segments) > 1
47+
assert len(stream_segments) == 0
48+
49+
def canonicalize_request_id_headers(self):
50+
src = x_goog_request_id_interceptor
51+
stream_segments = [
52+
parse_request_id(req_id) for req_id in src._stream_req_segments
53+
]
54+
unary_segments = [
55+
parse_request_id(req_id) for req_id in src._unary_req_segments
56+
]
57+
return stream_segments, unary_segments
58+
59+
60+
def parse_request_id(request_id_str):
61+
splits = request_id_str.split(".")
62+
version, rand_process_id, client_id, channel_id, nth_request, nth_attempt = list(
63+
map(lambda v: int(v), splits)
64+
)
65+
return (
66+
version,
67+
rand_process_id,
68+
client_id,
69+
channel_id,
70+
nth_request,
71+
nth_attempt,
72+
)

0 commit comments

Comments
 (0)