Skip to content

Commit

Permalink
Add :no_repeat_ngram_length option to text generation (#182)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Mar 28, 2023
1 parent 48c26db commit faa5474
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 6 deletions.
6 changes: 4 additions & 2 deletions lib/bumblebee/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ defmodule Bumblebee.Shared do
Keyword.validate!(defaults,
forced_bos_token_id: nil,
forced_eos_token_id: nil,
forced_token_ids: nil
forced_token_ids: nil,
no_repeat_ngram_length: nil
)

for {key, default} <- defaults do
Expand Down Expand Up @@ -113,7 +114,8 @@ defmodule Bumblebee.Shared do
# Generation
forced_bos_token_id: {"forced_bos_token_id", number()},
forced_eos_token_id: {"forced_eos_token_id", number()},
forced_token_ids: {"forced_decoder_ids", list(tuple([number(), number()]))}
forced_token_ids: {"forced_decoder_ids", list(tuple([number(), number()]))},
no_repeat_ngram_length: {"no_repeat_ngram_size", number()}
]

converters =
Expand Down
58 changes: 55 additions & 3 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ defmodule Bumblebee.Text.Generation do
pad_token_id: Map.get(spec, :pad_token_id),
forced_bos_token_id: Map.get(spec, :forced_bos_token_id),
forced_eos_token_id: Map.get(spec, :forced_eos_token_id),
forced_token_ids: Map.get(spec, :forced_token_ids)
forced_token_ids: Map.get(spec, :forced_token_ids),
no_repeat_ngram_length: Map.get(spec, :no_repeat_ngram_length)
)

decoder_start_token_id = opts[:decoder_start_token_id] || opts[:bos_token_id]
Expand All @@ -164,6 +165,7 @@ defmodule Bumblebee.Text.Generation do
forced_bos_token_id = opts[:forced_bos_token_id]
forced_eos_token_id = opts[:forced_eos_token_id]
forced_token_ids = opts[:forced_token_ids]
no_repeat_ngram_length = opts[:no_repeat_ngram_length]

{max_length_fun, min_length_fun} = lazy_lengths_from_opts(opts)

Expand All @@ -178,7 +180,8 @@ defmodule Bumblebee.Text.Generation do
eos_token_id,
forced_bos_token_id,
forced_eos_token_id,
forced_token_ids
forced_token_ids,
no_repeat_ngram_length
)

&generate_impl(
Expand Down Expand Up @@ -339,9 +342,13 @@ defmodule Bumblebee.Text.Generation do
eos_token_id,
forced_bos_token_id,
forced_eos_token_id,
forced_token_ids
forced_token_ids,
no_repeat_ngram_length
) do
processors = [
if no_repeat_ngram_length do
&no_repeat_ngram_logits_processor(&1, &2, ngram_length: no_repeat_ngram_length)
end,
if min_length_fun && eos_token_id do
&min_length_logits_processor(&1, &2,
min_length_fun: min_length_fun,
Expand Down Expand Up @@ -566,6 +573,51 @@ defmodule Bumblebee.Text.Generation do
end
end

defnp no_repeat_ngram_logits_processor(logits, context, opts \\ []) do
opts = keyword!(opts, [:ngram_length])
ngram_length = opts[:ngram_length]

if context.length + 1 < ngram_length do
logits
else
# Given a sequence of last {ngram_length - 1} tokens, we look
# for prior occurrences of that sequence and we want to make the
# subsequent token ignored. This way the n-gram is not repeated
# this time around

ngram_but_one_length = ngram_length - 1

last_ngram_but_one =
Nx.slice_along_axis(
context.sequences,
context.length - ngram_but_one_length,
ngram_but_one_length,
axis: 1
)

{_, _, _, _, logits} =
while {i = 0, last_ngram_but_one, sequences = context.sequences, length = context.length,
logits},
i + ngram_but_one_length < length do
ngram_but_one = Nx.slice_along_axis(sequences, i, ngram_but_one_length, axis: 1)

batch_size = Nx.axis_size(logits, 0)

token_id = sequences[[.., i + ngram_but_one_length]]
indices = Nx.stack([Nx.iota({batch_size}), token_id], axis: -1)

match? = Nx.all(ngram_but_one == last_ngram_but_one, axes: [1])
updates = Nx.select(match?, Nx.Constants.neg_infinity(), 0)

logits = Nx.indexed_add(logits, indices, updates)

{i + 1, last_ngram_but_one, sequences, length, logits}
end

logits
end
end

defnp force_token_id(logits, opts \\ []) do
token_id = opts[:token_id]

Expand Down
24 changes: 23 additions & 1 deletion test/bumblebee/text/generation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,31 @@ defmodule Bumblebee.Text.GenerationTest do
which were expected to last through at least midday tomorrow.
"""

serving = Bumblebee.Text.generation(model_info, tokenizer, max_new_tokens: 8)
serving =
Bumblebee.Text.generation(model_info, tokenizer,
max_new_tokens: 8,
defn_options: [compiler: EXLA]
)

assert %{results: [%{text: "PG&E scheduled the black"}]} = Nx.Serving.run(serving, article)
end

test "with :no_repeat_ngram_length" do
{:ok, model_info} = Bumblebee.load_model({:hf, "gpt2"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "gpt2"})

serving =
Bumblebee.Text.generation(model_info, tokenizer,
max_new_tokens: 12,
no_repeat_ngram_length: 2,
defn_options: [compiler: EXLA]
)

# Without :no_repeat_ngram_length we get
# %{results: [%{text: "I was going to say, 'Well, I'm going to say,"}]}

assert %{results: [%{text: "I was going to say, 'Well, I'm going back to the"}]} =
Nx.Serving.run(serving, "I was going")
end
end
end

0 comments on commit faa5474

Please sign in to comment.