diff --git a/test/candlex_test.exs b/test/candlex_test.exs index 8e5b56a..c130e59 100644 --- a/test/candlex_test.exs +++ b/test/candlex_test.exs @@ -1807,6 +1807,43 @@ defmodule CandlexTest do ) end + test "slice" do + t([1, 2, 3, 4, 5, 6]) + |> Nx.slice([0], [3]) + |> assert_equal(t([1, 2, 3])) + + # t([1, 2, 3, 4, 5, 6]) + # |> Nx.slice([0], [6], strides: [2]) + # |> assert_equal(t([1, 3, 5])) + + # t([[1, 2], [3, 4], [5, 6]]) + # |> Nx.slice([0, 0], [3, 2], strides: [2, 1]) + # |> assert_equal(t( + # [ + # [1, 2], + # [5, 6] + # ] + # )) + end + + test "squeeze" do + t([[[[[1]]]]]) + |> Nx.squeeze() + |> assert_equal(t(1)) + + t([[[[1]]], [[[2]]]]) + |> Nx.squeeze() + |> assert_equal(t([1, 2])) + + t([[1, 2, 3]], names: [:x, :y]) + |> Nx.squeeze(axes: [:x]) + |> assert_equal(t([1, 2, 3])) + + t([[1], [2]], names: [:x, :y]) + |> Nx.squeeze(axes: [:y]) + |> assert_equal(t([1, 2])) + end + test "put_slice" do t([0, 1, 2, 3, 4]) |> Nx.put_slice([2], Nx.tensor([5, 6]))