From d6c9820e92961294726efc9dad33257ebf707162 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 28 Mar 2023 16:04:39 +0200 Subject: [PATCH 1/2] Add :no_repeat_ngram_length option to text generation --- lib/bumblebee/shared.ex | 6 ++- lib/bumblebee/text/generation.ex | 58 +++++++++++++++++++++++-- test/bumblebee/text/generation_test.exs | 24 +++++++++- 3 files changed, 82 insertions(+), 6 deletions(-) diff --git a/lib/bumblebee/shared.ex b/lib/bumblebee/shared.ex index c02a00f7..735977cc 100644 --- a/lib/bumblebee/shared.ex +++ b/lib/bumblebee/shared.ex @@ -56,7 +56,8 @@ defmodule Bumblebee.Shared do Keyword.validate!(defaults, forced_bos_token_id: nil, forced_eos_token_id: nil, - forced_token_ids: nil + forced_token_ids: nil, + no_repeat_ngram_length: nil ) for {key, default} <- defaults do @@ -113,7 +114,8 @@ defmodule Bumblebee.Shared do # Generation forced_bos_token_id: {"forced_bos_token_id", number()}, forced_eos_token_id: {"forced_eos_token_id", number()}, - forced_token_ids: {"forced_decoder_ids", list(tuple([number(), number()]))} + forced_token_ids: {"forced_decoder_ids", list(tuple([number(), number()]))}, + no_repeat_ngram_length: {"no_repeat_ngram_size", number()} ] converters = diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 868c91c5..52d6072d 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -155,7 +155,8 @@ defmodule Bumblebee.Text.Generation do pad_token_id: Map.get(spec, :pad_token_id), forced_bos_token_id: Map.get(spec, :forced_bos_token_id), forced_eos_token_id: Map.get(spec, :forced_eos_token_id), - forced_token_ids: Map.get(spec, :forced_token_ids) + forced_token_ids: Map.get(spec, :forced_token_ids), + no_repeat_ngram_length: Map.get(spec, :no_repeat_ngram_length) ) decoder_start_token_id = opts[:decoder_start_token_id] || opts[:bos_token_id] @@ -164,6 +165,7 @@ defmodule Bumblebee.Text.Generation do forced_bos_token_id = opts[:forced_bos_token_id] forced_eos_token_id = opts[:forced_eos_token_id] forced_token_ids = opts[:forced_token_ids] + no_repeat_ngram_length = opts[:no_repeat_ngram_length] {max_length_fun, min_length_fun} = lazy_lengths_from_opts(opts) @@ -178,7 +180,8 @@ defmodule Bumblebee.Text.Generation do eos_token_id, forced_bos_token_id, forced_eos_token_id, - forced_token_ids + forced_token_ids, + no_repeat_ngram_length ) &generate_impl( @@ -339,9 +342,13 @@ defmodule Bumblebee.Text.Generation do eos_token_id, forced_bos_token_id, forced_eos_token_id, - forced_token_ids + forced_token_ids, + no_repeat_ngram_length ) do processors = [ + if no_repeat_ngram_length do + &no_repeat_ngram_logits_processor(&1, &2, ngram_length: no_repeat_ngram_length) + end, if min_length_fun && eos_token_id do &min_length_logits_processor(&1, &2, min_length_fun: min_length_fun, @@ -566,6 +573,51 @@ defmodule Bumblebee.Text.Generation do end end + defnp no_repeat_ngram_logits_processor(logits, context, opts \\ []) do + opts = keyword!(opts, [:ngram_length]) + ngram_length = opts[:ngram_length] + + if context.length + 1 < ngram_length do + logits + else + # Given a sequence of last {ngram_length - 1} tokens, we look + # for prior occurrences of that sequence and we want to make the + # subsequent token ignored. This way the n-gram is not repeated + # this time around + + ngram_but_one_length = ngram_length - 1 + + last_ngram_but_one = + Nx.slice_along_axis( + context.sequences, + context.length - ngram_but_one_length, + ngram_but_one_length, + axis: 1 + ) + + {_, _, _, _, logits} = + while {i = 0, last_ngram_but_one, sequences = context.sequences, length = context.length, + logits}, + i + ngram_but_one_length < length do + ngram_but_one = Nx.slice_along_axis(sequences, i, ngram_but_one_length, axis: 1) + + batch_size = Nx.axis_size(logits, 0) + + token_ids = sequences[[.., i + ngram_but_one_length]] + indices = Nx.stack([Nx.iota({batch_size}), token_ids], axis: -1) + + match? = Nx.all(ngram_but_one == last_ngram_but_one, axes: [1]) + updates = Nx.select(match?, Nx.Constants.neg_infinity(), 0) + + logits = Nx.indexed_add(logits, indices, updates) + + {i + 1, last_ngram_but_one, sequences, length, logits} + end + + logits + end + end + defnp force_token_id(logits, opts \\ []) do token_id = opts[:token_id] diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 8fb7e5da..db4f4308 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -17,9 +17,31 @@ defmodule Bumblebee.Text.GenerationTest do which were expected to last through at least midday tomorrow. """ - serving = Bumblebee.Text.generation(model_info, tokenizer, max_new_tokens: 8) + serving = + Bumblebee.Text.generation(model_info, tokenizer, + max_new_tokens: 8, + defn_options: [compiler: EXLA] + ) assert %{results: [%{text: "PG&E scheduled the black"}]} = Nx.Serving.run(serving, article) end + + test "with :no_repeat_ngram_length" do + {:ok, model_info} = Bumblebee.load_model({:hf, "gpt2"}) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "gpt2"}) + + serving = + Bumblebee.Text.generation(model_info, tokenizer, + max_new_tokens: 12, + no_repeat_ngram_length: 2, + defn_options: [compiler: EXLA] + ) + + # Without :no_repeat_ngram_length we get + # %{results: [%{text: "I was going to say, 'Well, I'm going to say,"}]} + + assert %{results: [%{text: "I was going to say, 'Well, I'm going back to the"}]} = + Nx.Serving.run(serving, "I was going") + end end end From c859905d89e50eca3d76a796a1ac7703e5aa3d43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 28 Mar 2023 20:12:52 +0200 Subject: [PATCH 2/2] Update lib/bumblebee/text/generation.ex --- lib/bumblebee/text/generation.ex | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 52d6072d..eb331be0 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -603,8 +603,8 @@ defmodule Bumblebee.Text.Generation do batch_size = Nx.axis_size(logits, 0) - token_ids = sequences[[.., i + ngram_but_one_length]] - indices = Nx.stack([Nx.iota({batch_size}), token_ids], axis: -1) + token_id = sequences[[.., i + ngram_but_one_length]] + indices = Nx.stack([Nx.iota({batch_size}), token_id], axis: -1) match? = Nx.all(ngram_but_one == last_ngram_but_one, axes: [1]) updates = Nx.select(match?, Nx.Constants.neg_infinity(), 0)