Skip to content

Commit 6050771

Browse files
committed
Moving dts logic into its own module
Signed-off-by: Ryan Lettieri <ryanLettieri@microsoft.com>
1 parent 0de338d commit 6050771

File tree

6 files changed

+114
-28
lines changed

6 files changed

+114
-28
lines changed

durabletask/worker.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import durabletask.internal.orchestrator_service_pb2 as pb
1717
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
1818
import durabletask.internal.shared as shared
19-
from durabletask.accessTokenManager import AccessTokenManager
19+
2020
from durabletask import task
2121

2222
TInput = TypeVar('TInput')
@@ -89,36 +89,15 @@ def __init__(self, *,
8989
metadata: Optional[list[tuple[str, str]]] = None,
9090
log_handler=None,
9191
log_formatter: Optional[logging.Formatter] = None,
92-
secure_channel: bool = False,
93-
access_token_manager: AccessTokenManager = None):
92+
secure_channel: bool = False):
9493
self._registry = _Registry()
9594
self._host_address = host_address if host_address else shared.get_default_host_address()
9695
self._metadata = metadata
9796
self._logger = shared.get_logger("worker", log_handler, log_formatter)
9897
self._shutdown = Event()
9998
self._is_running = False
10099
self._secure_channel = secure_channel
101-
self._access_token_manager = access_token_manager
102-
self.__update_metadata_with_token()
103-
104-
def __update_metadata_with_token(self):
105-
"""
106-
Add or update the `authorization` key in the metadata with the current access token.
107-
"""
108-
if self._access_token_manager is not None:
109-
token = self._access_token_manager.get_access_token()
110-
111-
# Check if "authorization" already exists in the metadata
112-
updated = False
113-
for i, (key, _) in enumerate(self._metadata):
114-
if key == "authorization":
115-
self._metadata[i] = ("authorization", token)
116-
updated = True
117-
break
118-
119-
# If not updated, add a new entry
120-
if not updated:
121-
self._metadata.append(("authorization", token))
100+
122101

123102
def __enter__(self):
124103
return self
@@ -153,7 +132,6 @@ def run_loop():
153132
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
154133
while not self._shutdown.is_set():
155134
try:
156-
self.__update_metadata_with_token()
157135
# send a "Hello" message to the sidecar to ensure that it's listening
158136
stub.Hello(empty_pb2.Empty())
159137

examples/dts/dts_activity_sequence.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
"""End-to-end sample that demonstrates how to configure an orchestrator
55
that calls an activity function in a sequence and prints the outputs."""
6-
from durabletask import client, task, worker
7-
from durabletask.accessTokenManager import AccessTokenManager
6+
from durabletask import client, task
7+
from externalpackages.durabletaskscheduler.durabletask_scheduler_worker import DurableTaskSchedulerWorker
8+
from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager
89

910
def hello(ctx: task.ActivityContext, name: str) -> str:
1011
"""Activity function that returns a greeting"""
@@ -56,7 +57,7 @@ def sequence(ctx: task.OrchestrationContext, _):
5657
]
5758

