Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reactant: add extension to prevent stackoverflow #206

Merged
merged 1 commit into from
Jan 5, 2025

Conversation

wsmoses
Copy link
Contributor

@wsmoses wsmoses commented Jan 5, 2025

Add extension to prevent stackoverflow when using Reactant.

Origin of bug in FluxML/Fluxperimental.jl#28

simple train!: Error During Test at /Users/wmoses/git/Fluxperimental.jl/test/reactant.jl:152
  Got exception outside of a @test
  StackOverflowError:
  Stacktrace:
    [1] mlirOperationCreate
      @ ~/git/Reactant.jl/src/mlir/libMLIR_h.jl:1005 [inlined]
    [2] create_operation(name::String, loc::Reactant.MLIR.IR.Location; results::Vector{Reactant.MLIR.IR.Type}, operands::Vector{Reactant.MLIR.IR.Value}, owned_regions::Vector{Reactant.MLIR.IR.Region}, successors::Vector{Reactant.MLIR.IR.Block}, attributes::Vector{Reactant.MLIR.IR.NamedAttribute}, result_inference::Bool)
      @ Reactant.MLIR.IR ~/git/Reactant.jl/src/mlir/IR/Operation.jl:319
    [3] constant(; output::Reactant.MLIR.IR.Type, value::Reactant.MLIR.IR.Attribute, location::Reactant.MLIR.IR.Location)
      @ Reactant.MLIR.Dialects.stablehlo ~/git/Reactant.jl/src/mlir/Dialects/StableHLO.jl:1134
    [4] constant
      @ ~/git/Reactant.jl/src/mlir/Dialects/StableHLO.jl:1126 [inlined]
    [5] constant(x::Array{Float32, 0}; location::Reactant.MLIR.IR.Location)
      @ Reactant.Ops ~/git/Reactant.jl/src/Ops.jl:74
    [6] constant(x::Array{Float32, 0})
      @ Reactant.Ops ~/git/Reactant.jl/src/Ops.jl:69
    [7] promote_to(::Type{Reactant.TracedRNumber{Float32}}, rhs::Int64)
      @ Reactant.TracedRNumberOverrides ~/git/Reactant.jl/src/TracedRNumber.jl:70
    [8] TracedRNumber
      @ ~/git/Reactant.jl/src/TracedRNumber.jl:56 [inlined]
    [9] convert
      @ ./number.jl:7 [inlined]
   [10] zero
      @ ./number.jl:309 [inlined]
   [11] float
      @ ./float.jl:311 [inlined]
   [12] _eps(T::Type{Reactant.TracedRNumber{Float32}}, e::Float64) (repeats 79970 times)
      @ Optimisers ~/.julia/packages/Optimisers/a4OnF/src/utils.jl:20
Test Summary:                           | Pass  Error  Broken  Total     Time

@CarloLucibello
Copy link
Member

For the problem with the learning rate in #205 instead, do we have to relax the type constraint?

@wsmoses
Copy link
Contributor Author

wsmoses commented Jan 5, 2025

These are two different issues, but yes for that issue the learning rate will need to be relaxed. Essentially tracedrarrays/tracedrnumbers represent data being compiled. It is never convertible to a regular float (because during compile time the data doesn't exist, but it represents a potential input, or set of operations on that input). Thus if you dof(TracedRValue) for any f you'll get another traced return type (unless of course the function returns a constant or doesn't depend on the input).

For the learning rate, if you want the learning rate to be a parameter to the compiled function (and not baked into the compiled function), it needs to be a traced value. Of course you could alternatively have the learning rate be a regular Float64, but then the compiled function will always use whatever learning rate you passed in at compile time.

@CarloLucibello CarloLucibello merged commit 12b7f31 into FluxML:master Jan 5, 2025
3 of 4 checks passed
@mcabbott
Copy link
Member

mcabbott commented Jan 5, 2025

Do we really need an extension for this? The code is:

_eps(T::Type{<:AbstractFloat}, e) = T(e)
_eps(T::Type{<:Number}, e) = _eps(real(float(T)), e) 

written apparently assuming that real(float(T)) must produce an AbstractFloat. That's apparently false here, but would also be false for say Dual numbers, or maybe unitful or who knows what else.

The second method could simply be this:

_eps(T::Type{<:Number}, e) = real(float(T))(e) 

which will make one attempt to convert & then give up. Why not do that?

@wsmoses
Copy link
Contributor Author

wsmoses commented Jan 5, 2025

Yeah, honestly this is a bug in _eps here IMO.

So presently it is

_eps(T::Type{<:AbstractFloat}, e) = T(e)
# catch complex and integers
_eps(T::Type{<:Number}, e) = _eps(real(float(T)), e) 
# avoid small e being rounded to zero
_eps(T::Type{Float16}, e) = e == 0 ? T(0) : max(T(1e-7), T(e))

So essentially the float function takes whatever the type is and converts it into a floating point version of that type.

e.g.

julia> float(3)
3.0

julia> float(Complex(3))
3.0 + 0.0im

We have a version of float that takes a traced number (e.g. TracedRNumber{Int} or TracedRNumber{Complex{Int}}` and this will convert it to a TracedRNumber{Float64} etc.

So the second case _eps(T::Type{<:Number}, e) = _eps(real(float(T)), e) if T is a TracedRNumber{Float64}, then float(T) will also be TracedRNumber{Float64}. Of course that is real so it's TracedRNumber{Float64} again. so literally the function
eps(TracedRNumber{Float64}, x) = eps(TracedRnumber{Float64}, x) so it infinitely recurs.

Honestly if you want to catch the complex/int stuff, you could just special case it on int/complex instead of on number.

@wsmoses
Copy link
Contributor Author

wsmoses commented Jan 5, 2025

I mean sure, go for it.

I'm not sufficiently into the insides of this package that I would've known that change to be okay (e.g. someone later down the line really assuming it needs to be a float).

Which.... eventually could be hit here: #205

@mcabbott
Copy link
Member

mcabbott commented Jan 5, 2025

I guess I wish the description of what the problem actually is was in the first message (since it requires knowing nothing about fancy Reactant stuff, just number types) and that we don't try to merge the first draft at lightspeed.

Want to revert and simplify?

For #205, is there a clear writeup of all of this somewhere? Like at some point numbers became 0-dim arrays in some way that I didn't understand, maybe this has changed... what are we actually designing around?

@wsmoses
Copy link
Contributor Author

wsmoses commented Jan 5, 2025

We no longer have that 0-dim array stuff. Instead now we just have two types TracedRNumbers [for numbers] and TracedRArrays [for arrays].

Essentially the same thing I said above applies to general traced types (and of course you can getindex a tracedrarray and get a tracedrnumber)

wsmoses added a commit to wsmoses/Optimisers.jl that referenced this pull request Jan 5, 2025
@wsmoses wsmoses mentioned this pull request Jan 5, 2025
@wsmoses wsmoses deleted the reext branch January 5, 2025 18:51
mcabbott added a commit that referenced this pull request Jan 5, 2025
* Revert "Reactant: add extension to prevent stackoverflow (#206)"

This reverts commit 12b7f31.

* Change eps to not be recursive

* Update src/utils.jl

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>

* Update src/utils.jl

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>

---------

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants