-
Notifications
You must be signed in to change notification settings - Fork 195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add utility functions to NX #616
Conversation
Thank you @tiagodavi! @polvalente / @seanmor5, do you think this belongs in Nx? Or some statistics oriented package? |
@josevalim I think std/var/avg are common and basic enough to belong in Nx. For instance, random variables are often normalized through the standard deviation, and those are Nx vectors. However, perhaps we could think about moving them to another module. Speaking of avg (or mean), do we have it yet? Could be a good companion function here. |
We do have |
nx/lib/nx.ex
Outdated
1.25 | ||
> | ||
""" | ||
@doc type: :element |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the type is aggregation or similar. Pls check what Nx.mean
does :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll take a look. thank you
nx/lib/nx.ex
Outdated
|> subtract(mean) | ||
|> power(2) | ||
|> sum() | ||
|> divide(total) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should document the exact formula, in particular that we divide by N
. In statistical use cases it's often desired to divide by N - 1
when calculating sample variance/stdev, since that gives an unbiased estimate.
For the record numpy.var has ddof
option and divides by N - ddof
, torch.var has a more explicit unbiased
option, while tfp.stats.variance always uses N
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added. good call.
nx/lib/nx.ex
Outdated
|
||
tensor | ||
|> Nx.subtract(mean) | ||
|> Nx.divide(std(tensor)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about zero std
value here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you mean by zero ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the standard_deviation can return 0 (for example if all values are equal). In this case, this function will error. Are there other formulas to implement standard_scale
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look that std is 0 only and only if all the values are equal. In this case, there is no possibility to achieve unit variance, so the data should be left as-is. And actually, I've searched how it is calculated in Sci-kit learn. They based on this article http://www.cs.yale.edu/publications/techreports/tr222.pdf.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But this problematic case is still handled particularly. What is more, they catch the cases when the std is non-zero, but values are almost constant for stability issues. https://github.com/scikit-learn/scikit-learn/blob/7e1e6d09bcc2eaeba98f7e737aac2ac782f0e5f1/sklearn/preprocessing/_data.py#L84
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, let's give up on the standard_scale/normalize for now then and revisit later.
nx/lib/nx.ex
Outdated
""" | ||
@doc type: :element | ||
@spec var(tensor :: Nx.Tensor.t()) :: Nx.Tensor.t() | ||
def var(%Nx.Tensor{shape: shape} = tensor) do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the definition for the variance of a 0-dimensional tensor? Perhaps we should have a check against that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iex(3)> Nx.variance Nx.tensor(5)
#Nx.Tensor<
f32
0.0
>
nx/lib/nx.ex
Outdated
""" | ||
@doc type: :aggregation | ||
@spec standard_deviation(tensor :: Nx.Tensor.t(), ddof :: number()) :: Nx.Tensor.t() | ||
def standard_deviation(tensor, ddof \\ 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we may want to add an axes
option to those functions, perhaps ddof
must be an option too? See how mean
is done, we want to accept the same options. :)
I am in favor of standard deviation and variance, but I'm not sure standard scale belongs in Nx. It feels more like the responsibility of another library related to data preprocessing |
I am trying to do something like this to respect the guidelines, but not sure how to access the tensor over its axes to compute the variance at the end:
|
Oh, I see. Let’s not worry about the axes version for now then! |
@josevalim done. Fixed spaces, removed standard_scale function and added the ddof option to variance, standard_deviation without axes for the time being. |
nx/lib/nx.ex
Outdated
%T{shape: shape} = tensor = to_tensor(tensor) | ||
|
||
total = Tuple.product(shape) | ||
ddof = Keyword.get(opts, :ddof, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ddof = Keyword.get(opts, :ddof, 0) | |
ddof = Keyword.fetch!(opts, :ddof) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
nx/lib/nx.ex
Outdated
tensor = to_tensor(tensor) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensor = to_tensor(tensor) |
In this specific case we can just delegate to variance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
💚 💙 💜 💛 ❤️ |
This PR adds some utility functions to NX.