Skip to content

Commit

Permalink
build(deps): bump nx (#23)
Browse files Browse the repository at this point in the history
* build(deps): bump nx

* bump nx

* commented gather tests with axes
  • Loading branch information
grzuy authored Nov 13, 2023
1 parent 3550e1e commit 08ef7ca
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 10 deletions.
16 changes: 10 additions & 6 deletions lib/candlex/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ defmodule Candlex.Backend do
# Indexed

@impl true
def gather(%T{} = out, %T{shape: {_}} = tensor, %T{} = indices) do
def gather(%T{} = out, %T{shape: {_}} = tensor, %T{} = indices, _opts) do
tensor
|> from_nx()
|> Native.gather(from_nx(Nx.flatten(indices)), 0)
Expand All @@ -391,7 +391,7 @@ defmodule Candlex.Backend do
end

@impl true
def indexed_add(%T{} = out, %T{shape: {_}} = tensor, %T{} = indices, %T{} = updates) do
def indexed_add(%T{} = out, %T{shape: {_}} = tensor, %T{} = indices, %T{} = updates, _opts) do
{tensor, updates} = maybe_upcast(tensor, updates)

tensor
Expand Down Expand Up @@ -887,7 +887,6 @@ defmodule Candlex.Backend do
end

for op <- [
:indexed_put,
:map,
:triangular_solve,
:window_max,
Expand All @@ -901,9 +900,14 @@ defmodule Candlex.Backend do
end
end

@impl true
def reduce(_out, _tensor, _, _, _) do
raise "unsupported Candlex.Backend.reduce function"
for op <- [
:indexed_put,
:reduce
] do
@impl true
def unquote(op)(_out, _tensor, _, _, _) do
raise "unsupported Candlex.Backend.#{unquote(op)} function"
end
end

for op <- [
Expand Down
8 changes: 4 additions & 4 deletions mix.lock
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
%{
"castore": {:hex, :castore, "1.0.4", "ff4d0fb2e6411c0479b1d965a814ea6d00e51eb2f58697446e9c41a97d940b28", [:mix], [], "hexpm", "9418c1b8144e11656f0be99943db4caf04612e3eaecefb5dae9a2a87565584f8"},
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
"earmark_parser": {:hex, :earmark_parser, "1.4.37", "2ad73550e27c8946648b06905a57e4d454e4d7229c2dafa72a0348c99d8be5f7", [:mix], [], "hexpm", "6b19783f2802f039806f375610faa22da130b8edc21209d0bff47918bb48360e"},
"earmark_parser": {:hex, :earmark_parser, "1.4.38", "b42252eddf63bda05554ba8be93a1262dc0920c721f1aaf989f5de0f73a2e367", [:mix], [], "hexpm", "2cd0907795aaef0c7e8442e376633c5b3bd6edc8dbbdc539b22f095501c1cdb6"},
"ex_doc": {:hex, :ex_doc, "0.30.9", "d691453495c47434c0f2052b08dd91cc32bc4e1a218f86884563448ee2502dd2", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "d7aaaf21e95dc5cddabf89063327e96867d00013963eadf2c6ad135506a8bc10"},
"jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"},
"makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"},
"makeup": {:hex, :makeup, "1.1.1", "fa0bc768698053b2b3869fa8a62616501ff9d11a562f3ce39580d60860c3a55e", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "5dc62fbdd0de44de194898b6710692490be74baa02d9d108bc29f007783b0b48"},
"makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"},
"makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"},
"nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"},
"nx": {:git, "https://github.com/elixir-nx/nx", "27e7b5658b6d88ca5e9106ef0f09ad173bb0f154", [sparse: "nx"]},
"nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"},
"nx": {:git, "https://github.com/elixir-nx/nx", "e1b776ed2a49498cbf2465862b2fba5a0df6f43b", [sparse: "nx"]},
"rustler": {:hex, :rustler, "0.30.0", "cefc49922132b072853fa9b0ca4dc2ffcb452f68fb73b779042b02d545e097fb", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "9ef1abb6a7dda35c47cfc649e6a5a61663af6cf842a55814a554a84607dee389"},
"rustler_precompiled": {:hex, :rustler_precompiled, "0.7.0", "5d0834fc06dbc76dd1034482f17b1797df0dba9b491cef8bb045fcaca94bcade", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "fdf43a6835f4e4de5bfbc4c019bfb8c46d124bd4635fefa3e20d9a2bbbec1512"},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
Expand Down
32 changes: 32 additions & 0 deletions test/candlex_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1639,6 +1639,38 @@ defmodule CandlexTest do
# t([[[1, 2], [11, 12]], [[101, 102], [111, 112]]])
# |> Nx.gather(t([[0, 0, 0], [0, 1, 1], [1, 1, 1]]))
# |> assert_equal(t([1, 12, 112]))

# t([[1, 2, 3], [4, 5, 6]])
# |> Nx.gather(t([[1], [0], [2], [1]]), axes: [1])
# |> assert_equal(t(
# [
# [2, 5],
# [1, 4],
# [3, 6],
# [2, 5]
# ]
# ))

# Nx.iota({2, 1, 3})
# |> Nx.gather(t([[[1], [0], [2]]]), axes: [2])
# |> assert_equal(t(
# [
# [
# [
# [1],
# [4]
# ],
# [
# [0],
# [3]
# ],
# [
# [2],
# [5]
# ]
# ]
# ]
# ))
end

test "indexed_add" do
Expand Down

0 comments on commit 08ef7ca

Please sign in to comment.