Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update our sampler documentation to reflect usage #1444

Merged
merged 1 commit into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 9 additions & 50 deletions keras_nlp/samplers/beam_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import ops
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.sampler import call_args_docstring
from keras_nlp.utils.python_utils import format_docstring


@format_docstring(call_args=call_args_docstring)
@keras_nlp_export("keras_nlp.samplers.BeamSampler")
class BeamSampler(Sampler):
"""Beam Sampler class.
Expand All @@ -42,55 +39,17 @@ class BeamSampler(Sampler):
{{call_args}}

Examples:
Return only the beam with the highest accumulated probability.
```python
# Use a simple alphabet of lowercase characters with ids in range [0, 25].
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
char_lookup = {v: k for k, v in int_lookup.items()}
batch_size, length, vocab_size = 1, 12, len(int_lookup)

def next(prompt, cache, index):
prompt_batch_size = tf.shape(prompt)[0]
hidden_states = np.ones((prompt_batch_size, 10))
# A uniform distribution over our alphabet.
logits = np.ones((prompt_batch_size, vocab_size))
return logits, hidden_states, cache

output = keras_nlp.samplers.BeamSampler()(
next=next,
prompt=np.full((batch_size, length), char_lookup["z"], dtype="int32"),
index=5,
)
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
# >>> ['zzzzzeeeeeee']
```
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")

Return all beams and their probabilities.
```python
# Use a simple alphabet of lowercase characters with ids in range [0, 25].
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
char_lookup = {v: k for k, v in int_lookup.items()}
batch_size, length, vocab_size = 1, 8, len(int_lookup)

def next(prompt, cache, index):
prompt_batch_size = tf.shape(prompt)[0]
hidden_states = np.ones((prompt_batch_size, 10))
# A uniform distribution over our alphabet.
logits = np.ones((batch_size, vocab_size))
return logits, hidden_states, cache

beams, probs = keras_nlp.samplers.BeamSampler(return_all_beams=True)(
next=next,
prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"),
index=5,
)

print(beams.shape)
# >>> (1, 5, 8)
print(probs.shape)
# >>> (1, 5)
print(["".join([int_lookup[i] for i in s]) for s in beams[0].numpy()])
# >>> ['zzzzzeee', 'zzzzzeed', 'zzzzzeec', 'zzzzzeea', 'zzzzzeeb']
# Pass by name to compile.
causal_lm.compile(sampler="beam")
causal_lm.generate(["Keras is a"])

# Pass by object to compile.
sampler = keras_nlp.samplers.BeamSampler(num_beams=5)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])
```
"""

Expand Down
35 changes: 10 additions & 25 deletions keras_nlp/samplers/contrastive_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import ops
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.sampler import call_args_docstring
from keras_nlp.utils.python_utils import format_docstring


@format_docstring(call_args=call_args_docstring)
@keras_nlp_export("keras_nlp.samplers.ContrastiveSampler")
class ContrastiveSampler(Sampler):
"""Contrastive Sampler class.
Expand All @@ -44,28 +41,16 @@ class ContrastiveSampler(Sampler):

Examples:
```python
# Use a simple alphabet of lowercase characters to [0, 26).
int_lookup = {i: chr(i + ord("a")) for i in range(26)}
char_lookup = {v: k for k, v in int_lookup.items()}
batch_size, length, vocab_size = 1, 12, len(int_lookup)
hidden_size = 5
index = 5

def next(prompt, cache, index):
prompt_batch_size = tf.shape(prompt)[0]
hidden_states = np.ones((prompt_batch_size, hidden_size))
# A uniform distribution over our alphabet.
logits = np.ones((prompt_batch_size, vocab_size))
return logits, hidden_states, cache

output = keras_nlp.samplers.ContrastiveSampler()(
next=next,
prompt=np.full((batch_size, length), char_lookup["z"], dtype="int32"),
index=index,
hidden_states=np.ones([batch_size, index, hidden_size]),
)
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
# >>> "zzzzzeeeeeee"
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Pass by name to compile.
causal_lm.compile(sampler="contrastive")
causal_lm.generate(["Keras is a"])

# Pass by object to compile.
sampler = keras_nlp.samplers.ContrastiveSampler(k=5)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])
```
"""

Expand Down
34 changes: 10 additions & 24 deletions keras_nlp/samplers/greedy_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,27 @@
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import ops
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.sampler import call_args_docstring
from keras_nlp.utils.python_utils import format_docstring


@format_docstring(call_args=call_args_docstring)
@keras_nlp_export("keras_nlp.samplers.GreedySampler")
class GreedySampler(Sampler):
"""Greedy sampler class.

This sampler is implemented on greedy search, i.e., always picking up the
token of the largest probability as the next token.

Call arguments:
{{call_args}}

Examples:
```python
# Use a simple alphabet of lowercase characters with ids in range [0, 25].
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
char_lookup = {v: k for k, v in int_lookup.items()}
batch_size, length, vocab_size = 1, 12, len(int_lookup)

def next(prompt, cache, index):
hidden_states = np.ones((batch_size, 10))
# A uniform distribution over our alphabet.
logits = np.ones((batch_size, vocab_size))
return logits, hidden_states, cache

output = keras_nlp.samplers.GreedySampler()(
next=next,
prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"),
index=5,
)
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
# >>> ['zzzzzaaaaaaa']
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Pass by name to compile.
causal_lm.compile(sampler="greedy")
causal_lm.generate(["Keras is a"])

# Pass by object to compile.
sampler = keras_nlp.samplers.GreedySampler()
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])
```
"""

Expand Down
27 changes: 8 additions & 19 deletions keras_nlp/samplers/random_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@
from keras_nlp.backend import ops
from keras_nlp.backend import random
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.sampler import call_args_docstring
from keras_nlp.utils.python_utils import format_docstring


@format_docstring(call_args=call_args_docstring)
@keras_nlp_export("keras_nlp.samplers.RandomSampler")
class RandomSampler(Sampler):
"""Random Sampler class.
Expand All @@ -37,24 +34,16 @@ class RandomSampler(Sampler):

Examples:
```python
# Use a simple alphabet of lowercase characters with ids in range [0, 25].
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
char_lookup = {v: k for k, v in int_lookup.items()}
batch_size, length, vocab_size = 1, 12, len(int_lookup)
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")

def next(prompt, state, index):
hidden_states = np.ones((batch_size, 10))
# A uniform distribution over our alphabet.
logits = np.ones((batch_size, vocab_size))
return logits, hidden_states, state
# Pass by name to compile.
causal_lm.compile(sampler="random")
causal_lm.generate(["Keras is a"])

output = keras_nlp.samplers.RandomSampler()(
next=next,
prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"),
index=5,
)
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
# >>> ['zzzzzcpnjqij']
# Pass by object to compile.
sampler = keras_nlp.samplers.RandomSampler(temperature=0.7)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])
```
"""

Expand Down
76 changes: 24 additions & 52 deletions keras_nlp/samplers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,8 @@
from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.backend import random
from keras_nlp.utils.python_utils import format_docstring

call_args_docstring = """next: A function which takes in the
`prompt, cache, index` of the current generation loop, and outputs
a tuple `(logits, hidden_states, cache)` with `logits` being the
logits of next token, `hidden_states` being the representation of
the next token, and `cache` for next iteration.
prompt: A 2D integer tensor with shape `(batch_size, max_length)`. This
tensor will be iteratively updated column by column with new sampled
values, starting at `index`.
cache: Optional. A tensor or nested structure of tensors that will be
updated by each call to `next`. This can be used to cache
computations from early iterations of the generative loop.
index: Optional. The first index of `prompt` to start sampling at.
Usually this is set as the length of the shortest non-padded
sequence in `prompt`.
mask: Optional. A 2D integer tensor with the same shape as `prompt`.
Locations which are `True` in the mask are never updated during
sampling. Usually used to mark all locations in the dense prompt
tensor which were present in a user input.
end_token_id: Optional. The token marking the end of the sequence. If
specified, sampling will stop as soon as all sequences in the prompt
produce a `end_token_id` in a location where `mask` is `False`.
"""


