Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Llama #199

Merged
merged 23 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ defmodule Bumblebee do
"LayoutLMForTokenClassification" =>
{Bumblebee.Multimodal.LayoutLm, :for_token_classification},
"LayoutLMModel" => {Bumblebee.Multimodal.LayoutLm, :base},
"LlamaModel" => {Bumblebee.Text.Llama, :base},
"LlamaForCausalLM" => {Bumblebee.Text.Llama, :for_causal_language_modeling},
"LLaMAForCausalLM" => {Bumblebee.Text.Llama, :for_causal_language_modeling},
"LlamaForSequenceClassification" => {Bumblebee.Text.Llama, :for_sequence_classification},
"MBartForCausalLM" => {Bumblebee.Text.Mbart, :for_causal_language_modeling},
"MBartForConditionalGeneration" => {Bumblebee.Text.Mbart, :for_conditional_generation},
"MBartForQuestionAnswering" => {Bumblebee.Text.Mbart, :for_question_answering},
Expand Down Expand Up @@ -194,6 +198,7 @@ defmodule Bumblebee do
"clip" => Bumblebee.Text.ClipTokenizer,
"gpt2" => Bumblebee.Text.Gpt2Tokenizer,
"layoutlm" => Bumblebee.Text.LayoutLmTokenizer,
"llama" => Bumblebee.Text.LlamaTokenizer,
"mbart" => Bumblebee.Text.MbartTokenizer,
"roberta" => Bumblebee.Text.RobertaTokenizer,
"t5" => Bumblebee.Text.T5Tokenizer,
Expand Down
83 changes: 83 additions & 0 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -965,4 +965,87 @@ defmodule Bumblebee.Layers do

x * weight
end

@doc """
Adds a rotary embedding layer to the network.
"""
def rotary_embedding(query, key, value, position_ids, dim, opts \\ []) do
opts = Keyword.validate!(opts, [:name, max_position_embeddings: 2048, base: 10_000])

out = Axon.layer(&rotary_embedding_impl/2, [value], [{:dim, dim} | opts])
{sin, cos} = {Axon.nx(out, &elem(&1, 0)), Axon.nx(out, &elem(&1, 1))}

out = Axon.layer(&apply_rotary_embedding/6, [query, key, cos, sin, position_ids])
{Axon.nx(out, &elem(&1, 0)), Axon.nx(out, &elem(&1, 1))}
end

defnp rotary_embedding_impl(value, opts \\ []) do
opts = keyword!(opts, [:dim, mode: :inference, max_position_embeddings: 2048, base: 10_000])
base = opts[:base]
dim = opts[:dim]

seq_len = Nx.axis_size(value, 1)

inv_freq = compute_inv_freq(base, dim)

t = Nx.iota({opts[:max_position_embeddings]})
freqs = Nx.outer(t, inv_freq)

emb = Nx.concatenate([freqs, freqs], axis: -1)

cos =
emb
|> Nx.cos()
|> Nx.new_axis(0)
|> Nx.new_axis(0)

sin =
emb
|> Nx.sin()
|> Nx.new_axis(0)
|> Nx.new_axis(0)

{cos[[.., .., 0..(seq_len - 1), ..]], sin[[.., .., 0..(seq_len - 1), ..]]}
end

deftransformp compute_inv_freq(base, dim) do
dim = div(dim, 2)
range = Nx.multiply(Nx.iota({dim}), 2)
Nx.divide(1.0, Nx.pow(base, range))
end

defnp apply_rotary_embedding(query, key, cos, sin, position_ids, _opts) do
{bsz, seq} = Nx.shape(position_ids)

query = Nx.transpose(query, axes: [0, 2, 1, 3])
key = Nx.transpose(key, axes: [0, 2, 1, 3])

gather_indices =
position_ids
|> Nx.reshape({bsz, 1, seq, 1})
|> Nx.broadcast({bsz, Nx.axis_size(cos, 1), seq, Nx.axis_size(cos, 3)})
|> Nx.as_type(:s64)

cos =
cos
|> Nx.broadcast({bsz, Nx.axis_size(cos, 1), Nx.axis_size(cos, 2), Nx.axis_size(cos, 3)})
|> Nx.take_along_axis(gather_indices, axis: 2)

sin =
sin
|> Nx.broadcast({bsz, Nx.axis_size(sin, 1), Nx.axis_size(sin, 2), Nx.axis_size(sin, 3)})
|> Nx.take_along_axis(gather_indices, axis: 2)

q_embed = query * cos + rotate_half(query) * sin
k_embed = key * cos + rotate_half(key) * sin

{Nx.transpose(q_embed, axes: [0, 2, 1, 3]), Nx.transpose(k_embed, axes: [0, 2, 1, 3])}
end

