66from dataclasses import dataclass
77from datetime import datetime
88from enum import Enum
9- from typing import Any , List , Tuple , TypeVar , Union
9+ from typing import Any , Optional , TypeVar , Union
1010
1111import grpc
1212from google .protobuf import wrappers_pb2
@@ -42,10 +42,10 @@ class OrchestrationState:
4242 runtime_status : OrchestrationStatus
4343 created_at : datetime
4444 last_updated_at : datetime
45- serialized_input : Union [str , None ]
46- serialized_output : Union [str , None ]
47- serialized_custom_status : Union [str , None ]
48- failure_details : Union [task .FailureDetails , None ]
45+ serialized_input : Optional [str ]
46+ serialized_output : Optional [str ]
47+ serialized_custom_status : Optional [str ]
48+ failure_details : Optional [task .FailureDetails ]
4949
5050 def raise_if_failed (self ):
5151 if self .failure_details is not None :
@@ -64,7 +64,7 @@ def failure_details(self):
6464 return self ._failure_details
6565
6666
67- def new_orchestration_state (instance_id : str , res : pb .GetInstanceResponse ) -> Union [OrchestrationState , None ]:
67+ def new_orchestration_state (instance_id : str , res : pb .GetInstanceResponse ) -> Optional [OrchestrationState ]:
6868 if not res .exists :
6969 return None
7070
@@ -92,20 +92,20 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Un
9292class TaskHubGrpcClient :
9393
9494 def __init__ (self , * ,
95- host_address : Union [str , None ] = None ,
96- metadata : Union [ List [ Tuple [str , str ]], None ] = None ,
97- log_handler = None ,
98- log_formatter : Union [logging .Formatter , None ] = None ,
95+ host_address : Optional [str ] = None ,
96+ metadata : Optional [ list [ tuple [str , str ]]] = None ,
97+ log_handler : Optional [ logging . Handler ] = None ,
98+ log_formatter : Optional [logging .Formatter ] = None ,
9999 secure_channel : bool = False ):
100100 channel = shared .get_grpc_channel (host_address , metadata , secure_channel = secure_channel )
101101 self ._stub = stubs .TaskHubSidecarServiceStub (channel )
102102 self ._logger = shared .get_logger ("client" , log_handler , log_formatter )
103103
104104 def schedule_new_orchestration (self , orchestrator : Union [task .Orchestrator [TInput , TOutput ], str ], * ,
105- input : Union [TInput , None ] = None ,
106- instance_id : Union [str , None ] = None ,
107- start_at : Union [datetime , None ] = None ,
108- reuse_id_policy : Union [pb .OrchestrationIdReusePolicy , None ] = None ) -> str :
105+ input : Optional [TInput ] = None ,
106+ instance_id : Optional [str ] = None ,
107+ start_at : Optional [datetime ] = None ,
108+ reuse_id_policy : Optional [pb .OrchestrationIdReusePolicy ] = None ) -> str :
109109
110110 name = orchestrator if isinstance (orchestrator , str ) else task .get_name (orchestrator )
111111
@@ -122,14 +122,14 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu
122122 res : pb .CreateInstanceResponse = self ._stub .StartInstance (req )
123123 return res .instanceId
124124
125- def get_orchestration_state (self , instance_id : str , * , fetch_payloads : bool = True ) -> Union [OrchestrationState , None ]:
125+ def get_orchestration_state (self , instance_id : str , * , fetch_payloads : bool = True ) -> Optional [OrchestrationState ]:
126126 req = pb .GetInstanceRequest (instanceId = instance_id , getInputsAndOutputs = fetch_payloads )
127127 res : pb .GetInstanceResponse = self ._stub .GetInstance (req )
128128 return new_orchestration_state (req .instanceId , res )
129129
130130 def wait_for_orchestration_start (self , instance_id : str , * ,
131131 fetch_payloads : bool = False ,
132- timeout : int = 60 ) -> Union [OrchestrationState , None ]:
132+ timeout : int = 60 ) -> Optional [OrchestrationState ]:
133133 req = pb .GetInstanceRequest (instanceId = instance_id , getInputsAndOutputs = fetch_payloads )
134134 try :
135135 self ._logger .info (f"Waiting up to { timeout } s for instance '{ instance_id } ' to start." )
@@ -144,7 +144,7 @@ def wait_for_orchestration_start(self, instance_id: str, *,
144144
145145 def wait_for_orchestration_completion (self , instance_id : str , * ,
146146 fetch_payloads : bool = True ,
147- timeout : int = 60 ) -> Union [OrchestrationState , None ]:
147+ timeout : int = 60 ) -> Optional [OrchestrationState ]:
148148 req = pb .GetInstanceRequest (instanceId = instance_id , getInputsAndOutputs = fetch_payloads )
149149 try :
150150 self ._logger .info (f"Waiting { timeout } s for instance '{ instance_id } ' to complete." )
@@ -170,7 +170,7 @@ def wait_for_orchestration_completion(self, instance_id: str, *,
170170 raise
171171
172172 def raise_orchestration_event (self , instance_id : str , event_name : str , * ,
173- data : Union [Any , None ] = None ):
173+ data : Optional [Any ] = None ):
174174 req = pb .RaiseEventRequest (
175175 instanceId = instance_id ,
176176 name = event_name ,
@@ -180,7 +180,7 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *,
180180 self ._stub .RaiseEvent (req )
181181
182182 def terminate_orchestration (self , instance_id : str , * ,
183- output : Union [Any , None ] = None ,
183+ output : Optional [Any ] = None ,
184184 recursive : bool = True ):
185185 req = pb .TerminateRequest (
186186 instanceId = instance_id ,
@@ -203,4 +203,4 @@ def resume_orchestration(self, instance_id: str):
203203 def purge_orchestration (self , instance_id : str , recursive : bool = True ):
204204 req = pb .PurgeInstancesRequest (instanceId = instance_id , recursive = recursive )
205205 self ._logger .info (f"Purging instance '{ instance_id } '." )
206- self ._stub .PurgeInstances ()
206+ self ._stub .PurgeInstances (req )
0 commit comments