3434
3535PostFix , _ = optional_import ("monai.utils.enums" , name = "PostFix" ) # For the default meta_key_postfix
3636first , _ = optional_import ("monai.utils.misc" , name = "first" )
37+ ensure_tuple , _ = optional_import ("monai.utils" , name = "ensure_tuple" )
3738Compose_ , _ = optional_import ("monai.transforms" , name = "Compose" )
3839ConfigParser_ , _ = optional_import ("monai.bundle" , name = "ConfigParser" )
3940MapTransform_ , _ = optional_import ("monai.transforms" , name = "MapTransform" )
@@ -236,7 +237,7 @@ def __init__(
236237 output_mapping : List [IOMapping ],
237238 model_name : Optional [str ] = "" ,
238239 bundle_path : Optional [str ] = None ,
239- bundle_config_names : BundleConfigNames = None ,
240+ bundle_config_names : Optional [ BundleConfigNames ] = None ,
240241 * args ,
241242 ** kwargs ,
242243 ):
@@ -261,9 +262,9 @@ def __init__(
261262 self ._input_mapping = input_mapping
262263 self ._output_mapping = output_mapping
263264
264- self ._parser = None # Needs known bundle path, either on init or when compute function is called.
265- self ._inferer = None # Will be set during bundle parsing.
266- self ._init_completed = False
265+ self ._parser : ConfigParser = None # Needs known bundle path, either on init or when compute function is called.
266+ self ._inferer : Any = None # Will be set during bundle parsing.
267+ self ._init_completed : bool = False
267268
268269 # Need to set the operator's input(s) and output(s). Even when the bundle parsing is done in init,
269270 # there is still a need to define what op inputs/outputs map to what keys in the bundle config,
@@ -289,6 +290,9 @@ def __init__(
289290 logging .warn ("Bundle parsing is not completed on init, delayed till this operator is called to execute." )
290291 self ._bundle_path = None
291292
293+ # Lazy init of model network till execution time when the context is fully set.
294+ self ._model_network : Any = None
295+
292296 @property
293297 def model_name (self ) -> str :
294298 return self ._model_name
@@ -390,7 +394,7 @@ def _get_meta_key_postfix(self, compose: Compose, key_name: str = "meta_key_post
390394 post_fix = post_fix [0 ]
391395 break
392396
393- return post_fix
397+ return str ( post_fix )
394398
395399 def _get_io_data_type (self , conf ):
396400 """
@@ -441,28 +445,32 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe
441445
442446 # Try to get the Model object and its path from the context.
443447 # If operator is not fully initialized, use model path as bundle path to finish it.
444- # If Model not loaded, but bundle path exists, load model, just in case.
448+ # If Model not loaded, but bundle path exists, load model; edge case for local dev .
445449 #
446450 # `context.models.get(model_name)` returns a model instance if exists.
447451 # If model_name is not specified and only one model exists, it returns that model.
448- model = context .models .get (self ._model_name ) if context .models else None
449- if model :
452+
453+ self ._model_network = context .models .get (self ._model_name ) if context .models else None
454+ if self ._model_network :
450455 if not self ._init_completed :
451456 with self ._lock :
452457 if not self ._init_completed :
453- self ._bundle_path = model .path
458+ self ._bundle_path = self . _model_network .path
454459 self ._init_config (self ._bundle_config_names .config_names )
455460 self ._init_completed
456461 elif self ._bundle_path :
462+ # For the case of local dev/testing when the bundle path is not passed in as an exec cmd arg.
463+ # When run as a MAP docker, the bundle file is expected to be in the context, even if the model
464+ # network is loaded on a remote inference server (when the feature is introduced).
457465 logging .debug (f"Model network not loaded. Trying to load from model path: { self ._bundle_path } " )
458- model = torch .jit .load (self .bundle_path , map_location = self ._device ).eval ()
466+ self . _model_network = torch .jit .load (self .bundle_path , map_location = self ._device ).eval ()
459467 else :
460468 raise IOError ("Model network is not load and model file not found." )
461469
462470 first_input_name , * other_names = list (self ._inputs .keys ())
463471
464472 with torch .no_grad ():
465- inputs = {}
473+ inputs : Any = {} # Use type Any to quiet MyPy type checking complaints.
466474
467475 start = time .time ()
468476 for name in self ._inputs .keys ():
@@ -482,13 +490,13 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe
482490 logging .debug (f"Ingest and Pre-processing elapsed time (seconds): { time .time () - start } " )
483491
484492 start = time .time ()
485- outputs = self .predict (data = first_input , network = model , ** other_inputs )
493+ outputs : Any = self .predict (data = first_input , ** other_inputs ) # Use type Any to quiet MyPy complaints.
486494 logging .debug (f"Inference elapsed time (seconds): { time .time () - start } " )
487495
488496 # TODO: Does this work for models where multiple outputs are returned?
489497 # Note that the inputs are needed because the invert transform requires it.
490498 start = time .time ()
491- outputs = self .post_process (outputs [0 ], inputs )
499+ outputs = self .post_process (ensure_tuple ( outputs ) [0 ], preprocessed_inputs = inputs )
492500 logging .debug (f"Post-processing elapsed time (seconds): { time .time () - start } " )
493501 if isinstance (outputs , (tuple , list )):
494502 output_dict = dict (zip (self ._outputs .keys (), outputs ))
@@ -502,19 +510,27 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe
502510 # Please see the comments in the called function for the reasons.
503511 self ._send_output (output_dict [name ], name , input_metadata , op_output , context )
504512
505- def predict (self , data : Any , network : Any , * args , ** kwargs ) -> Union [Image , Any ]:
513+ def predict (self , data : Any , * args , ** kwargs ) -> Union [Image , Any , Tuple [ Any , ...], Dict [ Any , Any ] ]:
506514 """Predicts output using the inferer."""
507- return self ._inferer (inputs = data , network = network , * args , ** kwargs )
508515
509- def pre_process (self , data : Any ) -> Union [Image , Any ]:
516+ return self ._inferer (inputs = data , network = self ._model_network , * args , ** kwargs )
517+
518+ def pre_process (self , data : Any , * args , ** kwargs ) -> Union [Image , Any , Tuple [Any , ...], Dict [Any , Any ]]:
510519 """Processes the input dictionary with the stored transform sequence `self._preproc`."""
511520
512521 if is_map_compose (self ._preproc ):
513522 return self ._preproc (data )
514523 return {k : self ._preproc (v ) for k , v in data .items ()}
515524
516- def post_process (self , data : Any , inputs : Dict ) -> Union [Image , Any ]:
517- """Processes the output list/dictionary with the stored transform sequence `self._postproc`."""
525+ def post_process (self , data : Any , * args , ** kwargs ) -> Union [Image , Any , Tuple [Any , ...], Dict [Any , Any ]]:
526+ """Processes the output list/dictionary with the stored transform sequence `self._postproc`.
527+
528+ The "processed_inputs", in fact the metadata in it, need to be passed in so that the
529+ invertible transforms in the post processing can work properly.
530+ """
531+
532+ # Expect the inputs be passed in so that the inversion can work.
533+ inputs = kwargs .get ("preprocessed_inputs" , {})
518534
519535 if is_map_compose (self ._postproc ):
520536 if isinstance (data , (list , tuple )):
@@ -585,7 +601,7 @@ def _receive_input(self, name: str, op_input: InputContext, context: ExecutionCo
585601
586602 return value , metadata
587603
588- def _send_output (self , value , name : str , metadata : Dict , op_output : OutputContext , context : ExecutionContext ):
604+ def _send_output (self , value : Any , name : str , metadata : Dict , op_output : OutputContext , context : ExecutionContext ):
589605 """Send the given output value to the output context."""
590606
591607 logging .debug (f"Setting output { name } " )
@@ -610,7 +626,7 @@ def _send_output(self, value, name: str, metadata: Dict, op_output: OutputContex
610626 raise TypeError ("arg 1 must be of type torch.Tensor or ndarray." )
611627
612628 logging .debug (f"Output { name } numpy image shape: { value .shape } " )
613- result = Image (np .swapaxes (np .squeeze (value , 0 ), 0 , 2 ).astype (np .uint8 ), metadata = metadata )
629+ result : Any = Image (np .swapaxes (np .squeeze (value , 0 ), 0 , 2 ).astype (np .uint8 ), metadata = metadata )
614630 logging .debug (f"Converted Image shape: { result .asnumpy ().shape } " )
615631 elif otype == np .ndarray :
616632 result = np .asarray (value )
0 commit comments