diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 2b53358093..72ad7860b6 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -5417,7 +5417,7 @@ defmodule Nx do @doc """ Returns the mean for the tensor. - If the `:axis` option is given, it aggregates over + If the `:axes` option is given, it aggregates over that dimension, effectively removing it. `axes: [0]` implies aggregating over the highest order dimension and so forth. If the axis is negative, then counts @@ -9278,21 +9278,63 @@ defmodule Nx do f32 1.6666666269302368 > + + iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [0]) + #Nx.Tensor< + f32[2] + [1.0, 1.0] + > + + iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [1]) + #Nx.Tensor< + f32[2] + [0.25, 0.25] + > + + iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [0], ddof: 1) + #Nx.Tensor< + f32[2] + [2.0, 2.0] + > + + iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [1], ddof: 1) + #Nx.Tensor< + f32[2] + [0.5, 0.5] + > + + ### Keeping axes + + iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [1], keep_axes: true) + #Nx.Tensor< + f32[2][1] + [ + [0.25], + [0.25] + ] + > """ @doc type: :aggregation @spec variance(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Nx.Tensor.t() def variance(tensor, opts \\ []) do - %T{shape: shape} = tensor = to_tensor(tensor) + %T{shape: shape, names: names} = tensor = to_tensor(tensor) + opts = keyword!(opts, [:axes, ddof: 0, keep_axes: false]) + axes = opts[:axes] + {ddof, opts} = Keyword.pop!(opts, :ddof) + + total = + if axes do + mean_den(shape, Nx.Shape.normalize_axes(shape, axes, names)) + else + size(shape) + end - opts = keyword!(opts, ddof: 0) - total = size(shape) - ddof = Keyword.fetch!(opts, :ddof) - mean = mean(tensor) + mean = mean(tensor, Keyword.put(opts, :keep_axes, true)) tensor |> subtract(mean) |> power(2) - |> sum() + |> sum(opts) |> divide(total - ddof) end @@ -9316,6 +9358,40 @@ defmodule Nx do f32 1.29099440574646 > + + iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), axes: [0]) + #Nx.Tensor< + f32[2] + [1.0, 1.0] + > + + iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), axes: [1]) + #Nx.Tensor< + f32[2] + [0.5, 0.5] + > + + iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), axes: [0], ddof: 1) + #Nx.Tensor< + f32[2] + [1.4142135381698608, 1.4142135381698608] + > + + iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), axes: [1], ddof: 1) + #Nx.Tensor< + f32[2] + [0.7071067690849304, 0.7071067690849304] + > + + ### Keeping axes + + iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), keep_axes: true) + #Nx.Tensor< + f32[1][1] + [ + [1.1180340051651] + ] + > """ @doc type: :aggregation @spec standard_deviation(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Nx.Tensor.t() diff --git a/nx/test/nx_test.exs b/nx/test/nx_test.exs index 53879e2f68..caf9de05f9 100644 --- a/nx/test/nx_test.exs +++ b/nx/test/nx_test.exs @@ -1922,26 +1922,53 @@ defmodule NxTest do end describe "variance/1" do - test "should calculate the variance of a tensor" do + test "calculates variance of a tensor" do t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) assert Nx.variance(t) == Nx.tensor(2.9166667461395264) end - test "should use the optional ddof" do + test "uses optional ddof" do t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) assert Nx.variance(t, ddof: 1) == Nx.tensor(3.5) end + + test "uses optional axes" do + t = Nx.tensor([[4, 5], [2, 3], [1, 0]], names: [:x, :y]) + + assert Nx.variance(t, axes: [:x]) == + Nx.tensor([1.5555557012557983, 4.222222328186035], names: [:y]) + + t = Nx.tensor([[4, 5], [2, 3], [1, 0]], names: [:x, :y]) + assert Nx.variance(t, axes: [:y]) == Nx.tensor([0.25, 0.25, 0.25], names: [:x]) + end + + test "uses optional keep axes" do + t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) + assert Nx.variance(t, keep_axes: true) == Nx.tensor([[2.9166667461395264]]) + end end describe "standard_deviation/1" do - test "should calculate the standard deviation of a tensor" do + test "calculates the standard deviation of a tensor" do t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) assert Nx.standard_deviation(t) == Nx.tensor(1.707825127659933) end - test "should use the optional ddof" do + test "uses optional ddof" do t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) assert Nx.standard_deviation(t, ddof: 1) == Nx.tensor(1.8708287477493286) end + + test "uses optional axes" do + t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) + + assert Nx.standard_deviation(t, axes: [0]) == + Nx.tensor([1.247219204902649, 2.054804801940918]) + end + + test "uses optional keep axes" do + t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) + assert Nx.standard_deviation(t, keep_axes: true) == Nx.tensor([[1.7078251838684082]]) + end end end