Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add utility functions to NX #616

Merged
merged 5 commits into from
Jan 31, 2022

Conversation

tiagodavi
Copy link
Contributor

This PR adds some utility functions to NX.

@josevalim
Copy link
Collaborator

Thank you @tiagodavi! @polvalente / @seanmor5, do you think this belongs in Nx? Or some statistics oriented package?

@polvalente
Copy link
Contributor

Thank you @tiagodavi! @polvalente / @seanmor5, do you think this belongs in Nx? Or some statistics oriented package?

@josevalim I think std/var/avg are common and basic enough to belong in Nx. For instance, random variables are often normalized through the standard deviation, and those are Nx vectors.

However, perhaps we could think about moving them to another module.

Speaking of avg (or mean), do we have it yet? Could be a good companion function here.

@josevalim
Copy link
Collaborator

We do have mean. @tiagodavi, can you please rename those functions to they have the full name instead of abbreviations? We typically avoid abbreviations in Nx.

nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated
1.25
>
"""
@doc type: :element
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the type is aggregation or similar. Pls check what Nx.mean does :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll take a look. thank you

nx/lib/nx.ex Outdated
|> subtract(mean)
|> power(2)
|> sum()
|> divide(total)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should document the exact formula, in particular that we divide by N. In statistical use cases it's often desired to divide by N - 1 when calculating sample variance/stdev, since that gives an unbiased estimate.

For the record numpy.var has ddof option and divides by N - ddof, torch.var has a more explicit unbiased option, while tfp.stats.variance always uses N.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added. good call.

nx/lib/nx.ex Outdated

tensor
|> Nx.subtract(mean)
|> Nx.divide(std(tensor))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about zero std value here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean by zero ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the standard_deviation can return 0 (for example if all values are equal). In this case, this function will error. Are there other formulas to implement standard_scale?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look that std is 0 only and only if all the values are equal. In this case, there is no possibility to achieve unit variance, so the data should be left as-is. And actually, I've searched how it is calculated in Sci-kit learn. They based on this article http://www.cs.yale.edu/publications/techreports/tr222.pdf.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this problematic case is still handled particularly. What is more, they catch the cases when the std is non-zero, but values are almost constant for stability issues. https://github.com/scikit-learn/scikit-learn/blob/7e1e6d09bcc2eaeba98f7e737aac2ac782f0e5f1/sklearn/preprocessing/_data.py#L84

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's give up on the standard_scale/normalize for now then and revisit later.

nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated
"""
@doc type: :element
@spec var(tensor :: Nx.Tensor.t()) :: Nx.Tensor.t()
def var(%Nx.Tensor{shape: shape} = tensor) do
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the definition for the variance of a 0-dimensional tensor? Perhaps we should have a check against that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iex(3)> Nx.variance Nx.tensor(5)
#Nx.Tensor<
  f32
  0.0
>

nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated
"""
@doc type: :aggregation
@spec standard_deviation(tensor :: Nx.Tensor.t(), ddof :: number()) :: Nx.Tensor.t()
def standard_deviation(tensor, ddof \\ 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we may want to add an axes option to those functions, perhaps ddof must be an option too? See how mean is done, we want to accept the same options. :)

@seanmor5
Copy link
Collaborator

I am in favor of standard deviation and variance, but I'm not sure standard scale belongs in Nx. It feels more like the responsibility of another library related to data preprocessing

@tiagodavi
Copy link
Contributor Author

I am trying to do something like this to respect the guidelines, but not sure how to access the tensor over its axes to compute the variance at the end:

elixir
def variance(tensor, opts \\ []) do
    %T{shape: shape, names: names} = tensor = to_tensor(tensor)

    mean_den =
      if axes = opts[:axes] do
        mean_den(shape, Nx.Shape.normalize_axes(shape, axes, names))
      else
        mean_den(shape, nil)
      end

    ddof = Keyword.get(opts, :ddof, 0)
    mean = mean(tensor, Keyword.take(opts, [:axes, keep_axes: false]))

    axes = axes(tensor)

    IO.inspect axes

    tensor[axes] # I am assuming I need to calculate based on its axes because mean and mean_den are based on it.
    |> subtract(mean)
    |> power(2)
    |> sum()
    |> divide(mean_den - ddof)
  end

@josevalim
Copy link
Collaborator

Oh, I see. Let’s not worry about the axes version for now then!

@tiagodavi
Copy link
Contributor Author

@josevalim done. Fixed spaces, removed standard_scale function and added the ddof option to variance, standard_deviation without axes for the time being.

nx/lib/nx.ex Outdated
%T{shape: shape} = tensor = to_tensor(tensor)

total = Tuple.product(shape)
ddof = Keyword.get(opts, :ddof, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ddof = Keyword.get(opts, :ddof, 0)
ddof = Keyword.fetch!(opts, :ddof)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

nx/lib/nx.ex Outdated
Comment on lines 9229 to 9230
tensor = to_tensor(tensor)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tensor = to_tensor(tensor)

In this specific case we can just delegate to variance

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated Show resolved Hide resolved
@josevalim josevalim merged commit 6259450 into elixir-nx:main Jan 31, 2022
@josevalim
Copy link
Collaborator

💚 💙 💜 💛 ❤️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants