diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 4132a2c80e..e73c2a2b86 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -15703,6 +15703,109 @@ defmodule Nx do sqrt(variance(tensor, opts)) end + @doc """ + A shortcut to `covariance/3` with either `opts` or `mean` as second argument. + """ + @doc type: :aggregation + @spec covariance(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Nx.Tensor.t() + def covariance(tensor, opts \\ []) + + @spec covariance(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Nx.Tensor.t() + def covariance(tensor, opts) when is_list(opts), + do: covariance(tensor, Nx.mean(tensor, axes: [-2]), opts) + + @spec covariance(tensor :: Nx.Tensor.t(), mean :: Nx.Tensor.t()) :: Nx.Tensor.t() + def covariance(tensor, mean), do: covariance(tensor, mean, []) + + @doc """ + Computes the covariance matrix of the input tensor. + + The covariance of two random variables X and Y is calculated as $Cov(X, Y) = \\frac{1}{N}\\sum_{i=0}^{N-1}{(X_i - \\overline{X}) * (Y_i - \\overline{Y})}$. + + The tensor must be at least of rank 2, with shape `{n, d}`. Any additional + dimension will be treated as batch dimensions. + + The column mean can be provided as the second argument and it must be + a tensor of shape `{..., d}`, where the batch shape is broadcastable with + that of the input tensor. If not provided, the mean is estimated using `Nx.mean/2`. + + If the `:ddof` (delta degrees of freedom) option is given, the divisor `n - ddof` + is used for the sum of the products. + + ## Examples + + iex> Nx.covariance(Nx.tensor([[1, 2], [3, 4], [5, 6]])) + #Nx.Tensor< + f32[2][2] + [ + [2.6666667461395264, 2.6666667461395264], + [2.6666667461395264, 2.6666667461395264] + ] + > + + iex> Nx.covariance(Nx.tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]])) + #Nx.Tensor< + f32[2][2][2] + [ + [ + [2.6666667461395264, 2.6666667461395264], + [2.6666667461395264, 2.6666667461395264] + ], + [ + [2.6666667461395264, 2.6666667461395264], + [2.6666667461395264, 2.6666667461395264] + ] + ] + > + + iex> Nx.covariance(Nx.tensor([[1, 2], [3, 4], [5, 6]]), ddof: 1) + #Nx.Tensor< + f32[2][2] + [ + [4.0, 4.0], + [4.0, 4.0] + ] + > + + iex> Nx.covariance(Nx.tensor([[1, 2], [3, 4], [5, 6]]), Nx.tensor([4, 3])) + #Nx.Tensor< + f32[2][2] + [ + [3.6666667461395264, 1.6666666269302368], + [1.6666666269302368, 3.6666667461395264] + ] + > + """ + @doc type: :aggregation + @spec covariance(tensor :: Nx.Tensor.t(), mean :: Nx.Tensor.t(), opts :: Keyword.t()) :: + Nx.Tensor.t() + def covariance(tensor, mean, opts) do + tensor = to_tensor(tensor) + mean = to_tensor(mean) + opts = keyword!(opts, ddof: 0) + tensor_rank = Nx.rank(tensor) + + if tensor_rank < 2 do + raise ArgumentError, "expected input tensor of rank at least 2, got #{tensor_rank}" + end + + if Nx.rank(mean) == 0 do + raise ArgumentError, "expected mean of rank at least 1, got 0" + end + + ddof = opts[:ddof] + + if not is_integer(ddof) or ddof < 0 do + raise ArgumentError, "expected ddof to be a non-negative integer, got #{ddof}" + end + + tensor = tensor |> subtract(new_axis(mean, -2)) |> rename(nil) + conj = if Nx.Type.complex?(Nx.type(tensor)), do: Nx.conjugate(tensor), else: tensor + batch_axes = 0..(Nx.rank(tensor) - 3)//1 |> Enum.to_list() + total = Nx.axis_size(tensor, -2) + Nx.dot(conj, [-2], batch_axes, tensor, [-2], batch_axes) |> divide(total - ddof) + end + @doc """ Calculates the DFT of the given tensor. diff --git a/nx/test/nx/random_test.exs b/nx/test/nx/random_test.exs index 12a5f1d304..e0025df031 100644 --- a/nx/test/nx/random_test.exs +++ b/nx/test/nx/random_test.exs @@ -328,13 +328,8 @@ defmodule Nx.RandomTest do assert_all_close(Nx.mean(multivariate_normal, axes: [0]), mean, rtol: 0.1) - multivariate_normal_centered = Nx.subtract(multivariate_normal, mean) - assert_all_close( - Nx.divide( - Nx.dot(multivariate_normal_centered, [0], multivariate_normal_centered, [0]), - 1000 - ), + Nx.covariance(multivariate_normal, mean), covariance, rtol: 0.1 ) diff --git a/nx/test/nx_test.exs b/nx/test/nx_test.exs index e3b825a78a..564691e813 100644 --- a/nx/test/nx_test.exs +++ b/nx/test/nx_test.exs @@ -2680,6 +2680,72 @@ defmodule NxTest do end end + describe "covariance/1" do + test "calculates the covariance of a tensor" do + t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) + + assert Nx.covariance(t) == + Nx.tensor([ + [1.5555554628372192, 2.444444417953491], + [2.444444417953491, 4.222222328186035] + ]) + end + + test "with tensor batch" do + t = Nx.tensor([[[4, 5], [2, 3], [1, 0]], [[10, 11], [8, 9], [7, 6]]]) + + assert Nx.covariance(t) == + Nx.tensor([ + [ + [1.5555554628372192, 2.444444417953491], + [2.444444417953491, 4.222222328186035] + ], + [ + [1.5555554628372192, 2.444444417953491], + [2.444444417953491, 4.222222328186035] + ] + ]) + end + + test "with mean" do + t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) + mean = Nx.tensor([3, 2]) + + assert Nx.covariance(t, mean) == + Nx.tensor([ + [2.0, 2.0], + [2.0, 4.666666507720947] + ]) + end + + test "with mean batch" do + t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) + mean = Nx.tensor([[3, 2], [2, 3]]) + + assert Nx.covariance(t, mean) == + Nx.tensor([ + [ + [2.0, 2.0], + [2.0, 4.666666507720947] + ], + [ + [1.6666666269302368, 2.3333332538604736], + [2.3333332538604736, 4.333333492279053] + ] + ]) + end + + test "uses optional ddof" do + t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) + + assert Nx.covariance(t, ddof: 1) == + Nx.tensor([ + [2.3333332538604736, 3.6666667461395264], + [3.6666667461395264, 6.333333492279053] + ]) + end + end + describe "weighted_mean/3" do test "shape of input differs from shape of weights" do t = Nx.iota({4, 6})