Skip to content

Commit

Permalink
Avoid dynamic dispatch on callback
Browse files Browse the repository at this point in the history
  • Loading branch information
giordano committed Mar 4, 2017
1 parent cc004de commit 7256436
Showing 1 changed file with 99 additions and 104 deletions.
203 changes: 99 additions & 104 deletions src/Cuba.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,6 @@ const KEY = 0

### Functions

# Return pointer for "integrand", to be passed as "integrand" argument to Cuba
# functions.
integrand_ptr(integrand::Function) = cfunction(integrand, Cint,
(Ref{Cint}, # ndim
Ptr{Cdouble}, # x
Ref{Cint}, # ncomp
Ptr{Cdouble}, # f
Ptr{Void})) # userdata

# Note on implementation: instead of passing the function that performs
# calculations as "integrand" argument to integrator routines, we pass the
# pointer to this function and use "func_" to actually perform calculations.
Expand All @@ -93,58 +84,62 @@ integrand_ptr(integrand::Function) = cfunction(integrand, Cint,
# this, in particular the section about "qsort_r" ("Passing closures via
# pass-through pointers"). Thanks to Steven G. Johnson for pointing to this.
function generic_integrand!(ndim::Cint, x_::Ptr{Cdouble}, ncomp::Cint,
f_::Ptr{Cdouble}, func_::Ptr{Void})
f_::Ptr{Cdouble}, func!)
# Get arrays from "x_" and "f_" pointers.
x = unsafe_wrap(Array, x_, (ndim,))
f = unsafe_wrap(Array, f_, (ncomp,))
# Get the function from "func_" pointer.
func! = unsafe_pointer_to_objref(func_)
func!(x, f)
return Cint(0)
end

# Return pointer for "integrand", to be passed as "integrand" argument to Cuba functions.
integrand_ptr{T}(integrand::T) = cfunction(generic_integrand!, Cint,
(Ref{Cint}, # ndim
Ptr{Cdouble}, # x
Ref{Cint}, # ncomp
Ptr{Cdouble}, # f
Ref{typeof(integrand)})) # userdata

function __init__()
# This pointer needs to be available at runtime. See "Module initialization
# and precompilation" section of Julia manual.
global const c_generic_integrand! = integrand_ptr(generic_integrand!)
Cuba.cores(0, 10000)
end

# One function to rule them all.
for (CubaInt, prefix) in ((Int32, ""), (Int64, "ll"))
@eval begin
function dointegrate(algorithm::Symbol,
# First common arguments.
integrand::Ptr{Void}, ndim::Integer, ncomp::Integer,
userdata::Ptr{Void}, nvec::$CubaInt, reltol::Real,
abstol::Real, flags::Integer, seed::Integer,
minevals::$CubaInt, maxevals::$CubaInt,
# Vegas-specific arguments.
nstart::$CubaInt, nincrease::$CubaInt,
nbatch::$CubaInt, gridno::Integer,
# Suave-specific arguments.
nnew::$CubaInt, nmin::$CubaInt, flatness::Real,
# Divonne-specific arguments.
key1::Integer, key2::Integer, key3::Integer,
maxpass::Integer, border::Real, maxchisq::Real,
mindeviation::Real, ngiven::$CubaInt, ldxgiven::Integer,
xgiven::Any, nextra::$CubaInt, peakfinder::Ptr{Void},
# Cuhre-specific argument.
key::Integer,
# Final common arguments.
statefile::AbstractString, spin::Ptr{Void})
nregions = Ref{Cint}(0)
neval = Ref{$CubaInt}(0)
fail = Ref{Cint}(0)
integral = zeros(Cdouble, ncomp)
error = zeros(Cdouble, ncomp)
prob = zeros(Cdouble, ncomp)
function dointegrate{T}(algorithm::Symbol,
# First common arguments.
func::T, ndim::Integer, ncomp::Integer,
nvec::$CubaInt, reltol::Real,
abstol::Real, flags::Integer, seed::Integer,
minevals::$CubaInt, maxevals::$CubaInt,
# Vegas-specific arguments.
nstart::$CubaInt, nincrease::$CubaInt,
nbatch::$CubaInt, gridno::Integer,
# Suave-specific arguments.
nnew::$CubaInt, nmin::$CubaInt, flatness::Real,
# Divonne-specific arguments.
key1::Integer, key2::Integer, key3::Integer,
maxpass::Integer, border::Real, maxchisq::Real,
mindeviation::Real, ngiven::$CubaInt, ldxgiven::Integer,
xgiven::Any, nextra::$CubaInt, peakfinder::Ptr{Void},
# Cuhre-specific argument.
key::Integer,
# Final common arguments.
statefile::AbstractString, spin::Ptr{Void})
integrand = integrand_ptr(func)
nregions = Ref{Cint}(0)
neval = Ref{$CubaInt}(0)
fail = Ref{Cint}(0)
integral = zeros(Cdouble, ncomp)
error = zeros(Cdouble, ncomp)
prob = zeros(Cdouble, ncomp)
if algorithm == :Cuhre
ccall(($(prefix * "Cuhre"), libcuba), Cdouble,
(Cint, # ndim
Cint, # ncomp
Ptr{Void}, # integrand
Ptr{Void}, # userdata
Any, # userdata
$CubaInt, # nvec
Cdouble, # reltol
Cdouble, # abstol
Expand All @@ -161,7 +156,7 @@ for (CubaInt, prefix) in ((Int32, ""), (Int64, "ll"))
Ptr{Cdouble}, # error
Ptr{Cdouble}),# prob
# Input
ndim, ncomp, integrand, userdata, nvec, reltol,
ndim, ncomp, integrand, func, nvec, reltol,
abstol, flags, minevals, maxevals, key, statefile, spin,
# Output
nregions, neval, fail, integral, error, prob)
Expand All @@ -170,7 +165,7 @@ for (CubaInt, prefix) in ((Int32, ""), (Int64, "ll"))
(Cint, # ndim
Cint, # ncomp
Ptr{Void}, # integrand
Ptr{Void}, # userdata
Any, # userdata
$CubaInt, # nvec
Cdouble, # reltol
Cdouble, # abstol
Expand All @@ -190,7 +185,7 @@ for (CubaInt, prefix) in ((Int32, ""), (Int64, "ll"))
Ptr{Cdouble}, # error
Ptr{Cdouble}),# prob
# Input
ndim, ncomp, integrand, userdata, nvec,
ndim, ncomp, integrand, func, nvec,
reltol, abstol, flags, seed, minevals, maxevals,
nstart, nincrease, nbatch, gridno, statefile, spin,
# Output
Expand All @@ -200,7 +195,7 @@ for (CubaInt, prefix) in ((Int32, ""), (Int64, "ll"))
(Cint, # ndim
Cint, # ncomp
Ptr{Void}, # integrand
Ptr{Void}, # userdata
Any, # userdata
$CubaInt, # nvec
Cdouble, # reltol
Cdouble, # abstol
Expand All @@ -220,7 +215,7 @@ for (CubaInt, prefix) in ((Int32, ""), (Int64, "ll"))
Ptr{Cdouble}, # error
Ptr{Cdouble}),# prob
# Input
ndim, ncomp, integrand, userdata, nvec,
ndim, ncomp, integrand, func, nvec,
reltol, abstol, flags, seed, minevals, maxevals,
nnew, nmin, flatness, statefile, spin,
# Output
Expand All @@ -230,7 +225,7 @@ for (CubaInt, prefix) in ((Int32, ""), (Int64, "ll"))
(Cint, # ndim
Cint, # ncomp
Ptr{Void}, # integrand
Ptr{Void}, # userdata
Any, # userdata
$CubaInt, # nvec
Cdouble, # reltol
Cdouble, # abstol
Expand Down Expand Up @@ -259,7 +254,7 @@ for (CubaInt, prefix) in ((Int32, ""), (Int64, "ll"))
Ptr{Cdouble}, # error
Ptr{Cdouble}),# prob
# Input
ndim, ncomp, integrand, userdata, nvec, reltol,
ndim, ncomp, integrand, func, nvec, reltol,
abstol, flags, seed, minevals, maxevals, key1, key2, key3,
maxpass, border, maxchisq, mindeviation, ngiven, ldxgiven,
xgiven, nextra, peakfinder, statefile, spin,
Expand Down Expand Up @@ -299,22 +294,22 @@ vegas, llvegas
for (CubaInt, prefix) in ((Int32, ""), (Int64, "ll"))
func = Symbol(prefix, "vegas")
@eval begin
$func(integrand::Function, ndim::Integer=1, ncomp::Integer=1;
nvec::Integer=NVEC,
reltol::Real=RELTOL, abstol::Real=ABSTOL, flags::Integer=FLAGS,
seed::Integer=SEED, minevals::Real=MINEVALS, maxevals::Real=MAXEVALS,
nstart::Integer=NSTART, nincrease::Integer=NINCREASE,
nbatch::Integer=NBATCH, gridno::Integer=GRIDNO,
statefile::AbstractString=STATEFILE, spin::Ptr{Void}=SPIN) =
dointegrate(:Vegas, c_generic_integrand!::Ptr{Void}, ndim, ncomp,
pointer_from_objref(integrand), $CubaInt(nvec), float(reltol),
float(abstol), flags, seed, trunc($CubaInt, minevals),
trunc($CubaInt, maxevals), $CubaInt(nstart),
$CubaInt(nincrease), $CubaInt(nbatch),
gridno, $CubaInt(NNEW), $CubaInt(NMIN), FLATNESS, KEY1, KEY2,
KEY3, MAXPASS, BORDER, MAXCHISQ, MINDEVIATION,
$CubaInt(NGIVEN), LDXGIVEN, XGIVEN,
$CubaInt(NEXTRA), PEAKFINDER, KEY, statefile, spin)
$func{T}(integrand::T, ndim::Integer=1, ncomp::Integer=1;
nvec::Integer=NVEC,
reltol::Real=RELTOL, abstol::Real=ABSTOL, flags::Integer=FLAGS,
seed::Integer=SEED, minevals::Real=MINEVALS, maxevals::Real=MAXEVALS,
nstart::Integer=NSTART, nincrease::Integer=NINCREASE,
nbatch::Integer=NBATCH, gridno::Integer=GRIDNO,
statefile::AbstractString=STATEFILE, spin::Ptr{Void}=SPIN) =
dointegrate(:Vegas, integrand, ndim, ncomp,
$CubaInt(nvec), float(reltol),
float(abstol), flags, seed, trunc($CubaInt, minevals),
trunc($CubaInt, maxevals), $CubaInt(nstart),
$CubaInt(nincrease), $CubaInt(nbatch),
gridno, $CubaInt(NNEW), $CubaInt(NMIN), FLATNESS, KEY1, KEY2,
KEY3, MAXPASS, BORDER, MAXCHISQ, MINDEVIATION,
$CubaInt(NGIVEN), LDXGIVEN, XGIVEN,
$CubaInt(NEXTRA), PEAKFINDER, KEY, statefile, spin)
end
end

Expand Down Expand Up @@ -345,20 +340,20 @@ suave, llsuave
for (CubaInt, prefix) in ((Int32, ""), (Int64, "ll"))
func = Symbol(prefix, "suave")
@eval begin
$func(integrand::Function, ndim::Integer=1, ncomp::Integer=1;
nvec::Integer=NVEC, reltol::Real=RELTOL,
abstol::Real=ABSTOL, flags::Integer=FLAGS, seed::Integer=SEED,
minevals::Real=MINEVALS, maxevals::Real=MAXEVALS, nnew::Integer=NNEW,
nmin::Integer=NMIN, flatness::Real=FLATNESS,
statefile::AbstractString=STATEFILE, spin::Ptr{Void}=SPIN) =
dointegrate(:Suave, c_generic_integrand!::Ptr{Void}, ndim, ncomp,
pointer_from_objref(integrand), $CubaInt(nvec), float(reltol),
float(abstol), flags, seed, trunc($CubaInt, minevals),
trunc($CubaInt, maxevals), $CubaInt(NSTART),
$CubaInt(NINCREASE), $CubaInt(NBATCH), GRIDNO, $CubaInt(nnew),
$CubaInt(nmin), flatness, KEY1, KEY2, KEY3, MAXPASS,
BORDER, MAXCHISQ, MINDEVIATION, $CubaInt(NGIVEN), LDXGIVEN,
XGIVEN, $CubaInt(NEXTRA), PEAKFINDER, KEY, statefile, spin)
$func{T}(integrand::T, ndim::Integer=1, ncomp::Integer=1;
nvec::Integer=NVEC, reltol::Real=RELTOL,
abstol::Real=ABSTOL, flags::Integer=FLAGS, seed::Integer=SEED,
minevals::Real=MINEVALS, maxevals::Real=MAXEVALS, nnew::Integer=NNEW,
nmin::Integer=NMIN, flatness::Real=FLATNESS,
statefile::AbstractString=STATEFILE, spin::Ptr{Void}=SPIN) =
dointegrate(:Suave, integrand, ndim, ncomp,
$CubaInt(nvec), float(reltol),
float(abstol), flags, seed, trunc($CubaInt, minevals),
trunc($CubaInt, maxevals), $CubaInt(NSTART),
$CubaInt(NINCREASE), $CubaInt(NBATCH), GRIDNO, $CubaInt(nnew),
$CubaInt(nmin), flatness, KEY1, KEY2, KEY3, MAXPASS,
BORDER, MAXCHISQ, MINDEVIATION, $CubaInt(NGIVEN), LDXGIVEN,
XGIVEN, $CubaInt(NEXTRA), PEAKFINDER, KEY, statefile, spin)
end
end

