Skip to content

Commit

Permalink
Add support for vectorization
Browse files Browse the repository at this point in the history
`nvec` was not actually used in `generic_integrand!`.  Fix #10.
  • Loading branch information
giordano committed Jun 23, 2017
1 parent 00dbc0e commit d1e1797
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
26 changes: 25 additions & 1 deletion src/Cuba.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ const KEY = 0
# error. See http://julialang.org/blog/2013/05/callback for more information on
# this, in particular the section about "qsort_r" ("Passing closures via
# pass-through pointers"). Thanks to Steven G. Johnson for pointing to this.
#
# For better performance, when nvec == 1 we store a simple Vector, instead of a Matrix with
# second dimension equal to 1.
function generic_integrand!(ndim::Cint, x_::Ptr{Cdouble}, ncomp::Cint,
f_::Ptr{Cdouble}, func!)
# Get arrays from "x_" and "f_" pointers.
Expand All @@ -91,6 +94,14 @@ function generic_integrand!(ndim::Cint, x_::Ptr{Cdouble}, ncomp::Cint,
func!(x, f)
return Cint(0)
end
function generic_integrand!(ndim::Cint, x_::Ptr{Cdouble}, ncomp::Cint,
f_::Ptr{Cdouble}, func!, nvec::Cint)
# Get arrays from "x_" and "f_" pointers.
x = unsafe_wrap(Array, x_, (ndim, nvec))
f = unsafe_wrap(Array, f_, (ncomp, nvec))
func!(x, f)
return Cint(0)
end

# Return pointer for "integrand", to be passed as "integrand" argument to Cuba functions.
integrand_ptr{T}(integrand::T) = cfunction(generic_integrand!, Cint,
Expand All @@ -99,6 +110,13 @@ integrand_ptr{T}(integrand::T) = cfunction(generic_integrand!, Cint,
Ref{Cint}, # ncomp
Ptr{Cdouble}, # f
Ref{T})) # userdata
integrand_ptr_nvec{T}(integrand::T) = cfunction(generic_integrand!, Cint,
(Ref{Cint}, # ndim
Ptr{Cdouble}, # x
Ref{Cint}, # ncomp
Ptr{Cdouble}, # f
Ref{T}, # userdata
Ref{Cint})) # nvec

abstract Integrand{T}

Expand Down Expand Up @@ -138,7 +156,13 @@ function Base.show(io::IO, x::Integral)
end

@inline function dointegrate{T}(x::Integrand{T})
integrand = integrand_ptr(x.func)
# Choose the integrand function wrapper based on the value of `nvec`. This function is
# called only once, so the overhead of the following if should be negligible.
if x.nvec == 1
integrand = integrand_ptr(x.func)
else
integrand = integrand_ptr_nvec(x.func)
end
integral = Vector{Cdouble}(x.ncomp)
error = Vector{Cdouble}(x.ncomp)
prob = Vector{Cdouble}(x.ncomp)
Expand Down
15 changes: 15 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,21 @@ end
@test isapprox(result[1], pi*log(2), atol=3e-12)
end

@testset "Vectorization" begin
for alg in (vegas, suave, divonne, cuhre)
result1, err1, _ = alg((x,f) -> f[1] = x[1] + cos(x[2]) - exp(x[3]), 3)
result2, err2, _ = alg((x,f) -> f[1,:] .= x[1,:] .+ cos.(x[2,:]) .- exp.(x[3,:]),
3, nvec = 10)
@test result1 == result2
@test err1 == err2
result1, err1, _ = alg((x,f) -> begin f[1] = sin(x[1]); f[2] = sqrt(x[2]) end, 2, 2)
result2, err2, _ = alg((x,f) -> begin f[1,:] .= sin.(x[1,:]); f[2,:] .= sqrt.(x[2,:]) end,
2, 2, nvec = 10)
@test result1 == result2
@test err1 == err2
end
end

# Make sure these functions don't crash.
Cuba.init(C_NULL, C_NULL)
Cuba.exit(C_NULL, C_NULL)
Expand Down

0 comments on commit d1e1797

Please sign in to comment.