Skip to content

Commit

Permalink
Add Mistral (#264)
Browse files Browse the repository at this point in the history
Co-authored-by: Jonatan Kłosko <jonatanklosko@gmail.com>
  • Loading branch information
seanmor5 and jonatanklosko authored Oct 23, 2023
1 parent 8867021 commit 38876bb
Show file tree
Hide file tree
Showing 9 changed files with 580 additions and 34 deletions.
4 changes: 4 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ defmodule Bumblebee do
"MBartForQuestionAnswering" => {Bumblebee.Text.Mbart, :for_question_answering},
"MBartForSequenceClassification" => {Bumblebee.Text.Mbart, :for_sequence_classification},
"MBartModel" => {Bumblebee.Text.Mbart, :base},
"MistralModel" => {Bumblebee.Text.Mistral, :base},
"MistralForCausalLM" => {Bumblebee.Text.Mistral, :for_causal_language_modeling},
"MistralForSequenceClassification" => {Bumblebee.Text.Mistral, :for_sequence_classification},
"ResNetForImageClassification" => {Bumblebee.Vision.ResNet, :for_image_classification},
"ResNetModel" => {Bumblebee.Vision.ResNet, :base},
"RobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling},
Expand Down Expand Up @@ -214,6 +217,7 @@ defmodule Bumblebee do
"gpt2" => Bumblebee.Text.Gpt2Tokenizer,
"layoutlm" => Bumblebee.Text.LayoutLmTokenizer,
"llama" => Bumblebee.Text.LlamaTokenizer,
"mistral" => Bumblebee.Text.LlamaTokenizer,
"mbart" => Bumblebee.Text.MbartTokenizer,
"roberta" => Bumblebee.Text.RobertaTokenizer,
"t5" => Bumblebee.Text.T5Tokenizer,
Expand Down
23 changes: 23 additions & 0 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1012,4 +1012,27 @@ defmodule Bumblebee.Layers do
x2 = x[[.., .., .., size..-1//1]]
Nx.concatenate([-x2, x1], axis: -1)
end

@doc """
Adds a repeat layer to the network.
## Options
* `:name` - layer name
* `:axis` - the axis to repeat along. Defaults to `-1`
"""
def repeat_interleave(x, times, opts \\ []) do
opts = Keyword.validate!(opts, [:name, axis: -1])

Axon.layer(
fn x, opts ->
axis = Nx.axis_index(x, opts[:axis])
Bumblebee.Utils.Nx.repeat_interleave(x, times, axis: axis)
end,
[x],
opts
)
end
end
40 changes: 26 additions & 14 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ defmodule Bumblebee.Layers.Transformer do

block_opts_keys = [
:num_attention_heads,
:num_key_value_heads,
:causal?,
:hidden_size,
:ffn,
Expand Down Expand Up @@ -298,6 +299,7 @@ defmodule Bumblebee.Layers.Transformer do
:num_attention_heads,
:hidden_size,
:ffn,
:num_key_value_heads,
attention_mask: Layers.none(),
attention_head_mask: Layers.none(),
attention_relative_bias: Layers.none(),
Expand All @@ -323,6 +325,7 @@ defmodule Bumblebee.Layers.Transformer do

name = opts[:name]
num_attention_heads = opts[:num_attention_heads]
num_key_value_heads = opts[:num_key_value_heads] || num_attention_heads
hidden_size = opts[:hidden_size]
ffn = opts[:ffn]
causal? = opts[:causal?]
Expand Down Expand Up @@ -392,6 +395,7 @@ defmodule Bumblebee.Layers.Transformer do
offset: offset,
causal?: causal?,
num_heads: num_attention_heads,
num_key_value_heads: num_key_value_heads,
hidden_size: hidden_size,
kernel_initializer: kernel_initializer,
attention_head_size: attention_head_size,
Expand Down Expand Up @@ -435,6 +439,7 @@ defmodule Bumblebee.Layers.Transformer do
attention_cache: cross_attention_cache,
offset: offset,
num_heads: num_attention_heads,
num_key_value_heads: num_key_value_heads,
hidden_size: hidden_size,
kernel_initializer: kernel_initializer,
attention_head_size: attention_head_size,
Expand Down Expand Up @@ -716,6 +721,7 @@ defmodule Bumblebee.Layers.Transformer do
:name,
:num_heads,
:hidden_size,
:num_key_value_heads,
attention_mask: Layers.none(),
attention_head_mask: Layers.none(),
attention_relative_bias: Layers.none(),
Expand All @@ -740,6 +746,7 @@ defmodule Bumblebee.Layers.Transformer do

name = opts[:name]
num_heads = opts[:num_heads]
num_key_value_heads = opts[:num_key_value_heads] || num_heads
hidden_size = opts[:hidden_size]
kernel_initializer = opts[:kernel_initializer]
causal? = opts[:causal?]
Expand All @@ -754,14 +761,9 @@ defmodule Bumblebee.Layers.Transformer do

attention_relative_bias = opts[:attention_relative_bias]

inner_size =
if attention_head_size = opts[:attention_head_size] do
num_heads * attention_head_size
else
hidden_size
end

head_size = div(hidden_size, num_heads)
attention_head_size = opts[:attention_head_size] || div(hidden_size, num_heads)
inner_size = num_heads * attention_head_size
inner_kv_size = num_key_value_heads * attention_head_size

query =
query
Expand All @@ -774,21 +776,21 @@ defmodule Bumblebee.Layers.Transformer do

key =
key
|> Axon.dense(inner_size,
|> Axon.dense(inner_kv_size,
kernel_initializer: kernel_initializer,
name: join(name, "key"),
use_bias: key_use_bias
)
|> Layers.split_heads(num_heads)
|> Layers.split_heads(num_key_value_heads)

value =
value
|> Axon.dense(inner_size,
|> Axon.dense(inner_kv_size,
kernel_initializer: kernel_initializer,
name: join(name, "value"),
use_bias: value_use_bias
)
|> Layers.split_heads(num_heads)
|> Layers.split_heads(num_key_value_heads)

{query, key} =
case rotary_embedding do
Expand All @@ -801,11 +803,11 @@ defmodule Bumblebee.Layers.Transformer do
{position_ids, opts} = Keyword.pop(opts, :position_ids)
{percentage, opts} = Keyword.pop(opts, :percentage)

size = trunc(head_size * percentage)
size = trunc(attention_head_size * percentage)

rotary_opts = [name: join(name, "rotary_embedding")] ++ opts

if size == head_size do
if size == attention_head_size do
Layers.rotary_embedding(query, key, position_ids, size, rotary_opts)
else
query_rotary = Axon.nx(query, & &1[[.., .., .., 0..(size - 1)//1]])
Expand All @@ -825,6 +827,10 @@ defmodule Bumblebee.Layers.Transformer do
{query, key}
end

num_key_value_groups = div(num_heads, num_key_value_heads)
key = repeat_states(key, num_key_value_groups)
value = repeat_states(value, num_key_value_groups)

{key, value, attention_cache} =
Layers.Decoder.cached_attention_key_values(key, value, attention_cache, offset)

Expand Down Expand Up @@ -882,6 +888,12 @@ defmodule Bumblebee.Layers.Transformer do
{attention_output, attention_weights, attention_cache, attention_relative_bias}
end

defp repeat_states(state, 1), do: state

defp repeat_states(state, times) do
Layers.repeat_interleave(state, times, axis: 2)
end

defp validate_required_keys!(opts, keys) do
case keys -- Keyword.keys(opts) do
[] -> :ok
Expand Down
6 changes: 6 additions & 0 deletions lib/bumblebee/text/llama.ex
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ defmodule Bumblebee.Text.Llama do
default: 32,
doc: "the number of attention heads for each attention layer in the model"
],
num_key_value_heads: [
default: nil,
doc: "the number of key value heads for each attention layer in the model"
],
activation: [
default: :silu,
doc: "the activation function"
Expand Down Expand Up @@ -302,6 +306,7 @@ defmodule Bumblebee.Text.Llama do
cache: cache,
num_blocks: spec.num_blocks,
num_attention_heads: spec.num_attention_heads,
num_key_value_heads: spec.num_key_value_heads,
hidden_size: spec.hidden_size,
kernel_initializer: kernel_initializer(spec),
layer_norm: &Layers.rms_norm(&1, name: &2, epsilon: spec.layer_norm_epsilon),
Expand Down Expand Up @@ -365,6 +370,7 @@ defmodule Bumblebee.Text.Llama do
hidden_size: {"hidden_size", number()},
num_blocks: {"num_hidden_layers", number()},
num_attention_heads: {"num_attention_heads", number()},
num_key_value_heads: {"num_key_value_heads", number()},
intermediate_size: {"intermediate_size", number()},
activation: {"hidden_act", atom()},
initializer_scale: {"initializer_range", number()},
Expand Down
Loading

0 comments on commit 38876bb

Please sign in to comment.