Expand Down Expand Up @@ -398,28 +393,28 @@ divonne, lldivonne
for (CubaInt, prefix) in ((Int32, ""), (Int64, "ll"))
func = Symbol(prefix, "divonne")
@eval begin
function $func{R<:Real}(integrand::Function, ndim::Integer=1, ncomp::Integer=1;
nvec::Integer=NVEC, reltol::Real=RELTOL,
abstol::Real=ABSTOL, flags::Integer=FLAGS,
seed::Integer=SEED, minevals::Real=MINEVALS,
maxevals::Real=MAXEVALS, key1::Integer=KEY1,
key2::Integer=KEY2, key3::Integer=KEY3,
maxpass::Integer=MAXPASS, border::Real=BORDER,
maxchisq::Real=MAXCHISQ,
mindeviation::Real=MINDEVIATION,
ngiven::Integer=NGIVEN, ldxgiven::Integer=LDXGIVEN,
xgiven::AbstractArray{R}=zeros(Cdouble, ldxgiven,
ngiven),
nextra::Integer=NEXTRA,
peakfinder::Ptr{Void}=PEAKFINDER,
statefile::AbstractString=STATEFILE,
spin::Ptr{Void}=SPIN)
function $func{T,R<:Real}(integrand::T, ndim::Integer=1, ncomp::Integer=1;
nvec::Integer=NVEC, reltol::Real=RELTOL,
abstol::Real=ABSTOL, flags::Integer=FLAGS,
seed::Integer=SEED, minevals::Real=MINEVALS,
maxevals::Real=MAXEVALS, key1::Integer=KEY1,
key2::Integer=KEY2, key3::Integer=KEY3,
maxpass::Integer=MAXPASS, border::Real=BORDER,
maxchisq::Real=MAXCHISQ,
mindeviation::Real=MINDEVIATION,
ngiven::Integer=NGIVEN, ldxgiven::Integer=LDXGIVEN,
xgiven::AbstractArray{R}=zeros(Cdouble, ldxgiven,
ngiven),
nextra::Integer=NEXTRA,
peakfinder::Ptr{Void}=PEAKFINDER,
statefile::AbstractString=STATEFILE,
spin::Ptr{Void}=SPIN)
# Divonne requires "ndim" to be at least 2, even for an integral over a one
# dimensional domain. Instead, we don't prevent users from setting wrong
# "ndim" values like 0 or negative ones.
ndim == 1 && (ndim = 2)
return dointegrate(:Divonne, c_generic_integrand!::Ptr{Void}, ndim, ncomp,
pointer_from_objref(integrand), $CubaInt(nvec), float(reltol),
return dointegrate(:Divonne, integrand, ndim, ncomp,
$CubaInt(nvec), float(reltol),
float(abstol), flags, seed, trunc($CubaInt, minevals),
trunc($CubaInt, maxevals), $CubaInt(NSTART),
$CubaInt(NINCREASE), $CubaInt(NBATCH), GRIDNO,
Expand Down Expand Up @@ -455,17 +450,17 @@ cuhre, llcuhre
for (CubaInt, prefix) in ((Int32, ""), (Int64, "ll"))
func = Symbol(prefix, "cuhre")
@eval begin
function $func(integrand::Function, ndim::Integer=1, ncomp::Integer=1;
nvec::Integer=NVEC, reltol::Real=RELTOL, abstol::Real=ABSTOL,
flags::Integer=FLAGS, minevals::Real=MINEVALS,
maxevals::Real=MAXEVALS, key::Integer=KEY,
statefile::AbstractString=STATEFILE, spin::Ptr{Void}=SPIN)
function $func{T}(integrand::T, ndim::Integer=1, ncomp::Integer=1;
nvec::Integer=NVEC, reltol::Real=RELTOL, abstol::Real=ABSTOL,
flags::Integer=FLAGS, minevals::Real=MINEVALS,
maxevals::Real=MAXEVALS, key::Integer=KEY,
statefile::AbstractString=STATEFILE, spin::Ptr{Void}=SPIN)
# Cuhre requires "ndim" to be at least 2, even for an integral over a one
# dimensional domain. Instead, we don't prevent users from setting wrong
# "ndim" values like 0 or negative ones.
ndim == 1 && (ndim = 2)
return dointegrate(:Cuhre, c_generic_integrand!::Ptr{Void}, ndim, ncomp,
pointer_from_objref(integrand), $CubaInt(nvec), float(reltol),
return dointegrate(:Cuhre, integrand, ndim, ncomp,
$CubaInt(nvec), float(reltol),
float(abstol), flags, SEED, trunc($CubaInt, minevals),
trunc($CubaInt, maxevals), $CubaInt(NSTART),
$CubaInt(NINCREASE), $CubaInt(NBATCH), GRIDNO,
Expand Down

0 comments on commit 7256436

Please sign in to comment.