Skip to content

Commit

Permalink
docs update
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Feb 14, 2024
1 parent d19adad commit f96b5f2
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions keras_nlp/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,43 @@ def __init__(self, *args, **kwargs):
self.generate_function = None

def build_cache(self, batch_size, max_length):
"""Builds an empty cache for use with `call_with_cache`.
Args:
batch_size: int. The size of the batch for generation.
max_length: int. The maximum sequence length for the cache.
Returns:
A cache Tensor, the exact shape will depend on the model.
"""
raise NotImplementedError

def call_with_cache(self, token_ids, cache, index):
"""Forward pass with cache for generation.
`call_with_cache` adds an additional forward pass for the model for
autoregressive inference. Unlike calling the model directly, this method
allows caching previous key/value results in multi-head attention layer,
and avoids recomputing the outputs of seen tokens.
Args:
token_ids: a dense int Tensor with shape `(batch_size, n)`, where
`n` is some sequence length less than or equal to the max
length of the cache. Usually `n` is either the full cache
length, to "prefill" the prompt cache values, or `1`, to predict
single token id.
cache: a dense float Tensor. The cache of key and value projections
used in the attention layers of the model. The exact shape will
depend on the model.
index: int, or int Tensor. The index of the first token of
`token_ids` in the entire generated sequence.
Returns:
A `(logits, hidden_states, cache)` tuple. Where `logits` is the
language model logits for the input token_ids, `hidden_states` is
the final hidden representation of the input tokens, and `cache` is
the updated decoding cache.
"""
raise NotImplementedError

def compile(
Expand Down

0 comments on commit f96b5f2

Please sign in to comment.