Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use/export LogExpFunctions.jl? #252

Open
devmotion opened this issue Dec 22, 2020 · 14 comments
Open

Use/export LogExpFunctions.jl? #252

devmotion opened this issue Dec 22, 2020 · 14 comments

Comments

@devmotion
Copy link
Contributor

The implementation of logsumexp in StatsFuns is quite optimized (see, e.g., JuliaStats/StatsFuns.jl#97), it works with GPUs, is numerically more stable than the implementation in NNlib, and uses a one-pass algorithm.

I am wondering if NNlib should remove its own implementation and just reexport StatsFuns.logsumexp?

More generally, maybe it would make sense to unify some of the duplicate implementations in both packages of, e.g., softmax, softmax!, sigmoid, and softplus?

@CarloLucibello
Copy link
Member

CarloLucibello commented Dec 26, 2020

I am wondering if NNlib should remove its own implementation and just reexport StatsFuns.logsumexp?

I guess we should. I'm just wary of adding another dependency, there have been already some complaints about latency (see #224).
Any hope StatFuns could ditch its Rmath dependency?

More generally, maybe it would make sense to unify some of the duplicate implementations in both packages of, e.g., softmax, softmax!, sigmoid, and softplus?

Are all these gpu and AD friendly?

@devmotion
Copy link
Contributor Author

Any hope StatFuns could ditch its Rmath dependency?

I don't know the exact plans of the maintainers, I think the plan is to remove the dependency eventually at some point. There are some issues regarding Rmath (e.g. JuliaStats/Distributions.jl#1509) and there was a discussion about moving the log/exp functions to a separate package (JuliaStats/StatsFuns.jl#46).

Are all these gpu and AD friendly?

IIRC not, therefore I used the term unify here. I just checked and it seems softmax (without dims arguments), softplus, logit, and logistic work with CuArray but for some of them I get warnings such as

┌ Warning: calls to Base intrinsics might be GPU incompatible
│   exception =
│    You called log(x::Float32) in Base.Math at special/log.jl:289, maybe you intended to call log(x::Float32) in CUDA at /home/davwi492/.julia/packages/CUDA/YeS8q/src/de
vice/intrinsics/math.jl:73 instead?
│    Stacktrace:
│     [1] log at special/log.jl:289
│     [2] broadcast_kernel at /home/davwi492/.julia/packages/GPUArrays/jhRU7/src/host/broadcast.jl:59
└ @ GPUCompiler /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/irgen.jl:68

invsoftplus throws an error though, the problem can be reduced to

julia> map(log  expm1, CUDA.rand(Float32, 5))
┌ Warning: calls to Base intrinsics might be GPU incompatible
│   exception =
│    You called log(x::Float32) in Base.Math at special/log.jl:289, maybe you intended to call log(x::Float32) in CUDA at /home/davwi492/.julia/packages/CUDA/YeS8q/src/de
vice/intrinsics/math.jl:73 instead?
│    Stacktrace:
│     [1] log at special/log.jl:289
│     [2] broadcast_kernel at /home/davwi492/.julia/packages/GPUArrays/jhRU7/src/host/broadcast.jl:59
└ @ GPUCompiler /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/irgen.jl:68
ERROR: InvalidIRError: compiling kernel broadcast_kernel(CUDA.CuKernelContext, CuDeviceArray{Float32,1,1}, Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},Bas
e.var"#62#63"{typeof(log),typeof(expm1)},Tuple{Base.Broadcast.Extruded{CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}}, Int64) resulted in invalid LLVM IR
Reason: unsupported call to the Julia runtime (call to jl_f_tuple)
Stacktrace:
 [1] expm1 at math.jl:367
 [2] JuliaStats/StatsFuns.jl#62 at operators.jl:875
 [3] _broadcast_getindex_evalf at broadcast.jl:648
 [4] _broadcast_getindex at broadcast.jl:621
 [5] getindex at broadcast.jl:575
 [6] broadcast_kernel at /home/davwi492/.julia/packages/GPUArrays/jhRU7/src/host/broadcast.jl:62
Stacktrace:
 [1] check_ir(::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget,CUDA.CUDACompilerParams}, ::LLVM.Module) at /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/val
idation.jl:123
 [2] macro expansion at /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:239 [inlined]
 [3] macro expansion at /home/davwi492/.julia/packages/TimerOutputs/ZmKD7/src/TimerOutput.jl:206 [inlined]
 [4] codegen(::Symbol, ::GPUCompiler.CompilerJob; libraries::Bool, deferred_codegen::Bool, optimize::Bool, strip::Bool, validate::Bool, only_entry::Bool) at /home/davwi49
2/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:237
 [5] compile(::Symbol, ::GPUCompiler.CompilerJob; libraries::Bool, deferred_codegen::Bool, optimize::Bool, strip::Bool, validate::Bool, only_entry::Bool) at /home/davwi49
