-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reactant: add extension to prevent stackoverflow #206
Conversation
065eebe
to
f2ab9f0
Compare
For the problem with the learning rate in #205 instead, do we have to relax the type constraint? |
These are two different issues, but yes for that issue the learning rate will need to be relaxed. Essentially tracedrarrays/tracedrnumbers represent data being compiled. It is never convertible to a regular float (because during compile time the data doesn't exist, but it represents a potential input, or set of operations on that input). Thus if you do For the learning rate, if you want the learning rate to be a parameter to the compiled function (and not baked into the compiled function), it needs to be a traced value. Of course you could alternatively have the learning rate be a regular Float64, but then the compiled function will always use whatever learning rate you passed in at compile time. |
Do we really need an extension for this? The code is:
written apparently assuming that The second method could simply be this:
which will make one attempt to convert & then give up. Why not do that? |
Yeah, honestly this is a bug in _eps here IMO. So presently it is _eps(T::Type{<:AbstractFloat}, e) = T(e)
# catch complex and integers
_eps(T::Type{<:Number}, e) = _eps(real(float(T)), e)
# avoid small e being rounded to zero
_eps(T::Type{Float16}, e) = e == 0 ? T(0) : max(T(1e-7), T(e)) So essentially the float function takes whatever the type is and converts it into a floating point version of that type. e.g. julia> float(3)
3.0
julia> float(Complex(3))
3.0 + 0.0im We have a version of float that takes a traced number (e.g. TracedRNumber{Int} or TracedRNumber{Complex{Int}}` and this will convert it to a TracedRNumber{Float64} etc. So the second case Honestly if you want to catch the complex/int stuff, you could just special case it on int/complex instead of on number. |
I mean sure, go for it. I'm not sufficiently into the insides of this package that I would've known that change to be okay (e.g. someone later down the line really assuming it needs to be a float). Which.... eventually could be hit here: #205 |
I guess I wish the description of what the problem actually is was in the first message (since it requires knowing nothing about fancy Reactant stuff, just number types) and that we don't try to merge the first draft at lightspeed. Want to revert and simplify? For #205, is there a clear writeup of all of this somewhere? Like at some point numbers became 0-dim arrays in some way that I didn't understand, maybe this has changed... what are we actually designing around? |
We no longer have that 0-dim array stuff. Instead now we just have two types TracedRNumbers [for numbers] and TracedRArrays [for arrays]. Essentially the same thing I said above applies to general traced types (and of course you can getindex a tracedrarray and get a tracedrnumber) |
This reverts commit 12b7f31.
* Revert "Reactant: add extension to prevent stackoverflow (#206)" This reverts commit 12b7f31. * Change eps to not be recursive * Update src/utils.jl Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> * Update src/utils.jl Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --------- Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Add extension to prevent stackoverflow when using Reactant.
Origin of bug in FluxML/Fluxperimental.jl#28