Skip to content

Commit 71e05fe

Browse files
authored
Merge branch 'main' into cleanup-readme
2 parents cb17781 + e87aa29 commit 71e05fe

File tree

4 files changed

+134
-11
lines changed

4 files changed

+134
-11
lines changed

durabletask-azuremanaged/durabletask/azuremanaged/client.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33

4-
from azure.core.credentials import TokenCredential
54
from typing import Optional
65

7-
from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import \
8-
DTSDefaultClientInterceptorImpl
6+
from azure.core.credentials import TokenCredential
7+
8+
from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import (
9+
DTSDefaultClientInterceptorImpl,
10+
)
911
from durabletask.client import TaskHubGrpcClient
1012

1113

durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33

4-
import grpc
4+
from importlib.metadata import version
55
from typing import Optional
66

7+
import grpc
78
from azure.core.credentials import TokenCredential
89

9-
from durabletask.azuremanaged.internal.access_token_manager import \
10-
AccessTokenManager
10+
from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager
1111
from durabletask.internal.grpc_interceptor import (
12-
DefaultClientInterceptorImpl, _ClientCallDetails)
12+
DefaultClientInterceptorImpl,
13+
_ClientCallDetails,
14+
)
1315

1416

1517
class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl):
@@ -18,7 +20,16 @@ class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl):
1820
interceptor to add additional headers to all calls as needed."""
1921

2022
def __init__(self, token_credential: Optional[TokenCredential], taskhub_name: str):
21-
self._metadata = [("taskhub", taskhub_name)]
23+
try:
24+
# Get the version of the azuremanaged package
25+
sdk_version = version('durabletask-azuremanaged')
26+
except Exception:
27+
# Fallback if version cannot be determined
28+
sdk_version = "unknown"
29+
user_agent = f"durabletask-python/{sdk_version}"
30+
self._metadata = [
31+
("taskhub", taskhub_name),
32+
("x-user-agent", user_agent)] # 'user-agent' is a reserved header in grpc, so we use 'x-user-agent' instead
2233
super().__init__(self._metadata)
2334

2435
if token_credential is not None:

durabletask-azuremanaged/durabletask/azuremanaged/worker.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33

4-
from azure.core.credentials import TokenCredential
54
from typing import Optional
65

7-
from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import \
8-
DTSDefaultClientInterceptorImpl
6+
from azure.core.credentials import TokenCredential
7+
8+
from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import (
9+
DTSDefaultClientInterceptorImpl,
10+
)
911
from durabletask.worker import TaskHubGrpcWorker
1012

1113

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import threading
5+
import unittest
6+
from concurrent import futures
7+
from importlib.metadata import version
8+
9+
import grpc
10+
11+
from durabletask.azuremanaged.client import DurableTaskSchedulerClient
12+
from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import (
13+
DTSDefaultClientInterceptorImpl,
14+
)
15+
from durabletask.internal import orchestrator_service_pb2 as pb
16+
from durabletask.internal import orchestrator_service_pb2_grpc as stubs
17+
18+
19+
class MockTaskHubSidecarServiceServicer(stubs.TaskHubSidecarServiceServicer):
20+
"""Mock implementation of the TaskHubSidecarService for testing."""
21+
22+
def __init__(self):
23+
self.captured_metadata = {}
24+
self.requests_received = 0
25+
26+
def GetInstance(self, request, context):
27+
"""Implementation of GetInstance that captures the metadata."""
28+
# Store all metadata key-value pairs from the context
29+
for key, value in context.invocation_metadata():
30+
self.captured_metadata[key] = value
31+
32+
self.requests_received += 1
33+
34+
# Return a mock response
35+
response = pb.GetInstanceResponse(exists=False)
36+
return response
37+
38+
39+
class TestDurableTaskGrpcInterceptor(unittest.TestCase):
40+
"""Tests for the DTSDefaultClientInterceptorImpl class."""
41+
42+
@classmethod
43+
def setUpClass(cls):
44+
# Start a real gRPC server on a free port
45+
cls.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
46+
cls.port = cls.server.add_insecure_port('[::]:0') # Bind to a random free port
47+
cls.server_address = f"localhost:{cls.port}"
48+
49+
# Add our mock service implementation to the server
50+
cls.mock_servicer = MockTaskHubSidecarServiceServicer()
51+
stubs.add_TaskHubSidecarServiceServicer_to_server(cls.mock_servicer, cls.server)
52+
53+
# Start the server in a background thread
54+
cls.server.start()
55+
56+
@classmethod
57+
def tearDownClass(cls):
58+
cls.server.stop(grace=None)
59+
60+
def test_user_agent_metadata_passed_in_request(self):
61+
"""Test that the user agent metadata is correctly passed in gRPC requests."""
62+
# Create a client that connects to our mock server
63+
# Note: secure_channel is False and token_credential is None as specified
64+
task_hub_client = DurableTaskSchedulerClient(
65+
host_address=self.server_address,
66+
secure_channel=False,
67+
taskhub="test-taskhub",
68+
token_credential=None
69+
)
70+
71+
# Make a client call that will trigger our interceptor
72+
task_hub_client.get_orchestration_state("test-instance-id")
73+
74+
# Verify the request was received by our mock server
75+
self.assertEqual(1, self.mock_servicer.requests_received, "Expected one request to be received")
76+
77+
# Check if our custom x-user-agent header was correctly set
78+
self.assertIn("x-user-agent", self.mock_servicer.captured_metadata, "x-user-agent header not found")
79+
80+
# Get what we expect our user agent to be
81+
try:
82+
expected_version = version('durabletask-azuremanaged')
83+
except Exception:
84+
expected_version = "unknown"
85+
86+
expected_user_agent = f"durabletask-python/{expected_version}"
87+
self.assertEqual(
88+
expected_user_agent,
89+
self.mock_servicer.captured_metadata["x-user-agent"],
90+
f"Expected x-user-agent header to be '{expected_user_agent}'"
91+
)
92+
93+
# Check if the taskhub header was correctly set
94+
self.assertIn("taskhub", self.mock_servicer.captured_metadata, "taskhub header not found")
95+
self.assertEqual("test-taskhub", self.mock_servicer.captured_metadata["taskhub"])
96+
97+
# Verify the standard gRPC user-agent is different from our custom one
98+
# Note: gRPC automatically adds its own "user-agent" header
99+
self.assertIn("user-agent", self.mock_servicer.captured_metadata, "gRPC user-agent header not found")
100+
self.assertNotEqual(
101+
self.mock_servicer.captured_metadata["user-agent"],
102+
self.mock_servicer.captured_metadata["x-user-agent"],
103+
"gRPC user-agent should be different from our custom x-user-agent"
104+
)
105+
106+
107+
if __name__ == "__main__":
108+
unittest.main()

0 commit comments

Comments
 (0)