@@ -53,7 +53,7 @@ class Application(ABC):
5353 # This is needed to identify Application class across different environments (e.g. by `runpy.run_path()`).
5454 _class_id : str = "monai.application"
5555
56- _env : "ApplicationEnv" = None
56+ _env : Optional [ "ApplicationEnv" ] = None
5757
5858 def __init__ (
5959 self ,
@@ -179,7 +179,7 @@ def add_operator(self, operator: Operator):
179179 self ._graph .add_operator (operator )
180180
181181 def add_flow (
182- self , upstream_op : Operator , downstream_op : Operator , io_map : Optional [Dict [str , Union [str , Set ]]] = None
182+ self , upstream_op : Operator , downstream_op : Operator , io_map : Optional [Dict [str , Union [str , Set [ str ] ]]] = None
183183 ):
184184 """Adds a flow from upstream to downstream.
185185
@@ -189,8 +189,8 @@ def add_flow(
189189 Args:
190190 upstream_op (Operator): An instance of the upstream operator of type Operator.
191191 downstream_op (Operator): An instance of the downstream operator of type Operator.
192- io_map (Optional[Dict[str, str]] ): A dictionary of mapping from the source operator's label to the
193- destination operator's label.
192+ io_map (Optional[Dict[str, Union[ str, Set[str]]]] ): A dictionary of mapping from the source operator's label
193+ to the destination operator's label(s) .
194194 """
195195
196196 # Ensure that the upstream and downstream operators are valid
@@ -213,12 +213,13 @@ def add_flow(
213213 io_map = {"" : {"" }}
214214
215215 # Convert io_map's values to the set of strings.
216+ io_maps : Dict [str , Set [str ]] = io_map # type: ignore
216217 for k , v in io_map .items ():
217218 if isinstance (v , str ):
218- io_map [k ] = {v }
219+ io_maps [k ] = {v }
219220
220221 # Verify that the upstream & downstream operator have the input and output ports specified by the io_map
221- output_labels = list (io_map .keys ())
222+ output_labels = list (io_maps .keys ())
222223
223224 if len (op_output_labels ) == 1 and len (output_labels ) != 1 :
224225 raise IOMappingError (
@@ -231,17 +232,17 @@ def add_flow(
231232 if output_label not in op_output_labels :
232233 if len (op_output_labels ) == 1 and len (output_labels ) == 1 and output_label == "" :
233234 # Set the default output port label.
234- io_map [next (iter (op_output_labels ))] = io_map [output_label ]
235- del io_map [output_label ]
235+ io_maps [next (iter (op_output_labels ))] = io_maps [output_label ]
236+ del io_maps [output_label ]
236237 break
237238 raise IOMappingError (
238239 f"The upstream operator({ upstream_op .name } ) has no output port with label '{ output_label } '. "
239240 f"It should be one of ({ ', ' .join (op_output_labels )} )."
240241 )
241242
242- output_labels = list (io_map .keys ()) # re-evaluate output_labels
243+ output_labels = list (io_maps .keys ()) # re-evaluate output_labels
243244 for output_label in output_labels :
244- input_labels = io_map [output_label ]
245+ input_labels = io_maps [output_label ]
245246
246247 if len (op_input_labels ) == 1 and len (input_labels ) != 1 :
247248 raise IOMappingError (
@@ -262,7 +263,7 @@ def add_flow(
262263 f"It should be one of ({ ', ' .join (op_input_labels )} )."
263264 )
264265
265- self ._graph .add_flow (upstream_op , downstream_op , io_map )
266+ self ._graph .add_flow (upstream_op , downstream_op , io_maps )
266267
267268 def get_package_info (self , model_path : Union [str , Path ] = "" ) -> Dict :
268269 """Returns the package information of this application.
0 commit comments