Skip to content

Commit

Permalink
improve based on guidelines
Browse files Browse the repository at this point in the history
  • Loading branch information
tiagodavi committed Feb 7, 2022
1 parent 17298d5 commit 06f0a45
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
26 changes: 23 additions & 3 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -9302,15 +9302,25 @@ defmodule Nx do
f32[2]
[0.5, 0.5]
>
### Keeping axes
iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [1], keep_axes: true)
#Nx.Tensor<
f32[2][1]
[
[0.25],
[0.25]
]
>
"""
@doc type: :aggregation
@spec variance(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Nx.Tensor.t()
def variance(tensor, opts \\ []) do
%T{shape: shape, names: names} = tensor = to_tensor(tensor)
opts = keyword!(opts, [:axes, ddof: 0])
opts = keyword!(opts, [:axes, ddof: 0, keep_axes: false])
axes = opts[:axes]
ddof = Keyword.fetch!(opts, :ddof)
opts = Keyword.delete(opts, :ddof)
{ddof, opts} = Keyword.pop!(opts, :ddof)

total =
if axes do
Expand Down Expand Up @@ -9372,6 +9382,16 @@ defmodule Nx do
f32[2]
[0.7071067690849304, 0.7071067690849304]
>
### Keeping axes
iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), keep_axes: true)
#Nx.Tensor<
f32[1][1]
[
[1.1180340051651]
]
>
"""
@doc type: :aggregation
@spec standard_deviation(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Nx.Tensor.t()
Expand Down
10 changes: 10 additions & 0 deletions nx/test/nx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1950,6 +1950,11 @@ defmodule NxTest do
assert Nx.variance(t, axes: [0], ddof: 1) ==
Nx.tensor([2.3333334922790527, 6.333333492279053], names: [:y])
end

test "should keep axes" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])
assert Nx.variance(t, keep_axes: true) == Nx.tensor([[2.9166667461395264]])
end
end

describe "standard_deviation/1" do
Expand All @@ -1976,5 +1981,10 @@ defmodule NxTest do
assert Nx.standard_deviation(t, axes: [1], ddof: 1) ==
Nx.tensor([0.7071067690849304, 0.7071067690849304, 0.7071067690849304])
end

test "should keep axes" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])
assert Nx.standard_deviation(t, keep_axes: true) == Nx.tensor([[1.7078251838684082]])
end
end
end

0 comments on commit 06f0a45

Please sign in to comment.