Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Nx.covariance #1289

Merged
merged 14 commits into from
Aug 31, 2023
Merged

Add Nx.covariance #1289

merged 14 commits into from
Aug 31, 2023

Conversation

krstopro
Copy link
Member

@krstopro krstopro commented Aug 28, 2023

Added Nx.covariance/2 which estimates the covariance matrix from input tensor. Also edited the test for properties of multivariate normal distribution in nx_test.ex.
Few things I am not sure about:

  • Currently the function requires the input tensor to be of rank 2. One could make it work for tensors of rank 1 as well, but then I would prefer the user to use Nx.variance.
  • Mean can be passed as an option. This is rather convenient when the mean is known or already computed. This is not the approach taken in the rest of Nx (e.g. functions for variance or standard deviation) or NumPy / JAX.

nx/lib/nx.ex Outdated
opts = keyword!(opts, ddof: 0, mean: nil)

if rank(tensor) != 2 do
raise ArgumentError, "expected tensor of rank 2, got #{rank(tensor)}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should a least accept tensors of shape {..., m, n} where the first k-2 dimensions are batch dimensions.

For the mean, I think we need to receive it as an argument, so we'd have:

def covariance(tensor), do: covariance(tensor, Nx.mean(tensor, axes: [-2]), [])
def covariance(tensor, opts) when is_list(opts), do: covariance(tensor, Nx.mean(tensor, axes: [-2]), opts)
def covariance(tensor, mean), do: covariance(tensor, mean, [])
def covariance(tensor, mean, opts) do
...
end

Although @josevalim will be able to say for certain if passing the mean as an option is really a problem.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And you'd do mean = to_tensor(mean) right after tensor = to_tensor(tensor)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, tensors must be arguments. And I would go with this signature:

def covariance(tensor, opts \\ [])
def covariance(tensor, opts) when is_list(opts), do: covariance(tensor, Nx.mean(tensor, axes: [-2]), opts)
def covariance(tensor, mean), do: covariance(tensor, mean, [])
def covariance(tensor, mean, opts) when is_list(opts) do

:)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got used to passing arguments as options to the point I forgot Elixir supports multiple-clause functions. 😅

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should at least accept tensors of shape {..., m, n} where the first k-2 dimensions are batch dimensions.

@polvalente Just to clarify: if we take a batch of datasets as input (shape {..., n, d}), the output should be a batch of covariances (shape {..., d, d}), right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

nx/lib/nx.ex Outdated
Comment on lines 15769 to 15775
rank(opts_mean) != 1 ->
raise ArgumentError, "expected mean to have rank 1, got #{rank(opts_mean)}"

size(opts_mean) != dim ->
raise ArgumentError,
"expected input tensor and mean to have same dimensions, got #{dim} and #{size(opts_mean)}"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can skip these validations because subtract should already apply them for us.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the reasons why I validated the rank of the mean is because the user can pass a single number and then Nx will broadcast it when subtracting. Do we really want this? If the input tensor is of format n x d, then I would prefer the user to hardcode a d-dimensional mean vector.
Agreed for the dimension.

nx/lib/nx.ex Outdated Show resolved Hide resolved
@msluszniak
Copy link
Contributor

There is a function that calculates this in Scholar https://hexdocs.pm/scholar/Scholar.Covariance.html#covariance_matrix/2. Is there any specific reason to implement this once again in Nx?

@krstopro
Copy link
Member Author

krstopro commented Aug 29, 2023

@msluszniak I see. I would honestly prefer it in Nx, as the function is rather fundamental and not only necessary for machine learning algorithms. I also thought of using it for unit testing Nx.Random.multivariate_normal (though this might be imported from Scholar as well).

multivariate_normal_centered = Nx.subtract(multivariate_normal, mean)
assert_all_close(
Nx.divide(
Nx.dot(multivariate_normal_centered, [0], multivariate_normal_centered, [0]),
1000
),

Krsto Proroković added 2 commits August 29, 2023 19:51
@polvalente
Copy link
Contributor

There is a function that calculates this in Scholar https://hexdocs.pm/scholar/Scholar.Covariance.html#covariance_matrix/2. Is there any specific reason to implement this once again in Nx?

@msluszniak I didn't know about that one. I believe we should deprecate that in favor of this one in Nx. We could reuse the implementation from Scholar.

@krstopro I see you added a :ddof option as well as the mean argument. Could you check on the possibility of merging both implementations?

@krstopro
Copy link
Member Author

@polvalente @msluszniak Looking at it, but this one seems more general at the moment (e.g. works with batches and complex input). The main difference is in the way they handle bias. I prefer using :biased option as done in Scholar, but I've been following the rest of Nx and used :ddof.

Both implementations contain a bug though :)
If the last axis of input tensor is named, the function will raise an error (as Nx.dot(x, [0], x, [0]) raises an error due to duplicate name). One way to solve this is to extract the names of the tensor, replace the last name with nil using List.replace_at/2, then rename the tensor. Is there a better way?

@polvalente
Copy link
Contributor

I think we should force the output tensor to have no names. So you can rename all axes of x to nil and it should resolve.

@msluszniak
Copy link
Contributor

@msluszniak I didn't know about that one. I believe we should deprecate that in favor of this one in Nx. We could reuse the implementation from Scholar.

Yes, I think it's a good idea to use the implementation from Nx and deprecate the version from Scholar. But to do this, we also need to implement a correlation matrix (which will be trivial as we get covariance).

@krstopro
Copy link
Member Author

krstopro commented Aug 29, 2023

I think we should force the output tensor to have no names. So you can rename all axes of x to nil and it should resolve.

Yeah, I guess I should do that. I still like the names on the batch axes though.

Yes, I think it's a good idea to use the implementation from Nx and deprecate the version from Scholar. But to do this, we also need to implement a correlation matrix (which will be trivial as we get covariance).

Agreed.

@polvalente
Copy link
Contributor

I think we should force the output tensor to have no names. So you can rename all axes of x to nil and it should resolve.

Yeah, I guess I should do that. I still like the names on the batch axes though.

The problem is that as we are reducing on a given axis, the semantics start to get fuzzy on what axes should be kept.
I think the simpler solution is to remove all axes' names.

@krstopro
Copy link
Member Author

I think we should force the output tensor to have no names. So you can rename all axes of x to nil and it should resolve.

Yeah, I guess I should do that. I still like the names on the batch axes though.

The problem is that as we are reducing on a given axis, the semantics start to get fuzzy on what axes should be kept. I think the simpler solution is to remove all axes' names.

Indeed, better remove all the names.

@krstopro
Copy link
Member Author

Alright, changed the description and added unit tests for batch inputs. Might be ready to merge.

nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated Show resolved Hide resolved
krstopro and others added 4 commits August 31, 2023 00:03
Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com>
Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com>
Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com>
Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com>
nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated Show resolved Hide resolved
@josevalim josevalim merged commit 8d5a19a into elixir-nx:main Aug 31, 2023
8 of 9 checks passed
@josevalim
Copy link
Collaborator

💚 💙 💜 💛 ❤️

@msluszniak msluszniak mentioned this pull request Sep 11, 2023
josevalim pushed a commit that referenced this pull request Oct 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants