Skip to content

Commit 7321905

Browse files
committed
add non-retryable error, shutdown helpers, additional tests
Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com>
1 parent 7f89f6a commit 7321905

File tree

13 files changed

+871
-173
lines changed

13 files changed

+871
-173
lines changed

README.md

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,62 @@ Orchestrations can be continued as new using the `continue_as_new` API. This API
126126

127127
Orchestrations can be suspended using the `suspend_orchestration` client API and will remain suspended until resumed using the `resume_orchestration` client API. A suspended orchestration will stop processing new events, but will continue to buffer any that happen to arrive until resumed, ensuring that no data is lost. An orchestration can also be terminated using the `terminate_orchestration` client API. Terminated orchestrations will stop processing new events and will discard any buffered events.
128128

129-
### Retry policies (TODO)
129+
### Retry policies
130130

131131
Orchestrations can specify retry policies for activities and sub-orchestrations. These policies control how many times and how frequently an activity or sub-orchestration will be retried in the event of a transient error.
132132

133+
#### Creating a retry policy
134+
135+
```python
136+
from datetime import timedelta
137+
from durabletask import task
138+
139+
retry_policy = task.RetryPolicy(
140+
first_retry_interval=timedelta(seconds=1), # Initial delay before first retry
141+
max_number_of_attempts=5, # Maximum total attempts (includes first attempt)
142+
backoff_coefficient=2.0, # Exponential backoff multiplier (must be >= 1)
143+
max_retry_interval=timedelta(seconds=30), # Cap on retry delay
144+
retry_timeout=timedelta(minutes=5), # Total time limit for all retries (optional)
145+
)
146+
```
147+
148+
**Notes:**
149+
- `max_number_of_attempts` **includes the initial attempt**. For example, `max_number_of_attempts=5` means 1 initial attempt + up to 4 retries.
150+
- `retry_timeout` is optional. If omitted or set to `None`, retries continue until `max_number_of_attempts` is reached.
151+
- `backoff_coefficient` controls exponential backoff: delay = `first_retry_interval * (backoff_coefficient ^ retry_number)`, capped by `max_retry_interval`.
152+
- `non_retryable_error_types` (optional) can specify additional exception types to treat as non-retryable (e.g., `[ValueError, TypeError]`). `NonRetryableError` is always non-retryable regardless of this setting.
153+
154+
#### Using retry policies
155+
156+
Apply retry policies to activities or sub-orchestrations:
157+
158+
```python
159+
def my_orchestrator(ctx: task.OrchestrationContext, input):
160+
# Retry an activity
161+
result = yield ctx.call_activity(my_activity, input=data, retry_policy=retry_policy)
162+
163+
# Retry a sub-orchestration
164+
result = yield ctx.call_sub_orchestrator(child_orchestrator, input=data, retry_policy=retry_policy)
165+
```
166+
167+
#### Non-retryable errors
168+
169+
For errors that should not be retried (e.g., validation failures, permanent errors), raise a `NonRetryableError`:
170+
171+
```python
172+
from durabletask.task import NonRetryableError
173+
174+
def my_activity(ctx: task.ActivityContext, input):
175+
if input is None:
176+
# This error will bypass retry logic and fail immediately
177+
raise NonRetryableError("Input cannot be None")
178+
179+
# Transient errors (network, timeouts, etc.) will be retried
180+
return call_external_service(input)
181+
```
182+
183+
Even with a retry policy configured, `NonRetryableError` will fail immediately without retrying.
184+
133185
## Getting Started
134186

135187
### Prerequisites
@@ -194,7 +246,7 @@ Certain aspects like multi-app activities require the full dapr runtime to be ru
194246
```shell
195247
dapr init || true
196248

197-
dapr run --app-id test-app --dapr-grpc-port 4001 --components-path ./examples/components/
249+
dapr run --app-id test-app --dapr-grpc-port 4001 --resources-path ./examples/components/
198250
```
199251

200252
To run the E2E tests on a specific python version (eg: 3.11), run the following command from the project root:

dev-requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-
grpcio-tools==1.62.3 # 1.62.X is the latest version before protobuf 1.26.X is used which has breaking changes for Python # supports protobuf 6.x and aligns with generated code

durabletask/client.py

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,28 @@ def __init__(
127127
interceptors=interceptors,
128128
options=channel_options,
129129
)
130+
self._channel = channel
130131
self._stub = stubs.TaskHubSidecarServiceStub(channel)
131132
self._logger = shared.get_logger("client", log_handler, log_formatter)
132133

134+
def __enter__(self):
135+
return self
136+
137+
def __exit__(self, exc_type, exc, tb):
138+
try:
139+
self.close()
140+
finally:
141+
return False
142+
143+
def close(self) -> None:
144+
"""Close the underlying gRPC channel."""
145+
try:
146+
# grpc.Channel.close() is idempotent
147+
self._channel.close()
148+
except Exception:
149+
# Best-effort cleanup
150+
pass
151+
133152
def schedule_new_orchestration(
134153
self,
135154
orchestrator: Union[task.Orchestrator[TInput, TOutput], str],
@@ -188,10 +207,59 @@ def wait_for_orchestration_completion(
188207
) -> Optional[OrchestrationState]:
189208
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
190209
try:
191-
grpc_timeout = None if timeout == 0 else timeout
192-
self._logger.info(
193-
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete."
194-
)
210+
# gRPC timeout mapping (pytest unit tests may pass None explicitly)
211+
grpc_timeout = None if (timeout is None or timeout == 0) else timeout
212+
213+
# If timeout is None or 0, skip pre-checks/polling and call server-side wait directly
214+
if timeout is None or timeout == 0:
215+
self._logger.info(
216+
f"Waiting {'indefinitely' if not timeout else f'up to {timeout}s'} for instance '{instance_id}' to complete."
217+
)
218+
res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(
219+
req, timeout=grpc_timeout
220+
)
221+
state = new_orchestration_state(req.instanceId, res)
222+
return state
223+
224+
# For positive timeout, best-effort pre-check and short polling to avoid long server waits
225+
try:
226+
# First check if the orchestration is already completed
227+
current_state = self.get_orchestration_state(
228+
instance_id, fetch_payloads=fetch_payloads
229+
)
230+
if current_state and current_state.runtime_status in [
231+
OrchestrationStatus.COMPLETED,
232+
OrchestrationStatus.FAILED,
233+
OrchestrationStatus.TERMINATED,
234+
]:
235+
return current_state
236+
237+
# Poll for completion with exponential backoff to handle eventual consistency
238+
import time
239+
240+
poll_timeout = min(timeout, 10)
241+
poll_start = time.time()
242+
poll_interval = 0.1
243+
244+
while time.time() - poll_start < poll_timeout:
245+
current_state = self.get_orchestration_state(
246+
instance_id, fetch_payloads=fetch_payloads
247+
)
248+
249+
if current_state and current_state.runtime_status in [
250+
OrchestrationStatus.COMPLETED,
251+
OrchestrationStatus.FAILED,
252+
OrchestrationStatus.TERMINATED,
253+
]:
254+
return current_state
255+
256+
time.sleep(poll_interval)
257+
poll_interval = min(poll_interval * 1.5, 1.0) # Exponential backoff, max 1s
258+
except Exception:
259+
# Ignore pre-check/poll issues (e.g., mocked stubs in unit tests) and fall back
260+
pass
261+
262+
self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to complete.")
195263
res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(
196264
req, timeout=grpc_timeout
197265
)

durabletask/internal/shared.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def get_logger(
102102
# Add a default log handler if none is provided
103103
if log_handler is None:
104104
log_handler = logging.StreamHandler()
105-
log_handler.setLevel(logging.INFO)
105+
log_handler.setLevel(logging.DEBUG)
106106
logger.handlers.append(log_handler)
107107

108108
# Set a default log formatter to our handler if none is provided

durabletask/task.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,16 @@ class OrchestrationStateError(Exception):
233233
pass
234234

235235

236+
class NonRetryableError(Exception):
237+
"""Exception indicating the operation should not be retried.
238+
239+
If an activity or sub-orchestration raises this exception, retry logic will be
240+
bypassed and the failure will be returned immediately to the orchestrator.
241+
"""
242+
243+
pass
244+
245+
236246
class Task(ABC, Generic[T]):
237247
"""Abstract base class for asynchronous tasks in a durable orchestration."""
238248

@@ -397,7 +407,7 @@ def compute_next_delay(self) -> Optional[timedelta]:
397407
next_delay_f = min(
398408
next_delay_f, self._retry_policy.max_retry_interval.total_seconds()
399409
)
400-
return timedelta(seconds=next_delay_f)
410+
return timedelta(seconds=next_delay_f)
401411

402412
return None
403413

@@ -490,6 +500,7 @@ def __init__(
490500
backoff_coefficient: Optional[float] = 1.0,
491501
max_retry_interval: Optional[timedelta] = None,
492502
retry_timeout: Optional[timedelta] = None,
503+
non_retryable_error_types: Optional[list[Union[str, type]]] = None,
493504
):
494505
"""Creates a new RetryPolicy instance.
495506
@@ -505,6 +516,11 @@ def __init__(
505516
The maximum retry interval to use for any retry attempt.
506517
retry_timeout : Optional[timedelta]
507518
The maximum amount of time to spend retrying the operation.
519+
non_retryable_error_types : Optional[list[Union[str, type]]]
520+
A list of exception type names or classes that should not be retried.
521+
If a failure's error type matches any of these, the task fails immediately.
522+
The built-in NonRetryableError is always treated as non-retryable regardless
523+
of this setting.
508524
"""
509525
# validate inputs
510526
if first_retry_interval < timedelta(seconds=0):
@@ -523,6 +539,17 @@ def __init__(
523539
self._backoff_coefficient = backoff_coefficient
524540
self._max_retry_interval = max_retry_interval
525541
self._retry_timeout = retry_timeout
542+
# Normalize non-retryable error type names to a set of strings
543+
names: Optional[set[str]] = None
544+
if non_retryable_error_types:
545+
names = set()
546+
for t in non_retryable_error_types:
547+
if isinstance(t, str):
548+
if t:
549+
names.add(t)
550+
elif isinstance(t, type):
551+
names.add(t.__name__)
552+
self._non_retryable_error_types = names
526553

527554
@property
528555
def first_retry_interval(self) -> timedelta:
@@ -549,6 +576,15 @@ def retry_timeout(self) -> Optional[timedelta]:
549576
"""The maximum amount of time to spend retrying the operation."""
550577
return self._retry_timeout
551578

579+
@property
580+
def non_retryable_error_types(self) -> Optional[set[str]]:
581+
"""Set of error type names that should not be retried.
582+
583+
Comparison is performed against the errorType string provided by the
584+
backend (typically the exception class name).
585+
"""
586+
return self._non_retryable_error_types
587+
552588

553589
def get_name(fn: Callable) -> str:
554590
"""Returns the name of the provided function"""

0 commit comments

Comments
 (0)