Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Nx.covariance #1289

Merged
merged 14 commits into from
Aug 31, 2023
95 changes: 95 additions & 0 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15703,6 +15703,101 @@ defmodule Nx do
sqrt(variance(tensor, opts))
end

@doc """
josevalim marked this conversation as resolved.
Show resolved Hide resolved
Computes the covariance matrix of the input tensor with shape `{..., n, d}`.

Covariance of two random variables X and Y is calculated as mean( (X - mean(X)) * (Y - mean(Y)) ).
krstopro marked this conversation as resolved.
Show resolved Hide resolved
For every tensor of shape {n, d} in the batch, covariance matrix element at position (i, j)
krstopro marked this conversation as resolved.
Show resolved Hide resolved
is equal to the covariance of the i-th and j-th columns of the tensor.
krstopro marked this conversation as resolved.
Show resolved Hide resolved
Column mean can be provided as the second argument. It must be a tensor of shape `{..., d}`
where the batch shape `...` is broadcastable with that of the input tensor.
If not provided, colum mean is estimated using `Nx.mean`.
krstopro marked this conversation as resolved.
Show resolved Hide resolved
If the `:ddof` (delta degrees of freedom) option is given, the divisor `n - ddof` is used for the sum of the products.

## Examples

iex> Nx.covariance(Nx.tensor([[1, 2], [3, 4], [5, 6]]))
#Nx.Tensor<
f32[2][2]
[
[2.6666667461395264, 2.6666667461395264],
[2.6666667461395264, 2.6666667461395264]
]
>

iex> Nx.covariance(Nx.tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]))
#Nx.Tensor<
f32[2][2][2]
[
[
[2.6666667461395264, 2.6666667461395264],
[2.6666667461395264, 2.6666667461395264]
],
[
[2.6666667461395264, 2.6666667461395264],
[2.6666667461395264, 2.6666667461395264]
]
]
>

iex> Nx.covariance(Nx.tensor([[1, 2], [3, 4], [5, 6]]), ddof: 1)
#Nx.Tensor<
f32[2][2]
[
[4.0, 4.0],
[4.0, 4.0]
]
>

iex> Nx.covariance(Nx.tensor([[1, 2], [3, 4], [5, 6]]), Nx.tensor([4, 3]))
#Nx.Tensor<
f32[2][2]
[
[3.6666667461395264, 1.6666666269302368],
[1.6666666269302368, 3.6666667461395264]
]
>
"""
@doc type: :aggregation
@spec covariance(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Nx.Tensor.t()
def covariance(tensor, opts \\ [])

@spec covariance(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Nx.Tensor.t()
def covariance(tensor, opts) when is_list(opts),
do: covariance(tensor, Nx.mean(tensor, axes: [-2]), opts)

@spec covariance(tensor :: Nx.Tensor.t(), mean :: Nx.Tensor.t()) :: Nx.Tensor.t()
def covariance(tensor, mean), do: covariance(tensor, mean, [])

josevalim marked this conversation as resolved.
Show resolved Hide resolved
@spec covariance(tensor :: Nx.Tensor.t(), mean :: Nx.Tensor.t(), opts :: Keyword.t()) ::
Nx.Tensor.t()
def covariance(tensor, mean, opts) do
tensor = to_tensor(tensor)
mean = to_tensor(mean)
opts = keyword!(opts, ddof: 0)
tensor_rank = Nx.rank(tensor)

if tensor_rank < 2 do
raise ArgumentError, "expected input tensor of rank at least 2, got #{tensor_rank}"
end

if Nx.rank(mean) == 0 do
raise ArgumentError, "expected mean of rank at least 1, got 0"
end

ddof = opts[:ddof]

if not is_integer(ddof) or ddof < 0 do
raise ArgumentError, "expected ddof to be a non-negative integer, got #{ddof}"
end

tensor = tensor |> subtract(new_axis(mean, -2)) |> rename(nil)
conj = if Nx.Type.complex?(Nx.type(tensor)), do: Nx.conjugate(tensor), else: tensor
batch_axes = 0..(Nx.rank(tensor) - 3)//1 |> Enum.to_list()
total = Nx.axis_size(tensor, -2)
Nx.dot(conj, [-2], batch_axes, tensor, [-2], batch_axes) |> divide(total - ddof)
end

@doc """
Calculates the DFT of the given tensor.

Expand Down
7 changes: 1 addition & 6 deletions nx/test/nx/random_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -328,13 +328,8 @@ defmodule Nx.RandomTest do

assert_all_close(Nx.mean(multivariate_normal, axes: [0]), mean, rtol: 0.1)

multivariate_normal_centered = Nx.subtract(multivariate_normal, mean)

assert_all_close(
Nx.divide(
Nx.dot(multivariate_normal_centered, [0], multivariate_normal_centered, [0]),
1000
),
Nx.covariance(multivariate_normal, mean),
covariance,
rtol: 0.1
)
Expand Down
66 changes: 66 additions & 0 deletions nx/test/nx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2680,6 +2680,72 @@ defmodule NxTest do
end
end

describe "covariance/1" do
test "calculates the covariance of a tensor" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])

assert Nx.covariance(t) ==
Nx.tensor([
[1.5555554628372192, 2.444444417953491],
[2.444444417953491, 4.222222328186035]
])
end

test "with tensor batch" do
t = Nx.tensor([[[4, 5], [2, 3], [1, 0]], [[10, 11], [8, 9], [7, 6]]])

assert Nx.covariance(t) ==
Nx.tensor([
[
[1.5555554628372192, 2.444444417953491],
[2.444444417953491, 4.222222328186035]
],
[
[1.5555554628372192, 2.444444417953491],
[2.444444417953491, 4.222222328186035]
]
])
end

test "with mean" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])
mean = Nx.tensor([3, 2])

assert Nx.covariance(t, mean) ==
Nx.tensor([
[2.0, 2.0],
[2.0, 4.666666507720947]
])
end

test "with mean batch" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])
mean = Nx.tensor([[3, 2], [2, 3]])

assert Nx.covariance(t, mean) ==
Nx.tensor([
[
[2.0, 2.0],
[2.0, 4.666666507720947]
],
[
[1.6666666269302368, 2.3333332538604736],
[2.3333332538604736, 4.333333492279053]
]
])
end

test "uses optional ddof" do
t = Nx.tensor([[4, 5], [2, 3], [1, 0]])

assert Nx.covariance(t, ddof: 1) ==
Nx.tensor([
[2.3333332538604736, 3.6666667461395264],
[3.6666667461395264, 6.333333492279053]
])
end
end

describe "weighted_mean/3" do
test "shape of input differs from shape of weights" do
t = Nx.iota({4, 6})
Expand Down
Loading