2/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:39
 [6] compile at /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:35 [inlined]
 [7] cufunction_compile(::GPUCompiler.FunctionSpec; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/davwi492/.julia/packages/CUDA/Y
eS8q/src/compiler/execution.jl:310
 [8] cufunction_compile(::GPUCompiler.FunctionSpec) at /home/davwi492/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:305
 [9] check_cache(::Dict{UInt64,Any}, ::Any, ::Any, ::GPUCompiler.FunctionSpec{GPUArrays.var"#broadcast_kernel#12",Tuple{CUDA.CuKernelContext,CuDeviceArray{Float32,1,1},Ba
se.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},Base.var"#62#63"{typeof(log),typeof(expm1)},Tuple{Base.Broadcast.Extruded{CuDeviceArray{Float32,1,1},Tuple{Bool}
,Tuple{Int64}}}},Int64}}, ::UInt64; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/
cache.jl:40
 [10] broadcast_kernel at /home/davwi492/.julia/packages/GPUArrays/jhRU7/src/host/broadcast.jl:60 [inlined]
 [11] cached_compilation at /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/cache.jl:65 [inlined]
 [12] cufunction(::GPUArrays.var"#broadcast_kernel#12", ::Type{Tuple{CUDA.CuKernelContext,CuDeviceArray{Float32,1,1},Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},Base.var"#62#63"{typeof(log),typeof(expm1)},Tuple{Base.Broadcast.Extruded{CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}},Int64}}; name::Nothing, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/davwi492/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:297
 [13] cufunction at /home/davwi492/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:294 [inlined]
 [14] #launch_heuristic#853 at /home/davwi492/.julia/packages/CUDA/YeS8q/src/gpuarrays.jl:19 [inlined]
 [15] launch_heuristic at /home/davwi492/.julia/packages/CUDA/YeS8q/src/gpuarrays.jl:17 [inlined]
 [16] copyto! at /home/davwi492/.julia/packages/GPUArrays/jhRU7/src/host/broadcast.jl:66 [inlined]
 [17] copyto! at ./broadcast.jl:886 [inlined]
 [18] copy at ./broadcast.jl:862 [inlined]
 [19] materialize(::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Nothing,Base.var"#62#63"{typeof(log),typeof(expm1)},Tuple{CuArray{Float32,1}}}) at ./broadcast.jl:837
 [20] map(::Function, ::CuArray{Float32,1}) at /home/davwi492/.julia/packages/GPUArrays/jhRU7/src/host/broadcast.jl:89
 [21] top-level scope at REPL[52]:1

@devmotion
Copy link
Contributor Author

BTW StatsFuns already depends on ChainRulesCore implicitly via SpecialFunctions, so it seems custom ChainRules-based adjoints could be added to StatsFuns without introducing any additional dependencies.

@cossio
Copy link

cossio commented Jul 9, 2021

We should use https://github.com/JuliaStats/LogExpFunctions.jl, which doesn't depend on Rmath.
Note that StatsFuns just re-exports the functions from LogExpFunctions.
See #331.

@devmotion
Copy link
Contributor Author

Yes, this issue was one motivation for moving the functions to LogExpFunctions 🙂

@CarloLucibello CarloLucibello changed the title Use/export StatsFuns.logsumexp? Use/export LogSumExp.jl? Jul 11, 2021
@CarloLucibello CarloLucibello changed the title Use/export LogSumExp.jl? Use/export LogExpFunctions.jl? Jul 11, 2021
@CarloLucibello
Copy link
Member

LogExpFunctions.jl should define the rrules. We could do it here, but the original repo is the natural place.

Also, if we need this, we'll need to define sepate implementations for CuArrays in NNlibCUDA

@devmotion
Copy link
Contributor Author

FYI recently I added the ChainRules definitions to LogExpFunctions.

@DhairyaLGandhi
Copy link
Member

Great, we can move some of the definitions there

@devmotion
Copy link
Contributor Author

Which definitions? ChainRules? LogExpFunctions contains already derivatives for all functions defined in LogExpFunctions.

@CarloLucibello
Copy link
Member

FYI recently I added the ChainRules definitions to LogExpFunctions.

Not something that we typically pay much attention to (although we should!), but the rules themselves are differentiable?

@devmotion
Copy link
Contributor Author

Nobody has tested it but they should be as they only involve basic functions or functions from LogExpFunctions for which rules are defined: https://github.com/JuliaStats/LogExpFunctions.jl/blob/master/src/chainrules.jl It might be more efficient though for in particular logsumexp and softmax to use custom second derivatives instead of differentiating through the rules. In general, I don't know if anyone has ever differentiated through rrules and frules (I would assume someone tried at least?).

@mcabbott
Copy link
Member

mcabbott commented Nov 6, 2021

There are a few rules which have their own rules, as for sum here: https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/mapreduce.jl#L36

Those ones look likely to work, to me. Although perhaps you could find ways to make the second more efficient.

Why do they have Ωcopy = copy(Ω) though?

@devmotion
Copy link
Contributor Author

devmotion commented Nov 6, 2021

Mutation of the primal result of softmax leads to an incorrect pullback. The copy ensures that the pullback is always correct, regardless of downstream computations.

@mcabbott
Copy link
Member

mcabbott commented Nov 6, 2021

Sure, I guess I mean, did this come up somewhere?

Every rule in https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/arraymath.jl (except for + & -) closes over things without preventative copies. So they rely on you not mutating arguments/results elsewhere. Changing that seems like it would roughly double memory usage.

Looking in the docs quickly, I don't actually see mention of such questions. Maybe @oxinabox has thoughts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants