Skip to content

Commit

Permalink
improve variance and standard deviation
Browse files Browse the repository at this point in the history
  • Loading branch information
tiagodavi committed Feb 4, 2022
1 parent eb20bb9 commit dcbd2b4
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 7 deletions.
70 changes: 63 additions & 7 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5349,7 +5349,7 @@ defmodule Nx do
@doc """
Returns the mean for the tensor.
If the `:axis` option is given, it aggregates over
If the `:axes` option is given, it aggregates over
that dimension, effectively removing it. `axes: [0]`
implies aggregating over the highest order dimension
and so forth. If the axis is negative, then counts
Expand Down Expand Up @@ -9187,21 +9187,53 @@ defmodule Nx do
f32
1.6666666269302368
>
iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [0])
#Nx.Tensor<
f32[2]
[1.0, 1.0]
>
iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [1])
#Nx.Tensor<
f32[2]
[0.25, 0.25]
>
iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [0], ddof: 1)
#Nx.Tensor<
f32[2]
[2.0, 2.0]
>
iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [1], ddof: 1)
#Nx.Tensor<
f32[2]
[0.5, 0.5]
>
"""
@doc type: :aggregation
@spec variance(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Nx.Tensor.t()
def variance(tensor, opts \\ []) do
%T{shape: shape} = tensor = to_tensor(tensor)

opts = keyword!(opts, ddof: 0)
total = size(shape)
%T{shape: shape, names: names} = tensor = to_tensor(tensor)
opts = keyword!(opts, [:axes, ddof: 0])
axes = opts[:axes]
ddof = Keyword.fetch!(opts, :ddof)
mean = mean(tensor)
opts = Keyword.delete(opts, :ddof)

total =
if axes do
mean_den(shape, Nx.Shape.normalize_axes(shape, axes, names))
else
size(shape)
end

mean = mean(tensor, Keyword.put(opts, :keep_axes, true))

tensor
|> subtract(mean)
|> power(2)
|> sum()
|> sum(opts)
|> divide(total - ddof)
end

Expand All @@ -9225,6 +9257,30 @@ defmodule Nx do
f32
1.29099440574646
>
iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), axes: [0])
#Nx.Tensor<
f32[2]
[1.0, 1.0]
>
iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), axes: [1])
#Nx.Tensor<
f32[2]
[0.5, 0.5]
>
iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), axes: [0], ddof: 1)
#Nx.Tensor<
f32[2]
[1.4142135381698608, 1.4142135381698608]
>
iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), axes: [1], ddof: 1)
#Nx.Tensor<
f32[2]
[0.7071067690849304, 0.7071067690849304]
>
"""
@doc type: :aggregation
@spec standard_deviation(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Nx.Tensor.t()
Expand Down
33 changes: 33 additions & 0 deletions nx/test/nx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1931,6 +1931,25 @@ defmodule NxTest do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])
assert Nx.variance(t, ddof: 1) == Nx.tensor(3.5)
end

test "should use the optional axes on x" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]], names: [:x, :y])

assert Nx.variance(t, axes: [:x]) ==
Nx.tensor([1.5555557012557983, 4.222222328186035], names: [:y])
end

test "should use the optional axes on y" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]], names: [:x, :y])
assert Nx.variance(t, axes: [:y]) == Nx.tensor([0.25, 0.25, 0.25], names: [:x])
end

test "should use the optional axes and ddof" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]], names: [:x, :y])

assert Nx.variance(t, axes: [0], ddof: 1) ==
Nx.tensor([2.3333334922790527, 6.333333492279053], names: [:y])
end
end

describe "standard_deviation/1" do
Expand All @@ -1943,5 +1962,19 @@ defmodule NxTest do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])
assert Nx.standard_deviation(t, ddof: 1) == Nx.tensor(1.8708287477493286)
end

test "should use the optional axes" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])

assert Nx.standard_deviation(t, axes: [0]) ==
Nx.tensor([1.247219204902649, 2.054804801940918])
end

test "should use the optional axes and ddof" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])

assert Nx.standard_deviation(t, axes: [1], ddof: 1) ==
Nx.tensor([0.7071067690849304, 0.7071067690849304, 0.7071067690849304])
end
end
end

0 comments on commit dcbd2b4

Please sign in to comment.