33import uuid
44from datetime import datetime , timedelta , timezone
55from time import monotonic
6- from typing import Optional , Union
6+ from typing import Any , List , Optional , Union
77
88import pytest
99
1010from temporalio import activity , workflow
11- from temporalio .client import Client , WorkflowFailureError
11+ from temporalio .client import (
12+ Client ,
13+ Interceptor ,
14+ OutboundInterceptor ,
15+ StartWorkflowInput ,
16+ WorkflowFailureError ,
17+ WorkflowHandle ,
18+ )
1219from temporalio .common import RetryPolicy
1320from temporalio .exceptions import (
1421 ActivityError ,
@@ -176,7 +183,36 @@ def some_signal(self) -> None:
176183 assert "foo" == "bar"
177184
178185
186+ class SimpleClientInterceptor (Interceptor ):
187+ def __init__ (self ) -> None :
188+ self .events : List [str ] = []
189+
190+ def intercept_client (self , next : OutboundInterceptor ) -> OutboundInterceptor :
191+ return SimpleClientOutboundInterceptor (self , super ().intercept_client (next ))
192+
193+
194+ class SimpleClientOutboundInterceptor (OutboundInterceptor ):
195+ def __init__ (
196+ self , root : SimpleClientInterceptor , next : OutboundInterceptor
197+ ) -> None :
198+ super ().__init__ (next )
199+ self .root = root
200+
201+ async def start_workflow (
202+ self , input : StartWorkflowInput
203+ ) -> WorkflowHandle [Any , Any ]:
204+ self .root .events .append (f"start: { input .workflow } " )
205+ return await super ().start_workflow (input )
206+
207+
179208async def test_workflow_env_assert (client : Client ):
209+ # Set the interceptor on the client. This used to fail for being
210+ # accidentally overridden.
211+ client_config = client .config ()
212+ interceptor = SimpleClientInterceptor ()
213+ client_config ["interceptors" ] = [interceptor ]
214+ client = Client (** client_config )
215+
180216 def assert_proper_error (err : Optional [BaseException ]) -> None :
181217 assert isinstance (err , ApplicationError )
182218 # In unsandboxed workflows, this message has extra diff info appended
@@ -195,6 +231,7 @@ def assert_proper_error(err: Optional[BaseException]) -> None:
195231 task_queue = worker .task_queue ,
196232 )
197233 assert_proper_error (err .value .cause )
234+ assert interceptor .events
198235
199236 # Start a new one and check signal
200237 handle = await env .client .start_workflow (
0 commit comments