|
7 | 7 |
|
8 | 8 | from dataclasses import dataclass |
9 | 9 | from typing import ( |
10 | | - TYPE_CHECKING, |
11 | 10 | Awaitable, |
12 | 11 | Callable, |
13 | 12 | List, |
14 | | - Mapping, |
15 | 13 | MutableSequence, |
16 | 14 | Optional, |
17 | 15 | Sequence, |
|
20 | 18 | Union, |
21 | 19 | ) |
22 | 20 |
|
23 | | -import google.protobuf.internal.containers |
24 | 21 | from typing_extensions import TypeAlias |
25 | 22 |
|
26 | 23 | import temporalio.api.common.v1 |
|
35 | 32 | import temporalio.bridge.temporal_sdk_bridge |
36 | 33 | import temporalio.converter |
37 | 34 | import temporalio.exceptions |
38 | | -from temporalio.api.common.v1.message_pb2 import Payload, Payloads |
39 | | -from temporalio.bridge._visitor import PayloadVisitor, VisitorFunctions |
| 35 | +from temporalio.api.common.v1.message_pb2 import Payload |
| 36 | +from temporalio.bridge._visitor import VisitorFunctions |
40 | 37 | from temporalio.bridge.temporal_sdk_bridge import ( |
41 | 38 | CustomSlotSupplier as BridgeCustomSlotSupplier, |
42 | 39 | ) |
43 | 40 | from temporalio.bridge.temporal_sdk_bridge import PollShutdownError # type: ignore |
| 41 | +from temporalio.worker._command_aware_visitor import CommandAwarePayloadVisitor |
44 | 42 |
|
45 | 43 |
|
46 | 44 | @dataclass |
@@ -299,22 +297,22 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: |
299 | 297 |
|
300 | 298 |
|
301 | 299 | async def decode_activation( |
302 | | - act: temporalio.bridge.proto.workflow_activation.WorkflowActivation, |
| 300 | + activation: temporalio.bridge.proto.workflow_activation.WorkflowActivation, |
303 | 301 | codec: temporalio.converter.PayloadCodec, |
304 | 302 | decode_headers: bool, |
305 | 303 | ) -> None: |
306 | | - """Decode the given activation with the codec.""" |
307 | | - await PayloadVisitor( |
| 304 | + """Decode all payloads in the activation.""" |
| 305 | + await CommandAwarePayloadVisitor( |
308 | 306 | skip_search_attributes=True, skip_headers=not decode_headers |
309 | | - ).visit(_Visitor(codec.decode), act) |
| 307 | + ).visit(_Visitor(codec.decode), activation) |
310 | 308 |
|
311 | 309 |
|
312 | 310 | async def encode_completion( |
313 | | - comp: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion, |
| 311 | + completion: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion, |
314 | 312 | codec: temporalio.converter.PayloadCodec, |
315 | 313 | encode_headers: bool, |
316 | 314 | ) -> None: |
317 | | - """Recursively encode the given completion with the codec.""" |
318 | | - await PayloadVisitor( |
| 315 | + """Encode all payloads in the completion.""" |
| 316 | + await CommandAwarePayloadVisitor( |
319 | 317 | skip_search_attributes=True, skip_headers=not encode_headers |
320 | | - ).visit(_Visitor(codec.encode), comp) |
| 318 | + ).visit(_Visitor(codec.encode), completion) |
0 commit comments