diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 72dc24da..a4247850 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -148,6 +148,9 @@ defmodule Bumblebee do "MBartForQuestionAnswering" => {Bumblebee.Text.Mbart, :for_question_answering}, "MBartForSequenceClassification" => {Bumblebee.Text.Mbart, :for_sequence_classification}, "MBartModel" => {Bumblebee.Text.Mbart, :base}, + "MistralModel" => {Bumblebee.Text.Mistral, :base}, + "MistralForCausalLM" => {Bumblebee.Text.Mistral, :for_causal_language_modeling}, + "MistralForSequenceClassification" => {Bumblebee.Text.Mistral, :for_sequence_classification}, "ResNetForImageClassification" => {Bumblebee.Vision.ResNet, :for_image_classification}, "ResNetModel" => {Bumblebee.Vision.ResNet, :base}, "RobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling}, @@ -214,6 +217,7 @@ defmodule Bumblebee do "gpt2" => Bumblebee.Text.Gpt2Tokenizer, "layoutlm" => Bumblebee.Text.LayoutLmTokenizer, "llama" => Bumblebee.Text.LlamaTokenizer, + "mistral" => Bumblebee.Text.LlamaTokenizer, "mbart" => Bumblebee.Text.MbartTokenizer, "roberta" => Bumblebee.Text.RobertaTokenizer, "t5" => Bumblebee.Text.T5Tokenizer, diff --git a/lib/bumblebee/layers.ex b/lib/bumblebee/layers.ex index aa5d2051..5d2806f1 100644 --- a/lib/bumblebee/layers.ex +++ b/lib/bumblebee/layers.ex @@ -1012,4 +1012,27 @@ defmodule Bumblebee.Layers do x2 = x[[.., .., .., size..-1//1]] Nx.concatenate([-x2, x1], axis: -1) end + + @doc """ + Adds a repeat layer to the network. + + ## Options + + * `:name` - layer name + + * `:axis` - the axis to repeat along. Defaults to `-1` + + """ + def repeat_interleave(x, times, opts \\ []) do + opts = Keyword.validate!(opts, [:name, axis: -1]) + + Axon.layer( + fn x, opts -> + axis = Nx.axis_index(x, opts[:axis]) + Bumblebee.Utils.Nx.repeat_interleave(x, times, axis: axis) + end, + [x], + opts + ) + end end diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index 58480b05..2b59177c 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -42,6 +42,7 @@ defmodule Bumblebee.Layers.Transformer do block_opts_keys = [ :num_attention_heads, + :num_key_value_heads, :causal?, :hidden_size, :ffn, @@ -298,6 +299,7 @@ defmodule Bumblebee.Layers.Transformer do :num_attention_heads, :hidden_size, :ffn, + :num_key_value_heads, attention_mask: Layers.none(), attention_head_mask: Layers.none(), attention_relative_bias: Layers.none(), @@ -323,6 +325,7 @@ defmodule Bumblebee.Layers.Transformer do name = opts[:name] num_attention_heads = opts[:num_attention_heads] + num_key_value_heads = opts[:num_key_value_heads] || num_attention_heads hidden_size = opts[:hidden_size] ffn = opts[:ffn] causal? = opts[:causal?] @@ -392,6 +395,7 @@ defmodule Bumblebee.Layers.Transformer do offset: offset, causal?: causal?, num_heads: num_attention_heads, + num_key_value_heads: num_key_value_heads, hidden_size: hidden_size, kernel_initializer: kernel_initializer, attention_head_size: attention_head_size, @@ -435,6 +439,7 @@ defmodule Bumblebee.Layers.Transformer do attention_cache: cross_attention_cache, offset: offset, num_heads: num_attention_heads, + num_key_value_heads: num_key_value_heads, hidden_size: hidden_size, kernel_initializer: kernel_initializer, attention_head_size: attention_head_size, @@ -716,6 +721,7 @@ defmodule Bumblebee.Layers.Transformer do :name, :num_heads, :hidden_size, + :num_key_value_heads, attention_mask: Layers.none(), attention_head_mask: Layers.none(), attention_relative_bias: Layers.none(), @@ -740,6 +746,7 @@ defmodule Bumblebee.Layers.Transformer do name = opts[:name] num_heads = opts[:num_heads] + num_key_value_heads = opts[:num_key_value_heads] || num_heads hidden_size = opts[:hidden_size] kernel_initializer = opts[:kernel_initializer] causal? = opts[:causal?] @@ -754,14 +761,9 @@ defmodule Bumblebee.Layers.Transformer do attention_relative_bias = opts[:attention_relative_bias] - inner_size = - if attention_head_size = opts[:attention_head_size] do - num_heads * attention_head_size - else - hidden_size - end - - head_size = div(hidden_size, num_heads) + attention_head_size = opts[:attention_head_size] || div(hidden_size, num_heads) + inner_size = num_heads * attention_head_size + inner_kv_size = num_key_value_heads * attention_head_size query = query @@ -774,21 +776,21 @@ defmodule Bumblebee.Layers.Transformer do key = key - |> Axon.dense(inner_size, + |> Axon.dense(inner_kv_size, kernel_initializer: kernel_initializer, name: join(name, "key"), use_bias: key_use_bias ) - |> Layers.split_heads(num_heads) + |> Layers.split_heads(num_key_value_heads) value = value - |> Axon.dense(inner_size, + |> Axon.dense(inner_kv_size, kernel_initializer: kernel_initializer, name: join(name, "value"), use_bias: value_use_bias ) - |> Layers.split_heads(num_heads) + |> Layers.split_heads(num_key_value_heads) {query, key} = case rotary_embedding do @@ -801,11 +803,11 @@ defmodule Bumblebee.Layers.Transformer do {position_ids, opts} = Keyword.pop(opts, :position_ids) {percentage, opts} = Keyword.pop(opts, :percentage) - size = trunc(head_size * percentage) + size = trunc(attention_head_size * percentage) rotary_opts = [name: join(name, "rotary_embedding")] ++ opts - if size == head_size do + if size == attention_head_size do Layers.rotary_embedding(query, key, position_ids, size, rotary_opts) else query_rotary = Axon.nx(query, & &1[[.., .., .., 0..(size - 1)//1]]) @@ -825,6 +827,10 @@ defmodule Bumblebee.Layers.Transformer do {query, key} end + num_key_value_groups = div(num_heads, num_key_value_heads) + key = repeat_states(key, num_key_value_groups) + value = repeat_states(value, num_key_value_groups) + {key, value, attention_cache} = Layers.Decoder.cached_attention_key_values(key, value, attention_cache, offset) @@ -882,6 +888,12 @@ defmodule Bumblebee.Layers.Transformer do {attention_output, attention_weights, attention_cache, attention_relative_bias} end + defp repeat_states(state, 1), do: state + + defp repeat_states(state, times) do + Layers.repeat_interleave(state, times, axis: 2) + end + defp validate_required_keys!(opts, keys) do case keys -- Keyword.keys(opts) do [] -> :ok diff --git a/lib/bumblebee/text/llama.ex b/lib/bumblebee/text/llama.ex index 87ffc904..e989d4b1 100644 --- a/lib/bumblebee/text/llama.ex +++ b/lib/bumblebee/text/llama.ex @@ -34,6 +34,10 @@ defmodule Bumblebee.Text.Llama do default: 32, doc: "the number of attention heads for each attention layer in the model" ], + num_key_value_heads: [ + default: nil, + doc: "the number of key value heads for each attention layer in the model" + ], activation: [ default: :silu, doc: "the activation function" @@ -302,6 +306,7 @@ defmodule Bumblebee.Text.Llama do cache: cache, num_blocks: spec.num_blocks, num_attention_heads: spec.num_attention_heads, + num_key_value_heads: spec.num_key_value_heads, hidden_size: spec.hidden_size, kernel_initializer: kernel_initializer(spec), layer_norm: &Layers.rms_norm(&1, name: &2, epsilon: spec.layer_norm_epsilon), @@ -365,6 +370,7 @@ defmodule Bumblebee.Text.Llama do hidden_size: {"hidden_size", number()}, num_blocks: {"num_hidden_layers", number()}, num_attention_heads: {"num_attention_heads", number()}, + num_key_value_heads: {"num_key_value_heads", number()}, intermediate_size: {"intermediate_size", number()}, activation: {"hidden_act", atom()}, initializer_scale: {"initializer_range", number()}, diff --git a/lib/bumblebee/text/mistral.ex b/lib/bumblebee/text/mistral.ex new file mode 100644 index 00000000..6f1a4b91 --- /dev/null +++ b/lib/bumblebee/text/mistral.ex @@ -0,0 +1,421 @@ +defmodule Bumblebee.Text.Mistral do + alias Bumblebee.Shared + + options = + [ + vocab_size: [ + default: 32000, + doc: """ + the vocabulary size of the token embedding. This corresponds to the number of distinct + tokens that can be represented in model input and output + """ + ], + max_positions: [ + default: 131_072, + doc: """ + the vocabulary size of the position embedding. This corresponds to the maximum sequence + length that this model can process. Typically this is set to a large value just in case, + such as 512, 1024 or 2048 + """ + ], + hidden_size: [ + default: 4096, + doc: "the dimensionality of hidden layers" + ], + intermediate_size: [ + default: 14336, + doc: "the dimensionality of intermediate layers" + ], + num_blocks: [ + default: 32, + doc: "the number of Transformer blocks in the model" + ], + num_attention_heads: [ + default: 32, + doc: "the number of attention heads for each attention layer in the model" + ], + num_key_value_heads: [ + default: 8, + doc: """ + the number of key-value heads used to implement Grouped Query Attention. If + this value is set to the same as the number of attention heads, it will use + regular MHA. If it's set to 1, it will use MQA, otherwise it uses Grouped Query + Attention + """ + ], + activation: [ + default: :silu, + doc: "the activation function" + ], + layer_norm_epsilon: [ + default: 1.0e-12, + doc: "the epsilon used by RMS normalization layers" + ], + initializer_scale: [ + default: 0.02, + doc: + "the standard deviation of the normal initializer used for initializing kernel parameters" + ], + rotary_embedding_base: [ + default: 10_000, + doc: "base for computing rotary embedding frequency" + ] + ] ++ + Shared.common_options([ + :output_hidden_states, + :output_attentions, + :num_labels, + :id_to_label + ]) ++ Shared.token_options(pad_token_id: 0) + + @moduledoc """ + Mistral model family. + + ## Architectures + + * `:base` - plain Mistral without any head on top + + * `:for_causal_language_modeling` - Mistral with a language modeling + head. The head returns logits for each token in the original + sequence + + * `:for_sequence_classification` - Mistral with a sequence + classification head. The head returns logits corresponding to + possible classes + + ## Inputs + + * `"input_ids"` - `{batch_size, sequence_length}` + + Indices of input sequence tokens in the vocabulary. + + * `"attention_mask"` - `{batch_size, sequence_length}` + + Mask indicating which tokens to attend to. This is used to ignore + padding tokens, which are added when processing a batch of sequences + with different length. + + * `"position_ids"` - `{batch_size, sequence_length}` + + Indices of positions of each input sequence tokens in the position + embeddings. + + * `"attention_head_mask"` - `{encoder_num_blocks, encoder_num_attention_heads}` + + Mask to nullify selected heads of the self-attention blocks in + the encoder. + + * `"input_embeddings"` - `{batch_size, sequence_length, hidden_size}` + + Embedded representation of `"input_ids"`, which can be specified + for more control over how `"input_ids"` are embedded than the + model's internal embedding lookup. If `"input_embeddings"` are present, + then `"input_ids"` will be ignored. + + * `"cache"` + + A container with cached layer results used to speed up sequential + decoding (autoregression). With cache, certain hidden states are + taken from the cache, rather than recomputed on every decoding + pass. The cache should be treated as opaque and initialized with + `Bumblebee.Text.Generation.init_cache/4`. + + ## Configuration + + #{Shared.options_doc(options)} + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.Text.Generation + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + + @impl true + def architectures(), + do: [ + :base, + :for_causal_language_modeling, + :for_sequence_classification + ] + + @impl true + def config(spec, opts \\ []) do + spec + |> Shared.put_config_attrs(opts) + |> Shared.validate_label_options() + end + + @impl true + def input_template(_spec) do + %{ + "input_ids" => Nx.template({1, 1}, :s64) + } + end + + @impl true + def init_cache(spec, batch_size, max_length, _inputs) do + Layers.Decoder.init_cache(batch_size, max_length, + hidden_size: spec.hidden_size, + decoder_num_attention_heads: spec.num_attention_heads, + decoder_num_blocks: spec.num_blocks + ) + end + + @impl true + def traverse_cache(_spec, cache, fun) do + Layers.Decoder.traverse_cache(cache, fun) + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + inputs = inputs(spec) + + inputs + |> core(spec) + |> Layers.output() + end + + def model(%__MODULE__{architecture: :for_causal_language_modeling} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head") + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + def model(%__MODULE__{architecture: :for_sequence_classification} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + + logits = + Axon.dense(outputs.hidden_state, spec.num_labels, + kernel_initializer: kernel_initializer(spec), + name: "sequence_classification_head.output" + ) + + pooled_logits = + Layers.if_present inputs["input_ids"] do + Axon.layer( + fn logits, input_ids, _opts -> + indices = + input_ids + |> Nx.not_equal(spec.pad_token_id) + |> Nx.sum(axes: [-1]) + |> Nx.subtract(1) + |> Nx.as_type({:s, 64}) + + Bumblebee.Utils.Nx.batched_take(logits, indices) + end, + [logits, inputs["input_ids"]] + ) + else + Layers.take_token(logits, axis: 1, index: -1) + end + + Layers.output(%{ + logits: pooled_logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + defp inputs(spec) do + shape = {nil, nil} + hidden_shape = {nil, nil, spec.hidden_size} + + attention_head_mask_shape = {spec.num_blocks, spec.num_attention_heads} + + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("input_ids", optional: true, shape: shape), + Axon.input("attention_mask", optional: true, shape: shape), + Axon.input("position_ids", optional: true, shape: shape), + Axon.input("attention_head_mask", optional: true, shape: attention_head_mask_shape), + Axon.input("input_embeddings", optional: true, shape: hidden_shape), + Axon.input("cache", optional: true) + ]) + end + + defp core(inputs, spec) do + embeddings = + embedder( + inputs["input_ids"], + inputs["input_embeddings"], + spec, + name: "embedder" + ) + + position_ids = + Layers.default inputs["position_ids"] do + Layers.default_position_ids(embeddings) + end + + decoder_outputs = + decoder( + embeddings, + position_ids, + inputs["attention_mask"], + inputs["attention_head_mask"], + inputs["cache"], + spec, + name: "decoder" + ) + + hidden_state = + Layers.rms_norm(decoder_outputs.hidden_state, + name: "output_norm", + epsilon: spec.layer_norm_epsilon + ) + + %{ + hidden_state: hidden_state, + hidden_states: Layers.append(decoder_outputs.hidden_states, hidden_state), + attentions: decoder_outputs.attentions, + cache: decoder_outputs.cache + } + end + + defp embedder(input_ids, input_embeddings, spec, opts) do + name = opts[:name] + + # TODO: Axon needs a way to specify ignoring pad tokens + # in gradient + Layers.default input_embeddings do + Axon.embedding(input_ids, spec.vocab_size, spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "token_embedding") + ) + end + end + + defp decoder( + hidden_state, + position_ids, + attention_mask, + attention_head_mask, + cache, + spec, + opts + ) do + name = opts[:name] + + Layers.Transformer.blocks(hidden_state, + attention_mask: attention_mask, + attention_head_mask: attention_head_mask, + cache: cache, + num_blocks: spec.num_blocks, + num_attention_heads: spec.num_attention_heads, + num_key_value_heads: spec.num_key_value_heads, + hidden_size: spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + layer_norm: &Layers.rms_norm(&1, name: &2, epsilon: spec.layer_norm_epsilon), + ffn: + &gated_ffn(&1, spec.intermediate_size, spec.hidden_size, + name: &2, + activation: spec.activation + ), + block_type: :norm_first, + causal?: true, + rotary_embedding: [ + position_ids: position_ids, + max_positions: spec.max_positions, + base: spec.rotary_embedding_base + ], + query_use_bias: false, + key_use_bias: false, + value_use_bias: false, + output_use_bias: false, + output_hidden_states: spec.output_hidden_states, + output_attentions: spec.output_attentions, + name: join(name, "blocks") + ) + end + + defp gated_ffn(hidden_state, intermediate_size, output_size, opts) do + name = opts[:name] + activation = opts[:activation] + + intermediate = + Axon.dense(hidden_state, intermediate_size, + name: join(name, "intermediate"), + use_bias: false + ) + + gate = Axon.dense(hidden_state, intermediate_size, name: join(name, "gate"), use_bias: false) + + hidden_state = Axon.multiply(intermediate, Axon.activation(gate, activation)) + + Axon.dense(hidden_state, output_size, name: join(name, "output"), use_bias: false) + end + + defp language_modeling_head(hidden_state, spec, opts) do + name = opts[:name] + + # TODO: Tie lm-head to word embedding as a spec option + Layers.dense_transposed(hidden_state, spec.vocab_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "output") + ) + end + + defp kernel_initializer(spec) do + Axon.Initializers.normal(scale: spec.initializer_scale) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, data) do + import Shared.Converters + + opts = + convert!(data, + vocab_size: {"vocab_size", number()}, + max_positions: {"max_position_embeddings", number()}, + hidden_size: {"hidden_size", number()}, + num_blocks: {"num_hidden_layers", number()}, + num_attention_heads: {"num_attention_heads", number()}, + num_key_value_heads: {"num_key_value_heads", number()}, + intermediate_size: {"intermediate_size", number()}, + activation: {"hidden_act", atom()}, + rotary_embedding_base: {"rope_theta", number()}, + initializer_scale: {"initializer_range", number()}, + layer_norm_epsilon: {"rms_norm_eps", number()} + ) ++ Shared.common_options_from_transformers(data, spec) + + @for.config(spec, opts) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + def params_mapping(_spec) do + %{ + "embedder.token_embedding" => "model.embed_tokens", + "decoder.blocks.{n}.self_attention.query" => "model.layers.{n}.self_attn.q_proj", + "decoder.blocks.{n}.self_attention.key" => "model.layers.{n}.self_attn.k_proj", + "decoder.blocks.{n}.self_attention.value" => "model.layers.{n}.self_attn.v_proj", + "decoder.blocks.{n}.self_attention.output" => "model.layers.{n}.self_attn.o_proj", + "decoder.blocks.{n}.self_attention_norm" => "model.layers.{n}.input_layernorm", + "decoder.blocks.{n}.self_attention.rotary_embedding" => + "model.layers.{n}.self_attn.rotary_emb", + "decoder.blocks.{n}.ffn.gate" => "model.layers.{n}.mlp.gate_proj", + "decoder.blocks.{n}.ffn.intermediate" => "model.layers.{n}.mlp.up_proj", + "decoder.blocks.{n}.ffn.output" => "model.layers.{n}.mlp.down_proj", + "decoder.blocks.{n}.output_norm" => "model.layers.{n}.post_attention_layernorm", + "output_norm" => "model.norm", + "language_modeling_head.output" => "lm_head", + "sequence_classification_head.output" => "score" + } + end + end +end diff --git a/lib/bumblebee/utils/tokenizers.ex b/lib/bumblebee/utils/tokenizers.ex index 5a0b723e..60d3c66e 100644 --- a/lib/bumblebee/utils/tokenizers.ex +++ b/lib/bumblebee/utils/tokenizers.ex @@ -49,14 +49,13 @@ defmodule Bumblebee.Utils.Tokenizers do encodings = Enum.map(encodings, fn encoding -> - transformations = - [ - Encoding.Transformation.pad(pad_length, - pad_id: pad_id, - pad_token: pad_token, - direction: opts[:pad_direction] - ) - ] + transformations = [ + Encoding.Transformation.pad(pad_length, + pad_id: pad_id, + pad_token: pad_token, + direction: opts[:pad_direction] + ) + ] transformations = transformations ++ diff --git a/test/bumblebee/text/bert_tokenizer_test.exs b/test/bumblebee/text/bert_tokenizer_test.exs index bf61e20b..248b1f46 100644 --- a/test/bumblebee/text/bert_tokenizer_test.exs +++ b/test/bumblebee/text/bert_tokenizer_test.exs @@ -74,8 +74,7 @@ defmodule Bumblebee.Text.BertTokenizerTest do test "encoding with multiple lengths" do assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-cased"}) - inputs = - Bumblebee.apply_tokenizer(tokenizer, "This is short.", length: [8, 16]) + inputs = Bumblebee.apply_tokenizer(tokenizer, "This is short.", length: [8, 16]) assert {1, 8} = Nx.shape(inputs["input_ids"]) diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index edb974f5..5131031a 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -22,8 +22,7 @@ defmodule Bumblebee.Text.GenerationTest do generation_config = Bumblebee.configure(generation_config, max_new_tokens: 8) - serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config) + serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config) assert %{results: [%{text: "PG&E scheduled the black"}]} = Nx.Serving.run(serving, article) end @@ -36,8 +35,7 @@ defmodule Bumblebee.Text.GenerationTest do generation_config = Bumblebee.configure(generation_config, max_new_tokens: 12, no_repeat_ngram_length: 2) - serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config) + serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config) # Without :no_repeat_ngram_length we get # %{results: [%{text: "I was going to say, 'Well, I'm going to say,"}]} @@ -57,8 +55,7 @@ defmodule Bumblebee.Text.GenerationTest do strategy: %{type: :multinomial_sampling} ) - serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config, seed: 0) + serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config, seed: 0) # Note that this is just a snapshot test, we do not use any # reference value, because of PRNG difference @@ -81,8 +78,7 @@ defmodule Bumblebee.Text.GenerationTest do strategy: %{type: :contrastive_search, top_k: 4, alpha: 0.6} ) - serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config) + serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config) assert %{results: [%{text: "I was going to say, 'Well, I don't know what you"}]} = Nx.Serving.run(serving, "I was going") @@ -104,8 +100,7 @@ defmodule Bumblebee.Text.GenerationTest do generation_config = Bumblebee.configure(generation_config, max_new_tokens: 8) - serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config, stream: true) + serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config, stream: true) stream = Nx.Serving.run(serving, article) assert Enum.to_list(stream) == ["PG&E", " scheduled", " the", " black"] diff --git a/test/bumblebee/text/mistral_test.exs b/test/bumblebee/text/mistral_test.exs new file mode 100644 index 00000000..416c493a --- /dev/null +++ b/test/bumblebee/text/mistral_test.exs @@ -0,0 +1,87 @@ +defmodule Bumblebee.Text.MistralTest do + use ExUnit.Case, async: false + + import Bumblebee.TestHelpers + + @moduletag model_test_tags() + + describe "integration" do + test "base model" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "echarlaix/tiny-random-mistral"}, architecture: :base) + + assert %Bumblebee.Text.Mistral{architecture: :base} = spec + + input_ids = Nx.tensor([[1, 6312, 28709, 1526, 28808]]) + + inputs = %{ + "input_ids" => input_ids + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 5, 32} + + assert_all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([ + [ + [-1.1513, -0.3565, -1.3482], + [0.5468, 0.5652, -0.4141], + [-1.2177, -0.7919, -0.7064] + ] + ]), + atol: 1.0e-2 + ) + end + + test "sequence classification model" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "seanmor5/tiny-random-mistral-classification"}) + + assert %Bumblebee.Text.Mistral{architecture: :for_sequence_classification} = spec + input_ids = Nx.tensor([[1, 6312, 28709, 1526]]) + + inputs = %{ + "input_ids" => input_ids + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 2} + + assert_all_close( + outputs.logits, + Nx.tensor([[0.0255, 0.0318]]), + atol: 1.0e-4 + ) + end + + test "causal language model" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "echarlaix/tiny-random-mistral"}, + architecture: :for_causal_language_modeling + ) + + assert %Bumblebee.Text.Mistral{architecture: :for_causal_language_modeling} = spec + + input_ids = Nx.tensor([[1, 6312, 28709, 1526]]) + + inputs = %{ + "input_ids" => input_ids + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 4, 32000} + + assert_all_close( + outputs.logits[[.., 1..3, 1..3]], + Nx.tensor([ + [[0.1156, 0.0420, -0.0609], [0.0333, 0.0376, -0.0531], [-0.0507, -0.0097, -0.0039]] + ]), + atol: 1.0e-2 + ) + end + end +end