Skip to content

Commit

Permalink
Add support for generate
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger committed Apr 17, 2023
1 parent 5d89334 commit c483bdf
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,8 @@ def _update_model_kwargs_for_generation(
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state

# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/models/rwkv/modeling_rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,20 @@ def get_output_embeddings(self):
def set_output_embeddings(self, new_embeddings):
self.head = new_embeddings

def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs):
# only last token for inputs_ids if the state is passed along.
if state is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and state is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

model_inputs["state"] = state
return model_inputs

@add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
Expand Down

0 comments on commit c483bdf

Please sign in to comment.