File tree Expand file tree Collapse file tree 1 file changed +4
-5
lines changed Expand file tree Collapse file tree 1 file changed +4
-5
lines changed Original file line number Diff line number Diff line change @@ -388,9 +388,9 @@ def process_inputs(
388388
389389 eos_token_id = self .input_preprocessor .get_eos_token_id ()
390390
391- self ._validate_model_inputs (processed_inputs )
392-
393391 encoder_inputs , decoder_inputs = split_enc_dec_inputs (processed_inputs )
392+ self ._validate_model_inputs (encoder_inputs , decoder_inputs )
393+
394394 # Mypy does not always properly infer the types of some elements of
395395 # discriminated unions of TypedDicts, because of how it handles
396396 # inheritance of TypedDict. If we explicitly extract the items we want
@@ -458,9 +458,8 @@ def process_inputs(
458458 trace_headers = trace_headers ,
459459 )
460460
461- def _validate_model_inputs (self , inputs : ProcessorInputs ):
462- encoder_inputs , decoder_inputs = split_enc_dec_inputs (inputs )
463-
461+ def _validate_model_inputs (self , encoder_inputs : Optional [SingletonInputs ],
462+ decoder_inputs : SingletonInputs ):
464463 if encoder_inputs is not None :
465464 self ._validate_model_input (encoder_inputs , prompt_type = "encoder" )
466465
You can’t perform that action at this time.
0 commit comments