|
| 1 | +import concurrent.futures |
| 2 | +from threading import Thread |
| 3 | +from google.protobuf import empty_pb2 |
| 4 | +import grpc |
| 5 | +import durabletask.internal.orchestrator_service_pb2 as pb |
| 6 | +import durabletask.internal.orchestrator_service_pb2_grpc as stubs |
| 7 | +import durabletask.internal.shared as shared |
| 8 | + |
| 9 | +from durabletask.worker import TaskHubGrpcWorker |
| 10 | +from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager |
| 11 | + |
| 12 | +class DurableTaskSchedulerWorker(TaskHubGrpcWorker): |
| 13 | + def __init__(self, *args, access_token_manager: AccessTokenManager = None, **kwargs): |
| 14 | + # Initialize the base class |
| 15 | + super().__init__(*args, **kwargs) |
| 16 | + self._access_token_manager = access_token_manager |
| 17 | + self.__update_metadata_with_token() |
| 18 | + |
| 19 | + def __update_metadata_with_token(self): |
| 20 | + """ |
| 21 | + Add or update the `authorization` key in the metadata with the current access token. |
| 22 | + """ |
| 23 | + if self._access_token_manager is not None: |
| 24 | + token = self._access_token_manager.get_access_token() |
| 25 | + |
| 26 | + # Check if "authorization" already exists in the metadata |
| 27 | + updated = False |
| 28 | + for i, (key, _) in enumerate(self._metadata): |
| 29 | + if key == "authorization": |
| 30 | + self._metadata[i] = ("authorization", token) |
| 31 | + updated = True |
| 32 | + break |
| 33 | + |
| 34 | + # If not updated, add a new entry |
| 35 | + if not updated: |
| 36 | + self._metadata.append(("authorization", token)) |
| 37 | + |
| 38 | + def start(self): |
| 39 | + """Starts the worker on a background thread and begins listening for work items.""" |
| 40 | + channel = shared.get_grpc_channel(self._host_address, self._metadata, self._secure_channel) |
| 41 | + stub = stubs.TaskHubSidecarServiceStub(channel) |
| 42 | + |
| 43 | + if self._is_running: |
| 44 | + raise RuntimeError('The worker is already running.') |
| 45 | + |
| 46 | + def run_loop(): |
| 47 | + # TODO: Investigate whether asyncio could be used to enable greater concurrency for async activity |
| 48 | + # functions. We'd need to know ahead of time whether a function is async or not. |
| 49 | + # TODO: Max concurrency configuration settings |
| 50 | + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: |
| 51 | + while not self._shutdown.is_set(): |
| 52 | + try: |
| 53 | + self.__update_metadata_with_token() |
| 54 | + # send a "Hello" message to the sidecar to ensure that it's listening |
| 55 | + stub.Hello(empty_pb2.Empty()) |
| 56 | + |
| 57 | + # stream work items |
| 58 | + self._response_stream = stub.GetWorkItems(pb.GetWorkItemsRequest()) |
| 59 | + self._logger.info(f'Successfully connected to {self._host_address}. Waiting for work items...') |
| 60 | + |
| 61 | + # The stream blocks until either a work item is received or the stream is canceled |
| 62 | + # by another thread (see the stop() method). |
| 63 | + for work_item in self._response_stream: # type: ignore |
| 64 | + request_type = work_item.WhichOneof('request') |
| 65 | + self._logger.debug(f'Received "{request_type}" work item') |
| 66 | + if work_item.HasField('orchestratorRequest'): |
| 67 | + executor.submit(self._execute_orchestrator, work_item.orchestratorRequest, stub, work_item.completionToken) |
| 68 | + elif work_item.HasField('activityRequest'): |
| 69 | + executor.submit(self._execute_activity, work_item.activityRequest, stub, work_item.completionToken) |
| 70 | + elif work_item.HasField('healthPing'): |
| 71 | + pass # no-op |
| 72 | + else: |
| 73 | + self._logger.warning(f'Unexpected work item type: {request_type}') |
| 74 | + |
| 75 | + except grpc.RpcError as rpc_error: |
| 76 | + if rpc_error.code() == grpc.StatusCode.CANCELLED: # type: ignore |
| 77 | + self._logger.info(f'Disconnected from {self._host_address}') |
| 78 | + elif rpc_error.code() == grpc.StatusCode.UNAVAILABLE: # type: ignore |
| 79 | + self._logger.warning( |
| 80 | + f'The sidecar at address {self._host_address} is unavailable - will continue retrying') |
| 81 | + else: |
| 82 | + self._logger.warning(f'Unexpected error: {rpc_error}') |
| 83 | + except Exception as ex: |
| 84 | + self._logger.warning(f'Unexpected error: {ex}') |
| 85 | + |
| 86 | + # CONSIDER: exponential backoff |
| 87 | + self._shutdown.wait(5) |
| 88 | + self._logger.info("No longer listening for work items") |
| 89 | + return |
| 90 | + |
| 91 | + self._logger.info(f"Starting gRPC worker that connects to {self._host_address}") |
| 92 | + self._runLoop = Thread(target=run_loop) |
| 93 | + self._runLoop.start() |
| 94 | + self._is_running = True |
0 commit comments