Skip to content

Commit

Permalink
Add Nx.covariance (#1289)
Browse files Browse the repository at this point in the history
  • Loading branch information
krstopro authored and josevalim committed Oct 24, 2023
1 parent 423308f commit 10fa593
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 6 deletions.
103 changes: 103 additions & 0 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15703,6 +15703,109 @@ defmodule Nx do
sqrt(variance(tensor, opts))
end

@doc """
A shortcut to `covariance/3` with either `opts` or `mean` as second argument.
"""
@doc type: :aggregation
@spec covariance(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Nx.Tensor.t()
def covariance(tensor, opts \\ [])

@spec covariance(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Nx.Tensor.t()
def covariance(tensor, opts) when is_list(opts),
do: covariance(tensor, Nx.mean(tensor, axes: [-2]), opts)

@spec covariance(tensor :: Nx.Tensor.t(), mean :: Nx.Tensor.t()) :: Nx.Tensor.t()
def covariance(tensor, mean), do: covariance(tensor, mean, [])

@doc """
Computes the covariance matrix of the input tensor.
The covariance of two random variables X and Y is calculated as $Cov(X, Y) = \\frac{1}{N}\\sum_{i=0}^{N-1}{(X_i - \\overline{X}) * (Y_i - \\overline{Y})}$.
The tensor must be at least of rank 2, with shape `{n, d}`. Any additional
dimension will be treated as batch dimensions.
The column mean can be provided as the second argument and it must be
a tensor of shape `{..., d}`, where the batch shape is broadcastable with
that of the input tensor. If not provided, the mean is estimated using `Nx.mean/2`.
If the `:ddof` (delta degrees of freedom) option is given, the divisor `n - ddof`
is used for the sum of the products.
## Examples
iex> Nx.covariance(Nx.tensor([[1, 2], [3, 4], [5, 6]]))
#Nx.Tensor<
f32[2][2]
[
[2.6666667461395264, 2.6666667461395264],
[2.6666667461395264, 2.6666667461395264]
]
>
iex> Nx.covariance(Nx.tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]))
#Nx.Tensor<
f32[2][2][2]
[
[
[2.6666667461395264, 2.6666667461395264],
[2.6666667461395264, 2.6666667461395264]
],
[
[2.6666667461395264, 2.6666667461395264],
[2.6666667461395264, 2.6666667461395264]
]
]
>
iex> Nx.covariance(Nx.tensor([[1, 2], [3, 4], [5, 6]]), ddof: 1)
#Nx.Tensor<
f32[2][2]
[
[4.0, 4.0],
[4.0, 4.0]
]
>
iex> Nx.covariance(Nx.tensor([[1, 2], [3, 4], [5, 6]]), Nx.tensor([4, 3]))
#Nx.Tensor<
f32[2][2]
[
[3.6666667461395264, 1.6666666269302368],
[1.6666666269302368, 3.6666667461395264]
]
>
"""
@doc type: :aggregation
@spec covariance(tensor :: Nx.Tensor.t(), mean :: Nx.Tensor.t(), opts :: Keyword.t()) ::
Nx.Tensor.t()
def covariance(tensor, mean, opts) do
tensor = to_tensor(tensor)
mean = to_tensor(mean)
opts = keyword!(opts, ddof: 0)
tensor_rank = Nx.rank(tensor)

if tensor_rank < 2 do
raise ArgumentError, "expected input tensor of rank at least 2, got #{tensor_rank}"
end

if Nx.rank(mean) == 0 do
raise ArgumentError, "expected mean of rank at least 1, got 0"
end

ddof = opts[:ddof]

if not is_integer(ddof) or ddof < 0 do
raise ArgumentError, "expected ddof to be a non-negative integer, got #{ddof}"
end

tensor = tensor |> subtract(new_axis(mean, -2)) |> rename(nil)
conj = if Nx.Type.complex?(Nx.type(tensor)), do: Nx.conjugate(tensor), else: tensor
batch_axes = 0..(Nx.rank(tensor) - 3)//1 |> Enum.to_list()
total = Nx.axis_size(tensor, -2)
Nx.dot(conj, [-2], batch_axes, tensor, [-2], batch_axes) |> divide(total - ddof)
end

@doc """
Calculates the DFT of the given tensor.
Expand Down
7 changes: 1 addition & 6 deletions nx/test/nx/random_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -328,13 +328,8 @@ defmodule Nx.RandomTest do

assert_all_close(Nx.mean(multivariate_normal, axes: [0]), mean, rtol: 0.1)

multivariate_normal_centered = Nx.subtract(multivariate_normal, mean)

assert_all_close(
Nx.divide(
Nx.dot(multivariate_normal_centered, [0], multivariate_normal_centered, [0]),
1000
),
Nx.covariance(multivariate_normal, mean),
covariance,
rtol: 0.1
)
Expand Down
66 changes: 66 additions & 0 deletions nx/test/nx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2685,6 +2685,72 @@ defmodule NxTest do
end
end

describe "covariance/1" do
test "calculates the covariance of a tensor" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])

assert Nx.covariance(t) ==
Nx.tensor([
[1.5555554628372192, 2.444444417953491],
[2.444444417953491, 4.222222328186035]
])
end

test "with tensor batch" do
t = Nx.tensor([[[4, 5], [2, 3], [1, 0]], [[10, 11], [8, 9], [7, 6]]])

assert Nx.covariance(t) ==
Nx.tensor([
[
[1.5555554628372192, 2.444444417953491],
[2.444444417953491, 4.222222328186035]
],
[
[1.5555554628372192, 2.444444417953491],
[2.444444417953491, 4.222222328186035]
]
])
end

test "with mean" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])
mean = Nx.tensor([3, 2])

assert Nx.covariance(t, mean) ==
Nx.tensor([
[2.0, 2.0],
[2.0, 4.666666507720947]
])
end

test "with mean batch" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])
mean = Nx.tensor([[3, 2], [2, 3]])

assert Nx.covariance(t, mean) ==
Nx.tensor([
[
[2.0, 2.0],
[2.0, 4.666666507720947]
],
[
[1.6666666269302368, 2.3333332538604736],
[2.3333332538604736, 4.333333492279053]
]
])
end

test "uses optional ddof" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])

assert Nx.covariance(t, ddof: 1) ==
Nx.tensor([
[2.3333332538604736, 3.6666667461395264],
[3.6666667461395264, 6.333333492279053]
])
end
end

describe "weighted_mean/3" do
test "shape of input differs from shape of weights" do
t = Nx.iota({4, 6})
Expand Down

0 comments on commit 10fa593

Please sign in to comment.