@@ -96,12 +96,16 @@ def sampler_output(
9696 # TODO: Remove this branch once DraftModelRunner supports TP>1
9797 # and other restrictions that are part of DraftModelRunner's
9898 # supports_gpu_multi_step(..)
99+ if expanded_request .previous_hidden_states is not None :
100+ self .worker .model_runner .return_hidden_states = True
99101 for _ in range (sample_len ):
100102 model_output : List [SamplerOutput ] = self .worker .execute_model (
101103 execute_model_req = expanded_request )
102104 assert (len (model_output ) == 1
103105 ), "composing multistep workers not supported"
104106 model_output = model_output [0 ]
107+ self ._maybe_update_previous_hidden_states (
108+ model_output , expanded_request )
105109
106110 self ._append_new_tokens (
107111 model_output , expanded_request .seq_group_metadata_list ,
@@ -115,6 +119,19 @@ def sampler_output(
115119 model_outputs , indices_of_seq_with_bonus_tokens )
116120 return filtered_model_outputs , True
117121
122+ @staticmethod
123+ def _maybe_update_previous_hidden_states (
124+ model_output : SamplerOutput ,
125+ expanded_request : ExecuteModelRequest ) -> None :
126+ """
127+ Updates the previous hidden states in an expanded request
128+ in-place with the hidden states from the model output.
129+ """
130+ if expanded_request .previous_hidden_states is not None :
131+ expanded_request .previous_hidden_states = HiddenStates (
132+ model_output .hidden_states ,
133+ expanded_request .seq_group_metadata_list )
134+
118135 @staticmethod
119136 def _expand_execute_model_request (
120137 execute_model_req : ExecuteModelRequest ,
0 commit comments