diff --git a/CHANGELOG.md b/CHANGELOG.md index ee736f0..13b0e69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `set_custom_status` orchestrator API ([#31](https://github.com/microsoft/durabletask-python/pull/31)) - contributed by [@famarting](https://github.com/famarting) - Added `purge_orchestration` client API ([#34](https://github.com/microsoft/durabletask-python/pull/34)) - contributed by [@famarting](https://github.com/famarting) +- Added new `durabletask-azuremanaged` package for use with the [Durable Task Scheduler](https://techcommunity.microsoft.com/blog/appsonazureblog/announcing-limited-early-access-of-the-durable-task-scheduler-for-azure-durable-/4286526) - by [@RyanLettieri](https://github.com/RyanLettieri) ### Changes diff --git a/README.md b/README.md index 644635e..87af41d 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,14 @@ -# Durable Task Client SDK for Python +# Durable Task SDK for Python [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) [![Build Validation](https://github.com/microsoft/durabletask-python/actions/workflows/pr-validation.yml/badge.svg)](https://github.com/microsoft/durabletask-python/actions/workflows/pr-validation.yml) [![PyPI version](https://badge.fury.io/py/durabletask.svg)](https://badge.fury.io/py/durabletask) -This repo contains a Python client SDK for use with the [Durable Task Framework for Go](https://github.com/microsoft/durabletask-go) and [Dapr Workflow](https://docs.dapr.io/developing-applications/building-blocks/workflow/workflow-overview/). With this SDK, you can define, schedule, and manage durable orchestrations using ordinary Python code. +This repo contains a Python SDK for use with the [Azure Durable Task Scheduler](https://techcommunity.microsoft.com/blog/appsonazureblog/announcing-limited-early-access-of-the-durable-task-scheduler-for-azure-durable-/4286526) and the [Durable Task Framework for Go](https://github.com/microsoft/durabletask-go). With this SDK, you can define, schedule, and manage durable orchestrations using ordinary Python code. ⚠️ **This SDK is currently under active development and is not yet ready for production use.** ⚠️ -> Note that this project is **not** currently affiliated with the [Durable Functions](https://docs.microsoft.com/azure/azure-functions/durable/durable-functions-overview) project for Azure Functions. If you are looking for a Python SDK for Durable Functions, please see [this repo](https://github.com/Azure/azure-functions-durable-python). - +> Note that this SDK is **not** currently compatible with [Azure Durable Functions](https://docs.microsoft.com/azure/azure-functions/durable/durable-functions-overview). If you are looking for a Python SDK for Azure Durable Functions, please see [this repo](https://github.com/Azure/azure-functions-durable-python). ## Supported patterns diff --git a/durabletask-azuremanaged/__init__.py b/durabletask-azuremanaged/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/__init__.py b/durabletask-azuremanaged/durabletask/azuremanaged/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/client.py b/durabletask-azuremanaged/durabletask/azuremanaged/client.py new file mode 100644 index 0000000..f641eae --- /dev/null +++ b/durabletask-azuremanaged/durabletask/azuremanaged/client.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from azure.core.credentials import TokenCredential + +from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import \ + DTSDefaultClientInterceptorImpl +from durabletask.client import TaskHubGrpcClient + + +# Client class used for Durable Task Scheduler (DTS) +class DurableTaskSchedulerClient(TaskHubGrpcClient): + def __init__(self, *, + host_address: str, + taskhub: str, + token_credential: TokenCredential, + secure_channel: bool = True): + + if not taskhub: + raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub") + + interceptors = [DTSDefaultClientInterceptorImpl(token_credential, taskhub)] + + # We pass in None for the metadata so we don't construct an additional interceptor in the parent class + # Since the parent class doesn't use anything metadata for anything else, we can set it as None + super().__init__( + host_address=host_address, + secure_channel=secure_channel, + metadata=None, + interceptors=interceptors) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py b/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py new file mode 100644 index 0000000..f0e7a42 --- /dev/null +++ b/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from datetime import datetime, timedelta, timezone +from typing import Optional + +from azure.core.credentials import AccessToken, TokenCredential + +import durabletask.internal.shared as shared + + +# By default, when there's 10minutes left before the token expires, refresh the token +class AccessTokenManager: + + _token: Optional[AccessToken] + + def __init__(self, token_credential: Optional[TokenCredential], refresh_interval_seconds: int = 600): + self._scope = "https://durabletask.io/.default" + self._refresh_interval_seconds = refresh_interval_seconds + self._logger = shared.get_logger("token_manager") + + self._credential = token_credential + + if self._credential is not None: + self._token = self._credential.get_token(self._scope) + self.expiry_time = datetime.fromtimestamp(self._token.expires_on, tz=timezone.utc) + else: + self._token = None + self.expiry_time = None + + def get_access_token(self) -> Optional[AccessToken]: + if self._token is None or self.is_token_expired(): + self.refresh_token() + return self._token + + # Checks if the token is expired, or if it will expire in the next "refresh_interval_seconds" seconds. + # For example, if the token is created to have a lifespan of 2 hours, and the refresh buffer is set to 30 minutes, + # We will grab a new token when there're 30minutes left on the lifespan of the token + def is_token_expired(self) -> bool: + if self.expiry_time is None: + return True + return datetime.now(timezone.utc) >= (self.expiry_time - timedelta(seconds=self._refresh_interval_seconds)) + + def refresh_token(self): + if self._credential is not None: + self._token = self._credential.get_token(self._scope) + + # Convert UNIX timestamp to timezone-aware datetime + self.expiry_time = datetime.fromtimestamp(self._token.expires_on, tz=timezone.utc) + self._logger.debug(f"Token refreshed. Expires at: {self.expiry_time}") diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py b/durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py new file mode 100644 index 0000000..a23cac9 --- /dev/null +++ b/durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import grpc +from azure.core.credentials import TokenCredential + +from durabletask.azuremanaged.internal.access_token_manager import \ + AccessTokenManager +from durabletask.internal.grpc_interceptor import ( + DefaultClientInterceptorImpl, _ClientCallDetails) + + +class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl): + """The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, + StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an + interceptor to add additional headers to all calls as needed.""" + + def __init__(self, token_credential: TokenCredential, taskhub_name: str): + self._metadata = [("taskhub", taskhub_name)] + super().__init__(self._metadata) + + if token_credential is not None: + self._token_credential = token_credential + self._token_manager = AccessTokenManager(token_credential=self._token_credential) + access_token = self._token_manager.get_access_token() + if access_token is not None: + self._metadata.append(("authorization", f"Bearer {access_token.token}")) + + def _intercept_call( + self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: + """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC + call details.""" + # Refresh the auth token if it is present and needed + if self._metadata is not None: + for i, (key, _) in enumerate(self._metadata): + if key.lower() == "authorization": # Ensure case-insensitive comparison + new_token = self._token_manager.get_access_token() # Get the new token + if new_token is not None: + self._metadata[i] = ("authorization", f"Bearer {new_token.token}") # Update the token + + return super()._intercept_call(client_call_details) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py new file mode 100644 index 0000000..d10c2f7 --- /dev/null +++ b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from azure.core.credentials import TokenCredential + +from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import \ + DTSDefaultClientInterceptorImpl +from durabletask.worker import TaskHubGrpcWorker + + +# Worker class used for Durable Task Scheduler (DTS) +class DurableTaskSchedulerWorker(TaskHubGrpcWorker): + def __init__(self, *, + host_address: str, + taskhub: str, + token_credential: TokenCredential, + secure_channel: bool = True): + + if not taskhub: + raise ValueError("The taskhub value cannot be empty.") + + interceptors = [DTSDefaultClientInterceptorImpl(token_credential, taskhub)] + + # We pass in None for the metadata so we don't construct an additional interceptor in the parent class + # Since the parent class doesn't use anything metadata for anything else, we can set it as None + super().__init__( + host_address=host_address, + secure_channel=secure_channel, + metadata=None, + interceptors=interceptors) diff --git a/durabletask-azuremanaged/pyproject.toml b/durabletask-azuremanaged/pyproject.toml new file mode 100644 index 0000000..ac6be6f --- /dev/null +++ b/durabletask-azuremanaged/pyproject.toml @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# For more information on pyproject.toml, see https://peps.python.org/pep-0621/ + +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "durabletask.azuremanaged" +version = "0.1b1" +description = "Extensions for the Durable Task Python SDK for integrating with the Durable Task Scheduler in Azure" +keywords = [ + "durable", + "task", + "workflow", + "azure" +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", +] +requires-python = ">=3.9" +license = {file = "LICENSE"} +readme = "README.md" +dependencies = [ + "durabletask>=0.2.0", + "azure-identity>=1.19.0" +] + +[project.urls] +repository = "https://github.com/microsoft/durabletask-python" +changelog = "https://github.com/microsoft/durabletask-python/blob/main/CHANGELOG.md" + +[tool.setuptools.packages.find] +include = ["durabletask.azuremanaged", "durabletask.azuremanaged.*"] + +[tool.pytest.ini_options] +minversion = "6.0" diff --git a/durabletask/client.py b/durabletask/client.py index 31953ae..60e194f 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Optional, TypeVar, Union +from typing import Any, Optional, Sequence, TypeVar, Union import grpc from google.protobuf import wrappers_pb2 @@ -16,6 +16,7 @@ import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared from durabletask import task +from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl TInput = TypeVar('TInput') TOutput = TypeVar('TOutput') @@ -96,8 +97,25 @@ def __init__(self, *, metadata: Optional[list[tuple[str, str]]] = None, log_handler: Optional[logging.Handler] = None, log_formatter: Optional[logging.Formatter] = None, - secure_channel: bool = False): - channel = shared.get_grpc_channel(host_address, metadata, secure_channel=secure_channel) + secure_channel: bool = False, + interceptors: Optional[Sequence[shared.ClientInterceptor]] = None): + + # If the caller provided metadata, we need to create a new interceptor for it and + # add it to the list of interceptors. + if interceptors is not None: + interceptors = list(interceptors) + if metadata is not None: + interceptors.append(DefaultClientInterceptorImpl(metadata)) + elif metadata is not None: + interceptors = [DefaultClientInterceptorImpl(metadata)] + else: + interceptors = None + + channel = shared.get_grpc_channel( + host_address=host_address, + secure_channel=secure_channel, + interceptors=interceptors + ) self._stub = stubs.TaskHubSidecarServiceStub(channel) self._logger = shared.get_logger("client", log_handler, log_formatter) @@ -116,7 +134,7 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None, version=wrappers_pb2.StringValue(value=""), orchestrationIdReusePolicy=reuse_id_policy, - ) + ) self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.") res: pb.CreateInstanceResponse = self._stub.StartInstance(req) diff --git a/durabletask/internal/grpc_interceptor.py b/durabletask/internal/grpc_interceptor.py index 738fca9..69db3c5 100644 --- a/durabletask/internal/grpc_interceptor.py +++ b/durabletask/internal/grpc_interceptor.py @@ -19,10 +19,10 @@ class _ClientCallDetails( class DefaultClientInterceptorImpl ( - grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, - grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor): + grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor): """The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, - StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an + StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an interceptor to add additional headers to all calls as needed.""" def __init__(self, metadata: list[tuple[str, str]]): @@ -30,17 +30,17 @@ def __init__(self, metadata: list[tuple[str, str]]): self._metadata = metadata def _intercept_call( - self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: + self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC call details.""" if self._metadata is None: return client_call_details - + if client_call_details.metadata is not None: metadata = list(client_call_details.metadata) else: metadata = [] - + metadata.extend(self._metadata) client_call_details = _ClientCallDetails( client_call_details.method, client_call_details.timeout, metadata, diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index c4f3aa4..1872ad4 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -5,11 +5,16 @@ import json import logging from types import SimpleNamespace -from typing import Any, Optional +from typing import Any, Optional, Sequence, Union import grpc -from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl +ClientInterceptor = Union[ + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor +] # Field name used to indicate that an object was automatically serialized # and should be deserialized as a SimpleNamespace @@ -25,8 +30,9 @@ def get_default_host_address() -> str: def get_grpc_channel( host_address: Optional[str], - metadata: Optional[list[tuple[str, str]]], - secure_channel: bool = False) -> grpc.Channel: + secure_channel: bool = False, + interceptors: Optional[Sequence[ClientInterceptor]] = None) -> grpc.Channel: + if host_address is None: host_address = get_default_host_address() @@ -44,16 +50,18 @@ def get_grpc_channel( host_address = host_address[len(protocol):] break + # Create the base channel if secure_channel: channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials()) else: channel = grpc.insecure_channel(host_address) - if metadata is not None and len(metadata) > 0: - interceptors = [DefaultClientInterceptorImpl(metadata)] + # Apply interceptors ONLY if they exist + if interceptors: channel = grpc.intercept_channel(channel, *interceptors) return channel + def get_logger( name_suffix: str, log_handler: Optional[logging.Handler] = None, @@ -98,7 +106,7 @@ def default(self, obj): if dataclasses.is_dataclass(obj): # Dataclasses are not serializable by default, so we convert them to a dict and mark them for # automatic deserialization by the receiver - d = dataclasses.asdict(obj) # type: ignore + d = dataclasses.asdict(obj) # type: ignore d[AUTO_SERIALIZED] = True return d elif isinstance(obj, SimpleNamespace): diff --git a/durabletask/task.py b/durabletask/task.py index a40602b..9e8a08a 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -277,6 +277,7 @@ def get_tasks(self) -> list[Task]: def on_child_completed(self, task: Task[T]): pass + class WhenAllTask(CompositeTask[list[T]]): """A task that completes when all of its child tasks complete.""" @@ -333,7 +334,7 @@ class RetryableTask(CompletableTask[T]): """A task that can be retried according to a retry policy.""" def __init__(self, retry_policy: RetryPolicy, action: pb.OrchestratorAction, - start_time:datetime, is_sub_orch: bool) -> None: + start_time: datetime, is_sub_orch: bool) -> None: super().__init__() self._action = action self._retry_policy = retry_policy @@ -343,7 +344,7 @@ def __init__(self, retry_policy: RetryPolicy, action: pb.OrchestratorAction, def increment_attempt_count(self) -> None: self._attempt_count += 1 - + def compute_next_delay(self) -> Optional[timedelta]: if self._attempt_count >= self._retry_policy.max_number_of_attempts: return None @@ -351,7 +352,7 @@ def compute_next_delay(self) -> Optional[timedelta]: retry_expiration: datetime = datetime.max if self._retry_policy.retry_timeout is not None and self._retry_policy.retry_timeout != datetime.max: retry_expiration = self._start_time + self._retry_policy.retry_timeout - + if self._retry_policy.backoff_coefficient is None: backoff_coefficient = 1.0 else: diff --git a/durabletask/worker.py b/durabletask/worker.py index 75e2e37..2c31e52 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -9,7 +9,7 @@ from typing import Any, Generator, Optional, Sequence, TypeVar, Union import grpc -from google.protobuf import empty_pb2, wrappers_pb2 +from google.protobuf import empty_pb2 import durabletask.internal.helpers as ph import durabletask.internal.helpers as pbh @@ -17,6 +17,7 @@ import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared from durabletask import task +from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl TInput = TypeVar('TInput') TOutput = TypeVar('TOutput') @@ -82,21 +83,32 @@ class ActivityNotRegisteredError(ValueError): class TaskHubGrpcWorker: _response_stream: Optional[grpc.Future] = None + _interceptors: Optional[list[shared.ClientInterceptor]] = None def __init__(self, *, host_address: Optional[str] = None, metadata: Optional[list[tuple[str, str]]] = None, log_handler=None, log_formatter: Optional[logging.Formatter] = None, - secure_channel: bool = False): + secure_channel: bool = False, + interceptors: Optional[Sequence[shared.ClientInterceptor]] = None): self._registry = _Registry() self._host_address = host_address if host_address else shared.get_default_host_address() - self._metadata = metadata self._logger = shared.get_logger("worker", log_handler, log_formatter) self._shutdown = Event() self._is_running = False self._secure_channel = secure_channel + # Determine the interceptors to use + if interceptors is not None: + self._interceptors = list(interceptors) + if metadata: + self._interceptors.append(DefaultClientInterceptorImpl(metadata)) + elif metadata: + self._interceptors = [DefaultClientInterceptorImpl(metadata)] + else: + self._interceptors = None + def __enter__(self): return self @@ -117,7 +129,7 @@ def add_activity(self, fn: task.Activity) -> str: def start(self): """Starts the worker on a background thread and begins listening for work items.""" - channel = shared.get_grpc_channel(self._host_address, self._metadata, self._secure_channel) + channel = shared.get_grpc_channel(self._host_address, self._secure_channel, self._interceptors) stub = stubs.TaskHubSidecarServiceStub(channel) if self._is_running: @@ -143,9 +155,11 @@ def run_loop(): request_type = work_item.WhichOneof('request') self._logger.debug(f'Received "{request_type}" work item') if work_item.HasField('orchestratorRequest'): - executor.submit(self._execute_orchestrator, work_item.orchestratorRequest, stub) + executor.submit(self._execute_orchestrator, work_item.orchestratorRequest, stub, work_item.completionToken) elif work_item.HasField('activityRequest'): - executor.submit(self._execute_activity, work_item.activityRequest, stub) + executor.submit(self._execute_activity, work_item.activityRequest, stub, work_item.completionToken) + elif work_item.HasField('healthPing'): + pass # no-op else: self._logger.warning(f'Unexpected work item type: {request_type}') @@ -184,26 +198,27 @@ def stop(self): self._logger.info("Worker shutdown completed") self._is_running = False - def _execute_orchestrator(self, req: pb.OrchestratorRequest, stub: stubs.TaskHubSidecarServiceStub): + def _execute_orchestrator(self, req: pb.OrchestratorRequest, stub: stubs.TaskHubSidecarServiceStub, completionToken): try: executor = _OrchestrationExecutor(self._registry, self._logger) result = executor.execute(req.instanceId, req.pastEvents, req.newEvents) res = pb.OrchestratorResponse( instanceId=req.instanceId, actions=result.actions, - customStatus=pbh.get_string_value(result.encoded_custom_status)) + customStatus=pbh.get_string_value(result.encoded_custom_status), + completionToken=completionToken) except Exception as ex: self._logger.exception(f"An error occurred while trying to execute instance '{req.instanceId}': {ex}") failure_details = pbh.new_failure_details(ex) actions = [pbh.new_complete_orchestration_action(-1, pb.ORCHESTRATION_STATUS_FAILED, "", failure_details)] - res = pb.OrchestratorResponse(instanceId=req.instanceId, actions=actions) + res = pb.OrchestratorResponse(instanceId=req.instanceId, actions=actions, completionToken=completionToken) try: stub.CompleteOrchestratorTask(res) except Exception as ex: self._logger.exception(f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}") - def _execute_activity(self, req: pb.ActivityRequest, stub: stubs.TaskHubSidecarServiceStub): + def _execute_activity(self, req: pb.ActivityRequest, stub: stubs.TaskHubSidecarServiceStub, completionToken): instance_id = req.orchestrationInstance.instanceId try: executor = _ActivityExecutor(self._registry, self._logger) @@ -211,12 +226,14 @@ def _execute_activity(self, req: pb.ActivityRequest, stub: stubs.TaskHubSidecarS res = pb.ActivityResponse( instanceId=instance_id, taskId=req.taskId, - result=pbh.get_string_value(result)) + result=pbh.get_string_value(result), + completionToken=completionToken) except Exception as ex: res = pb.ActivityResponse( instanceId=instance_id, taskId=req.taskId, - failureDetails=pbh.new_failure_details(ex)) + failureDetails=pbh.new_failure_details(ex), + completionToken=completionToken) try: stub.CompleteActivityTask(res) @@ -471,6 +488,7 @@ def __init__(self, actions: list[pb.OrchestratorAction], encoded_custom_status: self.actions = actions self.encoded_custom_status = encoded_custom_status + class _OrchestrationExecutor: _generator: Optional[task.Orchestrator] = None diff --git a/examples/README.md b/examples/README.md index ec9088f..7cfbc7a 100644 --- a/examples/README.md +++ b/examples/README.md @@ -8,7 +8,7 @@ All the examples assume that you have a Durable Task-compatible sidecar running 1. Install the latest version of the [Dapr CLI](https://docs.dapr.io/getting-started/install-dapr-cli/), which contains and exposes an embedded version of the Durable Task engine. The setup process (which requires Docker) will configure the workflow engine to store state in a local Redis container. -1. Clone and run the [Durable Task Sidecar](https://github.com/microsoft/durabletask-go) project locally (requires Go 1.18 or higher). Orchestration state will be stored in a local sqlite database. +2. Clone and run the [Durable Task Sidecar](https://github.com/microsoft/durabletask-go) project locally (requires Go 1.18 or higher). Orchestration state will be stored in a local sqlite database. ## Running the examples diff --git a/examples/dts/README.md b/examples/dts/README.md new file mode 100644 index 0000000..9b4a3fd --- /dev/null +++ b/examples/dts/README.md @@ -0,0 +1,55 @@ +# Examples + +This directory contains examples of how to author durable orchestrations using the Durable Task Python SDK in conjunction with the Durable Task Scheduler (DTS). Please note that the installation instructions provided below will use the version of DTS directly from the your branch rather than installing through PyPI. + +## Prerequisites + +All the examples assume that you have a Durable Task Scheduler taskhub created. + +The simplest way to create a taskhub is by using the az cli commands: + +1. Create a scheduler: + az durabletask scheduler create --resource-group --name --location --ip-allowlist "[0.0.0.0/0]" --sku-capacity 1 --sku-name "Dedicated" --tags "{}" + +1. Create your taskhub + + ```bash + az durabletask taskhub create --resource-group --scheduler-name --name + ``` + +1. Retrieve the endpoint for the scheduler. This can be done by locating the taskhub in the portal. + +1. Set the appropriate environment variables for the TASKHUB and ENDPOINT + + ```bash + export TASKHUB= + export ENDPOINT= + ``` + +1. Since the samples rely on azure identity, ensure the package is installed and up-to-date + + ```bash + python3 -m pip install azure-identity + ``` + +1. Install the correct packages from the top level of this repository, i.e. durabletask-python/ + + ```bash + python3 -m pip install . + ``` + +1. Install the DTS specific packages from the durabletask-python/durabletask-azuremanaged directory + + ```bash + pip3 install -e . + ``` + +1. Grant yourself the `Durable Task Data Contributor` role over your scheduler + +## Running the examples + +Now, you can simply execute any of the examples in this directory using `python3`: + +```sh +python3 dts_activity_sequence.py +``` diff --git a/examples/dts/dts_activity_sequence.py b/examples/dts/dts_activity_sequence.py new file mode 100644 index 0000000..2ff3c22 --- /dev/null +++ b/examples/dts/dts_activity_sequence.py @@ -0,0 +1,71 @@ +"""End-to-end sample that demonstrates how to configure an orchestrator +that calls an activity function in a sequence and prints the outputs.""" +import os + +from azure.identity import DefaultAzureCredential + +from durabletask import client, task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + + +def hello(ctx: task.ActivityContext, name: str) -> str: + """Activity function that returns a greeting""" + return f'Hello {name}!' + + +def sequence(ctx: task.OrchestrationContext, _): + """Orchestrator function that calls the 'hello' activity function in a sequence""" + # call "hello" activity function in a sequence + result1 = yield ctx.call_activity(hello, input='Tokyo') + result2 = yield ctx.call_activity(hello, input='Seattle') + result3 = yield ctx.call_activity(hello, input='London') + + # return an array of results + return [result1, result2, result3] + + +# Read the environment variable +taskhub_name = os.getenv("TASKHUB") + +# Check if the variable exists +if taskhub_name: + print(f"The value of TASKHUB is: {taskhub_name}") +else: + print("TASKHUB is not set. Please set the TASKHUB environment variable to the name of the taskhub you wish to use") + print("If you are using windows powershell, run the following: $env:TASKHUB=\"\"") + print("If you are using bash, run the following: export TASKHUB=\"\"") + exit() + +# Read the environment variable +endpoint = os.getenv("ENDPOINT") + +# Check if the variable exists +if endpoint: + print(f"The value of ENDPOINT is: {endpoint}") +else: + print("ENDPOINT is not set. Please set the ENDPOINT environment variable to the endpoint of the scheduler") + print("If you are using windows powershell, run the following: $env:ENDPOINT=\"\"") + print("If you are using bash, run the following: export ENDPOINT=\"\"") + exit() + +# Note that any azure-identity credential type and configuration can be used here as DTS supports various credential +# types such as Managed Identities +credential = DefaultAzureCredential() + +# configure and start the worker +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=credential) as w: + w.add_orchestrator(sequence) + w.add_activity(hello) + w.start() + + # Construct the client and run the orchestrations + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=credential) + instance_id = c.schedule_new_orchestration(sequence) + state = c.wait_for_orchestration_completion(instance_id, timeout=60) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f'Orchestration completed! Result: {state.serialized_output}') + elif state: + print(f'Orchestration failed: {state.failure_details}') diff --git a/examples/dts/dts_fanout_fanin.py b/examples/dts/dts_fanout_fanin.py new file mode 100644 index 0000000..8ab68df --- /dev/null +++ b/examples/dts/dts_fanout_fanin.py @@ -0,0 +1,96 @@ +"""End-to-end sample that demonstrates how to configure an orchestrator +that a dynamic number activity functions in parallel, waits for them all +to complete, and prints an aggregate summary of the outputs.""" +import os +import random +import time + +from azure.identity import DefaultAzureCredential + +from durabletask import client, task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + + +def get_work_items(ctx: task.ActivityContext, _) -> list[str]: + """Activity function that returns a list of work items""" + # return a random number of work items + count = random.randint(2, 10) + print(f'generating {count} work items...') + return [f'work item {i}' for i in range(count)] + + +def process_work_item(ctx: task.ActivityContext, item: str) -> int: + """Activity function that returns a result for a given work item""" + print(f'processing work item: {item}') + + # simulate some work that takes a variable amount of time + time.sleep(random.random() * 5) + + # return a result for the given work item, which is also a random number in this case + return random.randint(0, 10) + + +def orchestrator(ctx: task.OrchestrationContext, _): + """Orchestrator function that calls the 'get_work_items' and 'process_work_item' + activity functions in parallel, waits for them all to complete, and prints + an aggregate summary of the outputs""" + + work_items: list[str] = yield ctx.call_activity(get_work_items) + + # execute the work-items in parallel and wait for them all to return + tasks = [ctx.call_activity(process_work_item, input=item) for item in work_items] + results: list[int] = yield task.when_all(tasks) + + # return an aggregate summary of the results + return { + 'work_items': work_items, + 'results': results, + 'total': sum(results), + } + + +# Read the environment variable +taskhub_name = os.getenv("TASKHUB") + +# Check if the variable exists +if taskhub_name: + print(f"The value of TASKHUB is: {taskhub_name}") +else: + print("TASKHUB is not set. Please set the TASKHUB environment variable to the name of the taskhub you wish to use") + print("If you are using windows powershell, run the following: $env:TASKHUB=\"\"") + print("If you are using bash, run the following: export TASKHUB=\"\"") + exit() + +# Read the environment variable +endpoint = os.getenv("ENDPOINT") + +# Check if the variable exists +if endpoint: + print(f"The value of ENDPOINT is: {endpoint}") +else: + print("ENDPOINT is not set. Please set the ENDPOINT environment variable to the endpoint of the scheduler") + print("If you are using windows powershell, run the following: $env:ENDPOINT=\"\"") + print("If you are using bash, run the following: export ENDPOINT=\"\"") + exit() + +credential = DefaultAzureCredential() + +# configure and start the worker +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=credential) as w: + w.add_orchestrator(orchestrator) + w.add_activity(process_work_item) + w.add_activity(get_work_items) + w.start() + + # create a client, start an orchestration, and wait for it to finish + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=credential) + instance_id = c.schedule_new_orchestration(orchestrator) + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f'Orchestration completed! Result: {state.serialized_output}') + elif state: + print(f'Orchestration failed: {state.failure_details}') + exit() diff --git a/requirements.txt b/requirements.txt index a31419b..0da7d46 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,5 @@ grpcio>=1.60.0 # 1.60.0 is the version introducing protobuf 1.25.X support, newe protobuf pytest pytest-cov +azure-core +azure-identity \ No newline at end of file diff --git a/tests/test_client.py b/tests/test_client.py index caacf65..64bbec8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,36 +1,36 @@ from unittest.mock import patch, ANY -from durabletask.internal.shared import (DefaultClientInterceptorImpl, - get_default_host_address, +from durabletask.internal.shared import (get_default_host_address, get_grpc_channel) +from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl HOST_ADDRESS = 'localhost:50051' METADATA = [('key1', 'value1'), ('key2', 'value2')] - +INTERCEPTORS = [DefaultClientInterceptorImpl(METADATA)] def test_get_grpc_channel_insecure(): with patch('grpc.insecure_channel') as mock_channel: - get_grpc_channel(HOST_ADDRESS, METADATA, False) + get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS) mock_channel.assert_called_once_with(HOST_ADDRESS) def test_get_grpc_channel_secure(): with patch('grpc.secure_channel') as mock_channel, patch( 'grpc.ssl_channel_credentials') as mock_credentials: - get_grpc_channel(HOST_ADDRESS, METADATA, True) + get_grpc_channel(HOST_ADDRESS, True, interceptors=INTERCEPTORS) mock_channel.assert_called_once_with(HOST_ADDRESS, mock_credentials.return_value) def test_get_grpc_channel_default_host_address(): with patch('grpc.insecure_channel') as mock_channel: - get_grpc_channel(None, METADATA, False) + get_grpc_channel(None, False, interceptors=INTERCEPTORS) mock_channel.assert_called_once_with(get_default_host_address()) def test_get_grpc_channel_with_metadata(): with patch('grpc.insecure_channel') as mock_channel, patch( 'grpc.intercept_channel') as mock_intercept_channel: - get_grpc_channel(HOST_ADDRESS, METADATA, False) + get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS) mock_channel.assert_called_once_with(HOST_ADDRESS) mock_intercept_channel.assert_called_once() @@ -48,41 +48,41 @@ def test_grpc_channel_with_host_name_protocol_stripping(): host_name = "myserver.com:1234" prefix = "grpc://" - get_grpc_channel(prefix + host_name, METADATA) + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) mock_insecure_channel.assert_called_with(host_name) prefix = "http://" - get_grpc_channel(prefix + host_name, METADATA) + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) mock_insecure_channel.assert_called_with(host_name) prefix = "HTTP://" - get_grpc_channel(prefix + host_name, METADATA) + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) mock_insecure_channel.assert_called_with(host_name) prefix = "GRPC://" - get_grpc_channel(prefix + host_name, METADATA) + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) mock_insecure_channel.assert_called_with(host_name) prefix = "" - get_grpc_channel(prefix + host_name, METADATA) + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) mock_insecure_channel.assert_called_with(host_name) prefix = "grpcs://" - get_grpc_channel(prefix + host_name, METADATA) + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) mock_secure_channel.assert_called_with(host_name, ANY) prefix = "https://" - get_grpc_channel(prefix + host_name, METADATA) + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) mock_secure_channel.assert_called_with(host_name, ANY) prefix = "HTTPS://" - get_grpc_channel(prefix + host_name, METADATA) + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) mock_secure_channel.assert_called_with(host_name, ANY) prefix = "GRPCS://" - get_grpc_channel(prefix + host_name, METADATA) + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) mock_secure_channel.assert_called_with(host_name, ANY) prefix = "" - get_grpc_channel(prefix + host_name, METADATA, True) + get_grpc_channel(prefix + host_name, True, interceptors=INTERCEPTORS) mock_secure_channel.assert_called_with(host_name, ANY) \ No newline at end of file