From eff2e740e2c71022e82809f5fbd00ae6b5b1125c Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 17 Nov 2023 13:35:38 -0300 Subject: [PATCH 1/6] test: Nx.Random --- test/random_test.exs | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 test/random_test.exs diff --git a/test/random_test.exs b/test/random_test.exs new file mode 100644 index 0000000..569f48a --- /dev/null +++ b/test/random_test.exs @@ -0,0 +1,32 @@ +defmodule Candlex.RandomTest do + use Nx.Case, async: true + + test "key/1" do + Nx.Random.key(42) + |> assert_equal(Nx.tensor([0, 42])) + end + + test "uniform/1" do + {normal, new_key} = + Nx.Random.key(42) + |> Nx.Random.uniform() + + normal + |> assert_close(Nx.tensor(0.9145736694335938)) + + new_key + |> assert_equal(Nx.tensor([2_465_931_498, 3_679_230_171])) + end + + test "normal/1" do + {normal, new_key} = + Nx.Random.key(42) + |> Nx.Random.normal() + + normal + |> assert_close(Nx.tensor(1.3694695234298706)) + + new_key + |> assert_equal(Nx.tensor([2_465_931_498, 3_679_230_171])) + end +end From 243db6418b548bb518b5f871f4faa0c85c3a3754 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 17 Nov 2023 13:48:57 -0300 Subject: [PATCH 2/6] cargo --release --- lib/candlex/native.ex | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/candlex/native.ex b/lib/candlex/native.ex index d6fddd7..8f26c81 100644 --- a/lib/candlex/native.ex +++ b/lib/candlex/native.ex @@ -4,14 +4,14 @@ defmodule Candlex.Native do mix_config = Mix.Project.config() version = mix_config[:version] source_url = mix_config[:package][:links]["GitHub"] - mode = if Mix.env() in [:dev, :test], do: :debug, else: :release + # mode = if Mix.env() in [:dev, :test], do: :debug, else: :release use RustlerPrecompiled, otp_app: :candlex, features: if(Application.compile_env(:candlex, :use_cuda), do: [:cuda], else: []), base_url: "#{source_url}/releases/download/v#{version}", force_build: System.get_env("CANDLEX_NIF_BUILD") in ["1", "true"], - mode: mode, + # mode: mode, version: version, nif_versions: ["2.16"], targets: [ From c13ab7ec9c0b2cdcf7420b430f3bcb617af737cf Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 17 Nov 2023 13:57:24 -0300 Subject: [PATCH 3/6] test: defn_while --- test/defn_test.exs | 86 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/test/defn_test.exs b/test/defn_test.exs index 58e6714..c4abbe4 100644 --- a/test/defn_test.exs +++ b/test/defn_test.exs @@ -21,4 +21,90 @@ defmodule Candlex.DefnTest do |> TG.tanh_grad() |> assert_close(Nx.tensor(0.41997432708740234)) end + + describe "while/3" do + defmodule Mod do + import Nx.Defn + + defn upto10(x) do + while x, Nx.less(x, 10) do + x + 1 + end + end + + defn factorial_tuple(x) do + factorial = Nx.tensor(1, type: Nx.type(x)) + + {factorial, _} = + while {factorial, x}, Nx.greater(x, 1) do + {factorial * x, x - 1} + end + + factorial + end + + defn factorial_map(x) do + factorial = Nx.tensor(1, type: Nx.type(x)) + + %{factorial: factorial} = + while map = %{factorial: factorial, x: x}, Nx.greater(map.x, 1) do + %{map | factorial: map.factorial * map.x, x: map.x - 1} + end + + factorial + end + + defn factorial_map_input(map) do + %{factorial: factorial} = + while map, Nx.greater(map.x, 1) do + %{map | factorial: map.factorial * map.x, x: map.x - 1} + end + + factorial + end + + defn tensor_generator_sum() do + while x = 0, r <- Nx.tensor([0, 1, 2]) do + x + r + end + end + end + + test "simple" do + Mod.upto10(0) + |> assert_equal(Nx.tensor(10)) + + Mod.upto10(5) + |> assert_equal(Nx.tensor(10)) + end + + test "factorial tuple" do + Mod.factorial_tuple(5) + |> assert_equal(Nx.tensor(120)) + + Mod.factorial_tuple(10.0) + |> assert_equal(Nx.tensor(3_628_800.0)) + end + + test "factorial map" do + Mod.factorial_map(5) + |> assert_equal(Nx.tensor(120)) + + Mod.factorial_map(10.0) + |> assert_equal(Nx.tensor(3_628_800.0)) + end + + test "factorial map input" do + Mod.factorial_map_input(%{factorial: 1, x: 5}) + |> assert_equal(Nx.tensor(120)) + + Mod.factorial_map_input(%{factorial: 1.0, x: 10.0}) + |> assert_equal(Nx.tensor(3_628_800.0)) + end + + test "tensor generator sum" do + Mod.tensor_generator_sum() + |> assert_equal(Nx.tensor(3)) + end + end end From 9677568e4e23cf06879891ade49251ee512826dc Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 17 Nov 2023 14:03:37 -0300 Subject: [PATCH 4/6] test: Nx.slice/3 with tensors as start_indices --- test/candlex_test.exs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/candlex_test.exs b/test/candlex_test.exs index c130e59..fc48c66 100644 --- a/test/candlex_test.exs +++ b/test/candlex_test.exs @@ -1824,6 +1824,18 @@ defmodule CandlexTest do # [5, 6] # ] # )) + + t([0, 1]) + |> Nx.slice([t(0)], [1]) + |> assert_equal(t([0])) + + t([0, 1]) + |> Nx.slice([t(1)], [1]) + |> assert_equal(t([1])) + + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.slice([t(0), t(1)], [1, 1]) + |> assert_equal(t([[2]])) end test "squeeze" do From 03e249e4b6257ec8867622e91d8b6d123e374fac Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 17 Nov 2023 14:07:50 -0300 Subject: [PATCH 5/6] better comment --- lib/candlex/native.ex | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/candlex/native.ex b/lib/candlex/native.ex index 8f26c81..3b01f1c 100644 --- a/lib/candlex/native.ex +++ b/lib/candlex/native.ex @@ -4,14 +4,17 @@ defmodule Candlex.Native do mix_config = Mix.Project.config() version = mix_config[:version] source_url = mix_config[:package][:links]["GitHub"] + # We can't run on :debug mode until we find a workaround to + # ignore integer overflows when running Nx.Random Threefry PRNG. # mode = if Mix.env() in [:dev, :test], do: :debug, else: :release + mode = :release use RustlerPrecompiled, otp_app: :candlex, features: if(Application.compile_env(:candlex, :use_cuda), do: [:cuda], else: []), base_url: "#{source_url}/releases/download/v#{version}", force_build: System.get_env("CANDLEX_NIF_BUILD") in ["1", "true"], - # mode: mode, + mode: mode, version: version, nif_versions: ["2.16"], targets: [ From 36349fd89c3c2ae7623dfe6b48d432ff4726b47a Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 17 Nov 2023 14:13:02 -0300 Subject: [PATCH 6/6] fix: fixes bug in Nx.slice when receiving Nx.Tensor as start_indices --- lib/candlex/backend.ex | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lib/candlex/backend.ex b/lib/candlex/backend.ex index d783e0b..f7d1c31 100644 --- a/lib/candlex/backend.ex +++ b/lib/candlex/backend.ex @@ -728,7 +728,11 @@ defmodule Candlex.Backend do defp narrow(t, [start | starts], [length | lengths], axis, shape) do dim = elem(shape, axis) - start = min(start, dim - length) + + start = + start + |> Nx.to_number() + |> min(dim - length) if start == 0 and length == dim do # Nothing to narrow at this step