@format_docstring(call_args=call_args_docstring)


@keras_nlp_export("keras_nlp.samplers.Sampler")
class Sampler:
"""Base sampler class.
Expand All @@ -57,35 +32,32 @@ class Sampler:
{{call_args}}

This base class can be extended to implement different auto-regressive
sampling methods. Subclasses can either:

- Override the `get_next_token()` method, which computes the next token
based on a probability distribution over all possible vocab entries.
- Override `__call__`, if the sampling method needs additional information
beyond the next tokens probability distribution to sample a sequence.

Please check available subclass samplers for examples.
sampling methods. To do so, override the `get_next_token()` method, which
computes the next token based on a probability distribution over all
possible vocab entries.

Examples:

```python
# Use a simple alphabet of lowercase characters with ids in range [0, 25].
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
char_lookup = {v: k for k, v in int_lookup.items()}
batch_size, length, vocab_size = 1, 12, len(int_lookup)

def next(prompt, cache, index):
# return a uniform distribution over our alphabet.
logits = ops.ones((batch_size, vocab_size))
return logits, None, cache

output = keras_nlp.samplers.GreedySampler()(
next=next,
prompt=ops.fill((batch_size, length,), char_lookup['z']),
index=5,
)
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
# >>> ['zzzzzaaaaaaa']
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Greedy search with some tokens forbidden.
class CustomSampler(keras_nlp.samplers.Sampler):
def __init__(self, forbidden_tokens, **kwargs):
super().__init__(**kwargs)
self.forbidden_tokens = forbidden_tokens

def get_next_token(self, probs):
batch_size, vocab_size = keras.ops.shape(probs)
for id in self.forbidden_tokens:
update = keras.ops.zeros((batch_size, 1))
probs = keras.ops.slice_update(probs, (0, id), update)
return keras.ops.argmax(probs, axis=-1)

# 257 = "a" with a leading space, 262 = "the" with a leading space.
causal_lm.compile(sampler=CustomSampler(forbidden_tokens=[257, 262]))
causal_lm.summary()
causal_lm.generate(["That's strange"])
```
"""

Expand Down
27 changes: 8 additions & 19 deletions keras_nlp/samplers/top_k_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@
from keras_nlp.backend import ops
from keras_nlp.backend import random
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.sampler import call_args_docstring
from keras_nlp.utils.python_utils import format_docstring


@format_docstring(call_args=call_args_docstring)
@keras_nlp_export("keras_nlp.samplers.TopKSampler")
class TopKSampler(Sampler):
"""Top-K Sampler class.
Expand All @@ -38,24 +35,16 @@ class TopKSampler(Sampler):

Examples:
```python
# Use a simple alphabet of lowercase characters with ids in range [0, 25].
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
char_lookup = {v: k for k, v in int_lookup.items()}
batch_size, length, vocab_size = 1, 12, len(int_lookup)
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")

def next(prompt, cache, index):
hidden_states = np.ones((batch_size, 10))
# A uniform distribution over our alphabet.
logits = np.ones((batch_size, vocab_size))
return logits, hidden_states, cache
# Pass by name to compile.
causal_lm.compile(sampler="top_k")
causal_lm.generate(["Keras is a"])

output = keras_nlp.samplers.TopKSampler(k=3)(
next=next,
prompt=np.full((batch_size, length,), char_lookup['z'], dtypes="int32"),
index=5,
)
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
# >>> ['zzzzzacbbcaa']
# Pass by object to compile.
sampler = keras_nlp.samplers.TopKSampler(k=5, temperature=0.7)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])
```
"""

Expand Down
Loading
Loading