5859
# configure and start the worker
59-
with worker.TaskHubGrpcWorker(host_address=endpoint, metadata=metaData, secure_channel=True, access_token_manager=token_manager) as w:
60+
with DurableTaskSchedulerWorker(host_address=endpoint, metadata=metaData, secure_channel=True, access_token_manager=token_manager) as w:
6061
w.add_orchestrator(sequence)
6162
w.add_activity(hello)
6263
w.start()
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
"""Durable Task SDK for Python"""
5+
6+
7+
PACKAGE_NAME = "durabletaskscheduler"
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from durabletask import TaskHubGrpcClient
2+
3+
class DurableTaskSchedulerClient(TaskHubGrpcClient):
4+
def __init__(self, *args, **kwargs):
5+
# Initialize the base class
6+
super().__init__(*args, **kwargs)
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import concurrent.futures
2+
from threading import Thread
3+
from google.protobuf import empty_pb2
4+
import grpc
5+
import durabletask.internal.orchestrator_service_pb2 as pb
6+
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
7+
import durabletask.internal.shared as shared
8+
9+
from durabletask.worker import TaskHubGrpcWorker
10+
from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager
11+
12+
class DurableTaskSchedulerWorker(TaskHubGrpcWorker):
13+
def __init__(self, *args, access_token_manager: AccessTokenManager = None, **kwargs):
14+
# Initialize the base class
15+
super().__init__(*args, **kwargs)
16+
self._access_token_manager = access_token_manager
17+
self.__update_metadata_with_token()
18+
19+
def __update_metadata_with_token(self):
20+
"""
21+
Add or update the `authorization` key in the metadata with the current access token.
22+
"""
23+
if self._access_token_manager is not None:
24+
token = self._access_token_manager.get_access_token()
25+
26+
# Check if "authorization" already exists in the metadata
27+
updated = False
28+
for i, (key, _) in enumerate(self._metadata):
29+
if key == "authorization":
30+
self._metadata[i] = ("authorization", token)
31+
updated = True
32+
break
33+
34+
# If not updated, add a new entry
35+
if not updated:
36+
self._metadata.append(("authorization", token))
37+
38+
def start(self):
39+
"""Starts the worker on a background thread and begins listening for work items."""
40+
channel = shared.get_grpc_channel(self._host_address, self._metadata, self._secure_channel)
41+
stub = stubs.TaskHubSidecarServiceStub(channel)
42+
43+
if self._is_running:
44+
raise RuntimeError('The worker is already running.')
45+
46+
def run_loop():
47+
# TODO: Investigate whether asyncio could be used to enable greater concurrency for async activity
48+
# functions. We'd need to know ahead of time whether a function is async or not.
49+
# TODO: Max concurrency configuration settings
50+
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
51+
while not self._shutdown.is_set():
52+
try:
53+
self.__update_metadata_with_token()
54+
# send a "Hello" message to the sidecar to ensure that it's listening
55+
stub.Hello(empty_pb2.Empty())
56+
57+
# stream work items
58+
self._response_stream = stub.GetWorkItems(pb.GetWorkItemsRequest())
59+
self._logger.info(f'Successfully connected to {self._host_address}. Waiting for work items...')
60+
61+
# The stream blocks until either a work item is received or the stream is canceled
62+
# by another thread (see the stop() method).
63+
for work_item in self._response_stream: # type: ignore
64+
request_type = work_item.WhichOneof('request')
65+
self._logger.debug(f'Received "{request_type}" work item')
66+
if work_item.HasField('orchestratorRequest'):
67+
executor.submit(self._execute_orchestrator, work_item.orchestratorRequest, stub, work_item.completionToken)
68+
elif work_item.HasField('activityRequest'):
69+
executor.submit(self._execute_activity, work_item.activityRequest, stub, work_item.completionToken)
70+
elif work_item.HasField('healthPing'):
71+
pass # no-op
72+
else:
73+
self._logger.warning(f'Unexpected work item type: {request_type}')
74+
75+
except grpc.RpcError as rpc_error:
76+
if rpc_error.code() == grpc.StatusCode.CANCELLED: # type: ignore
77+
self._logger.info(f'Disconnected from {self._host_address}')
78+
elif rpc_error.code() == grpc.StatusCode.UNAVAILABLE: # type: ignore
79+
self._logger.warning(
80+
f'The sidecar at address {self._host_address} is unavailable - will continue retrying')
81+
else:
82+
self._logger.warning(f'Unexpected error: {rpc_error}')
83+
except Exception as ex:
84+
self._logger.warning(f'Unexpected error: {ex}')
85+
86+
# CONSIDER: exponential backoff
87+
self._shutdown.wait(5)
88+
self._logger.info("No longer listening for work items")
89+
return
90+
91+
self._logger.info(f"Starting gRPC worker that connects to {self._host_address}")
92+
self._runLoop = Thread(target=run_loop)
93+
self._runLoop.start()
94+
self._is_running = True

0 commit comments

Comments
 (0)