Skip to content

Commit

Permalink
Fix encoding of bf16
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Sep 27, 2024
1 parent d31c33e commit 08e3330
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
7 changes: 6 additions & 1 deletion exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -921,13 +921,14 @@ defmodule EXLA.MLIR.Value do
end
end

defp float_hex(value, {_, size} = type) do
defp float_hex(value, {mod, size} = type) do
data =
case value do
:nan -> type |> Nx.Type.nan_binary() |> native_to_big()
:infinity -> type |> Nx.Type.infinity_binary() |> native_to_big()
:neg_infinity -> type |> Nx.Type.neg_infinity_binary() |> native_to_big()
value when size == 8 -> f8E5M2_to_big(value)
value when mod == :bf and size == 16 -> bf16_to_big(value)
value -> <<value::float-size(size)-big>>
end

Expand All @@ -938,6 +939,10 @@ defmodule EXLA.MLIR.Value do
binary_part(<<x::float-big-16>>, 0, 1)
end

defp bf16_to_big(x) do
binary_part(<<x::float-big-32>>, 0, 2)
end

defp native_to_big(binary) do
size = byte_size(binary) * 8
<<value::size(size)-native>> = binary
Expand Down
22 changes: 13 additions & 9 deletions exla/test/exla/defn/expr_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,23 @@ defmodule EXLA.Defn.ExprTest do
end
end

describe "float8" do
defn return_float8, do: Nx.tensor(1, type: {:f, 8})
describe "types" do
defn return_f8, do: Nx.tensor(1, type: {:f, 8})

test "supports float8 return types" do
assert_equal(return_float8(), Nx.tensor(1, type: {:f, 8}))
test "f8" do
assert_equal(return_f8(), Nx.tensor(1, type: {:f, 8}))
end

defn return_f16, do: Nx.tensor(1, type: {:f, 16})

test "f16" do
assert_equal(return_f16(), Nx.tensor(1, type: {:f, 16}))
end
end

describe "float16" do
defn return_float, do: Nx.tensor(1, type: {:f, 16})
defn return_bf16, do: Nx.tensor(1, type: {:bf, 16})

test "supports float16 return types" do
assert_equal(return_float(), Nx.tensor(1, type: {:f, 16}))
test "bf16" do
assert_equal(return_bf16(), Nx.tensor(1, type: {:bf, 16}))
end
end

Expand Down

0 comments on commit 08e3330

Please sign in to comment.