Skip to content

Commit

Permalink
Support more rotary embedding options for Llama (#285)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Nov 17, 2023
1 parent 5178a6d commit 5773ccf
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 13 deletions.
67 changes: 58 additions & 9 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -973,31 +973,80 @@ defmodule Bumblebee.Layers do
@doc """
Adds a rotary embedding layer to the network.
"""
def rotary_embedding(query, key, position_ids, size, opts \\ []) do
opts = Keyword.validate!(opts, [:name, max_positions: 2048, base: 10_000])
def rotary_embedding(query, key, position_ids, attention_mask, size, opts \\ []) do
opts = Keyword.validate!(opts, [:name, :scaling_strategy, max_positions: 2048, base: 10_000])

output =
Axon.layer(&apply_rotary_embedding/4, [query, key, position_ids], [size: size] ++ opts)
Axon.layer(
&apply_rotary_embedding/5,
[query, key, position_ids, Axon.optional(attention_mask)],
[size: size] ++ opts
)

unwrap_tuple(output, 2)
end

deftransformp create_sinusoidal_positions(max_positions, size, base) do
deftransformp create_sinusoidal_positions(
sequence_length,
max_positions,
size,
base,
scaling_strategy
) do
position = Nx.iota({sequence_length})

{base, position} =
case scaling_strategy do
%{type: :linear, factor: factor} ->
{base, Nx.divide(position, factor)}

%{type: :dynamic, factor: factor} when sequence_length > max_positions ->
base =
base
|> Nx.multiply(factor * sequence_length / max_positions - (factor - 1))
|> Nx.pow(size / (size - 2))

{base, position}

_other ->
{base, position}
end

range = Nx.iota({div(size, 2)}) |> Nx.multiply(2) |> Nx.divide(size)
inv_frequency = Nx.divide(1.0, Nx.pow(base, range))

position = Nx.iota({max_positions})
angle = Nx.outer(position, inv_frequency)

angle = Nx.concatenate([angle, angle], axis: -1)

{Nx.cos(angle), Nx.sin(angle)}
end

defnp apply_rotary_embedding(query, key, position_ids, opts \\ []) do
opts = keyword!(opts, [:size, mode: :inference, max_positions: 2048, base: 10_000])
defnp apply_rotary_embedding(query, key, position_ids, attention_mask, opts \\ []) do
opts =
keyword!(opts, [
:size,
:scaling_strategy,
mode: :inference,
max_positions: 2048,
base: 10_000
])

{cos, sin} = create_sinusoidal_positions(opts[:max_positions], opts[:size], opts[:base])
# When decoding with cache position_ids may be a partial sequence,
# but in that case we always have full-length attention mask
sequence_length =
case attention_mask do
%Axon.None{} -> Nx.axis_size(position_ids, 1)
_other -> Nx.axis_size(attention_mask, 1)
end

{cos, sin} =
create_sinusoidal_positions(
sequence_length,
opts[:max_positions],
opts[:size],
opts[:base],
opts[:scaling_strategy]
)

position_ids = Nx.as_type(position_ids, :s64)

Expand Down
19 changes: 16 additions & 3 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,13 @@ defmodule Bumblebee.Layers.Transformer do
validate_required_keys!(opts, [:position_ids])

opts =
Keyword.validate!(opts, [:position_ids, :max_positions, base: 10_000, percentage: 1.0])
Keyword.validate!(opts, [
:position_ids,
:max_positions,
:scaling_strategy,
base: 10_000,
percentage: 1.0
])

{position_ids, opts} = Keyword.pop(opts, :position_ids)
{percentage, opts} = Keyword.pop(opts, :percentage)
Expand All @@ -808,7 +814,7 @@ defmodule Bumblebee.Layers.Transformer do
rotary_opts = [name: join(name, "rotary_embedding")] ++ opts

if size == attention_head_size do
Layers.rotary_embedding(query, key, position_ids, size, rotary_opts)
Layers.rotary_embedding(query, key, position_ids, attention_mask, size, rotary_opts)
else
query_rotary = Axon.nx(query, & &1[[.., .., .., 0..(size - 1)//1]])
query_pass = Axon.nx(query, & &1[[.., .., .., size..-1//1]])
Expand All @@ -817,7 +823,14 @@ defmodule Bumblebee.Layers.Transformer do
key_pass = Axon.nx(key, & &1[[.., .., .., size..-1//1]])

{query_rotary, key_rotary} =
Layers.rotary_embedding(query_rotary, key_rotary, position_ids, size, rotary_opts)
Layers.rotary_embedding(
query_rotary,
key_rotary,
position_ids,
attention_mask,
size,
rotary_opts
)

{Axon.concatenate([query_rotary, query_pass], axis: -1),
Axon.concatenate([key_rotary, key_pass], axis: -1)}
Expand Down
38 changes: 37 additions & 1 deletion lib/bumblebee/text/llama.ex
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,22 @@ defmodule Bumblebee.Text.Llama do
default: :silu,
doc: "the activation function"
],
rotary_embedding_base: [
default: 10_000,
doc: "base for computing rotary embedding frequency"
],
rotary_embedding_scaling_strategy: [
default: nil,
doc: """
scaling configuration for rotary embedding. Currently the supported values are:
* `%{type: :linear, factor: number()}`
* `%{type: :dynamic, factor: number()}`
For more details see https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases
"""
],
layer_norm_epsilon: [
default: 1.0e-12,
doc: "the epsilon used by RMS normalization layers"
Expand Down Expand Up @@ -317,7 +333,12 @@ defmodule Bumblebee.Text.Llama do
),
block_type: :norm_first,
causal?: true,
rotary_embedding: [position_ids: position_ids, max_positions: spec.max_positions],
rotary_embedding: [
position_ids: position_ids,
max_positions: spec.max_positions,
base: spec.rotary_embedding_base,
scaling_strategy: spec.rotary_embedding_scaling_strategy
],
query_use_bias: false,
key_use_bias: false,
value_use_bias: false,
Expand Down Expand Up @@ -363,6 +384,19 @@ defmodule Bumblebee.Text.Llama do
def load(spec, data) do
import Shared.Converters

scaling_strategy_converter = fn name, value ->
case value do
%{"type" => "linear", "factor" => factor} when is_number(factor) ->
{:ok, %{type: :linear, factor: factor}}

%{"type" => "dynamic", "factor" => factor} when is_number(factor) ->
{:ok, %{type: :dynamic, factor: factor}}

_other ->
{:error, "invalid format for #{inspect(name)}, got: #{inspect(value)}"}
end
end

opts =
convert!(data,
vocab_size: {"vocab_size", number()},
Expand All @@ -373,6 +407,8 @@ defmodule Bumblebee.Text.Llama do
num_key_value_heads: {"num_key_value_heads", number()},
intermediate_size: {"intermediate_size", number()},
activation: {"hidden_act", atom()},
rotary_embedding_base: {"rope_theta", number()},
rotary_embedding_scaling_strategy: {"rope_scaling", scaling_strategy_converter},
initializer_scale: {"initializer_range", number()},
layer_norm_epsilon: {"rms_norm_eps", number()}
) ++ Shared.common_options_from_transformers(data, spec)
Expand Down

0 comments on commit 5773ccf

Please sign in to comment.