Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Oct 20, 2023
1 parent 832beb6 commit 143f0d6
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ defmodule Bumblebee.Layers.Transformer do

name = opts[:name]
num_heads = opts[:num_heads]
num_key_value_heads = opts[:num_key_value_heads]
num_key_value_heads = opts[:num_key_value_heads] || num_heads
hidden_size = opts[:hidden_size]
kernel_initializer = opts[:kernel_initializer]
causal? = opts[:causal?]
Expand Down Expand Up @@ -839,10 +839,9 @@ defmodule Bumblebee.Layers.Transformer do
{query, key}
end

{key, value} =
num_key_value_groups = div(num_heads, num_key_value_heads)
key = repeat_states(key, num_key_value_groups)
value = repeat_states(value, num_key_value_groups)
num_key_value_groups = div(num_heads, num_key_value_heads)
key = repeat_states(key, num_key_value_groups)
value = repeat_states(value, num_key_value_groups)

{key, value, attention_cache} =
Layers.Decoder.cached_attention_key_values(key, value, attention_cache, offset)
Expand Down

0 comments on commit 143f0d6

Please sign in to comment.