defnp rotate_half(x) do
i = div(Nx.axis_size(x, -1), 2)
x1 = x[[.., .., .., 0..(i - 1)//1]]
x2 = x[[.., .., .., i..-1//1]]
Nx.concatenate([-x2, x1], axis: -1)
end
end
39 changes: 36 additions & 3 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ defmodule Bumblebee.Layers.Transformer do
:layer_norm,
:norm_placement,
:output_shortcut,
:scale_query?
:scale_query?,
:use_rotary_embedding?
]

opts =
Expand All @@ -66,6 +67,7 @@ defmodule Bumblebee.Layers.Transformer do
[
:name,
:num_blocks,
position_ids: Layers.none(),
attention_mask: Layers.none(),
attention_head_mask: Layers.none(),
attention_relative_bias: nil,
Expand All @@ -84,6 +86,7 @@ defmodule Bumblebee.Layers.Transformer do
output_hidden_states = opts[:output_hidden_states]
output_attentions = opts[:output_attentions]

position_ids = opts[:position_ids]
attention_mask = opts[:attention_mask]
attention_head_mask = opts[:attention_head_mask]
cross_hidden_state = opts[:cross_hidden_state]
Expand Down Expand Up @@ -123,6 +126,7 @@ defmodule Bumblebee.Layers.Transformer do
block(
state.hidden_state,
[
position_ids: position_ids,
attention_mask: attention_mask,
attention_head_mask: block_attention_head_mask,
attention_relative_bias: attention_relative_bias,
Expand Down Expand Up @@ -260,6 +264,9 @@ defmodule Bumblebee.Layers.Transformer do
* `:scale_query?` - whether to scale query in the traditional style of
multi-headed attention. Defaults to `true`

* `:use_rotary_embedding?` - whether or not to use rotary position embedding
in multi-headed attention. Defaults to `false`

* `:name` - the prefix for layer names

## References
Expand All @@ -276,6 +283,7 @@ defmodule Bumblebee.Layers.Transformer do
:num_attention_heads,
:hidden_size,
:ffn,
position_ids: Layers.none(),
attention_mask: Layers.none(),
attention_head_mask: Layers.none(),
attention_relative_bias: Layers.none(),
Expand All @@ -296,7 +304,8 @@ defmodule Bumblebee.Layers.Transformer do
norm_placement: :last,
layer_norm: [],
output_shortcut: true,
scale_query?: true
scale_query?: true,
use_rotary_embedding?: false
])

name = opts[:name]
Expand All @@ -312,6 +321,7 @@ defmodule Bumblebee.Layers.Transformer do
key_use_bias = opts[:key_use_bias]
value_use_bias = opts[:value_use_bias]
output_use_bias = opts[:output_use_bias]
position_ids = opts[:position_ids]
attention_mask = opts[:attention_mask]
attention_head_mask = opts[:attention_head_mask]
attention_relative_bias = opts[:attention_relative_bias]
Expand All @@ -324,6 +334,7 @@ defmodule Bumblebee.Layers.Transformer do
norm_placement = opts[:norm_placement]
output_shortcut = opts[:output_shortcut]
scale_query? = opts[:scale_query?]
use_rotary_embedding? = opts[:use_rotary_embedding?]

ffn_fun =
case ffn do
Expand Down Expand Up @@ -368,6 +379,7 @@ defmodule Bumblebee.Layers.Transformer do

{hidden_state, attention, self_attention_cache, attention_relative_bias} =
multi_head_attention(hidden_state, hidden_state, hidden_state,
position_ids: position_ids,
attention_mask: attention_mask,
attention_head_mask: attention_head_mask,
attention_relative_bias: attention_relative_bias,
Expand All @@ -384,6 +396,7 @@ defmodule Bumblebee.Layers.Transformer do
value_use_bias: value_use_bias,
output_use_bias: output_use_bias,
scale_query?: scale_query?,
use_rotary_embedding?: use_rotary_embedding?,
name: join(name, "self_attention")
)

Expand All @@ -410,6 +423,7 @@ defmodule Bumblebee.Layers.Transformer do

{hidden_state, cross_attention, cross_attention_cache, _cross_attention_relative_bias} =
multi_head_attention(hidden_state, cross_hidden_state, cross_hidden_state,
position_ids: position_ids,
attention_mask: cross_attention_mask,
attention_head_mask: cross_attention_head_mask,
attention_cache: cross_attention_cache,
Expand All @@ -424,6 +438,7 @@ defmodule Bumblebee.Layers.Transformer do
value_use_bias: value_use_bias,
output_use_bias: output_use_bias,
scale_query?: scale_query?,
use_rotary_embedding?: use_rotary_embedding?,
name: join(name, "cross_attention")
)

Expand Down Expand Up @@ -569,6 +584,7 @@ defmodule Bumblebee.Layers.Transformer do
attention_relative_bias: Layers.none(),
attention_cache: Layers.none(),
offset: Layers.none(),
position_ids: Layers.none(),
causal?: false,
scale_query?: true,
kernel_initializer: :glorot_uniform,
Expand All @@ -577,9 +593,11 @@ defmodule Bumblebee.Layers.Transformer do
query_use_bias: true,
key_use_bias: true,
value_use_bias: true,
output_use_bias: true
output_use_bias: true,
use_rotary_embedding?: false
])

position_ids = opts[:position_ids]
attention_mask = opts[:attention_mask]
attention_head_mask = opts[:attention_head_mask]
attention_cache = opts[:attention_cache]
Expand All @@ -592,6 +610,7 @@ defmodule Bumblebee.Layers.Transformer do
causal? = opts[:causal?]
scale_query? = opts[:scale_query?]
dropout_rate = opts[:dropout_rate]
use_rotary_embedding? = opts[:use_rotary_embedding?]

query_use_bias = opts[:query_use_bias]
key_use_bias = opts[:key_use_bias]
Expand Down Expand Up @@ -634,6 +653,20 @@ defmodule Bumblebee.Layers.Transformer do
)
|> Layers.split_heads(num_heads)

{query, key} =
if use_rotary_embedding? do
Layers.rotary_embedding(
query,
key,
value,
position_ids,
div(hidden_size, num_heads),
name: join(name, "rotary_embedding")
)
else
{query, key}
end

{key, value, attention_cache} =
Layers.Decoder.cached_attention_key_values(key, value, attention_cache, offset)

Expand Down
Loading