Skip to content

Commit

Permalink
Fix M2M100 with batched input
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Nov 19, 2024
1 parent 45e7408 commit 29bdfa4
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion lib/bumblebee/text/m2m100.ex
Original file line number Diff line number Diff line change
Expand Up @@ -437,14 +437,18 @@ defmodule Bumblebee.Text.M2m100 do
end

defnp sinusoidal_position_embedding_impl(position_ids, opts \\ []) do
position_ids = Nx.vectorize(position_ids, :batch)

size = opts[:size]

half_size = div(size, 2)
base = 10_000
range = Nx.iota({half_size}) / (half_size - 1)
inv_frequency = 1 / Nx.pow(base, range)
angle = Nx.outer(position_ids, inv_frequency)
Nx.concatenate([Nx.sin(angle), Nx.cos(angle)], axis: -1)
sin_cos = Nx.concatenate([Nx.sin(angle), Nx.cos(angle)], axis: -1)

Nx.devectorize(sin_cos)
end

defp decoder(
Expand Down

0 comments on commit 29bdfa4

Please sign in to comment.