Skip to content

Commit

Permalink
type inference fixes and tests, work around JuliaLang/julia#22885
Browse files Browse the repository at this point in the history
  • Loading branch information
stevengj committed Jul 20, 2017
1 parent 528e46f commit 174a8a5
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 18 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ rule due to Genz and Malik (1980), until the estimated error `E`
satisfies `E ≤ max(rtol*norm(I), atol)`, i.e. `rtol` and `atol` are
the relative and absolute tolerances requested, respectively.
It also stops if the number of `f` evaluations exceeds `maxevals`.
The default `rtol` is the square root of the precision `eps(T)`
If neither `atol` nor `rtol` are specified, the
default `rtol` is the square root of the precision `eps(T)`
of the coordinate type `T` described above.

The error is estimated by `norm(I - I′)`, where `I′` is an alternative
Expand Down
32 changes: 19 additions & 13 deletions src/HCubature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ end
cubrule(::Type{Val{0}}, ::Type{T}) where {T} = Trivial()
countevals(::Trivial) = 1

function hcubature_(f, a::SVector{n,T}, b::SVector{n,T}, norm, rtol, atol, maxevals) where {n, T<:AbstractFloat}
function hcubature_(f, a::SVector{n,T}, b::SVector{n,T}, norm, rtol_, atol, maxevals) where {n, T<:AbstractFloat}
rtol = rtol_ == 0 == atol ? sqrt(eps(T)) : rtol_
(rtol < 0 || atol < 0) && throw(ArgumentError("invalid negative tolerance"))
maxevals < 0 && throw(ArgumentError("invalid negative maxevals"))

rule = cubrule(Val{n}, T)
I, E, kdiv = rule(f, a,b, norm)
numevals = evals_per_box = countevals(rule)
Expand Down Expand Up @@ -91,6 +95,18 @@ function hcubature_(f, a::SVector{n,T}, b::SVector{n,T}, norm, rtol, atol, maxev
return I,E
end

function hcubature_(f, a::SVector{n,T}, b::SVector{n,S},
norm, rtol, atol, maxevals) where {n, T<:Real, S<:Real}
F = float(promote_type(T, S))
return hcubature_(f, SVector{n,F}(a), SVector{n,F}(b), norm, rtol, atol, maxevals)
end
hcubature_(f, a::AbstractVector{<:Real}, b::AbstractVector{<:Real},
norm, rtol, atol, maxevals) =
hcubature_(f, SVector{length(a)}(a), SVector{length(b)}(b), norm, rtol, atol, maxevals)
hcubature_(f, a::NTuple{n,<:Real}, b::NTuple{n,<:Real},
norm, rtol, atol, maxevals) where {n} =
hcubature_(f, SVector{n}(a), SVector{n}(b), norm, rtol, atol, maxevals)

"""
hcubature(f, a, b; norm=vecnorm, rtol=sqrt(eps), atol=0, maxevals=typemax(Int))
Expand Down Expand Up @@ -132,18 +148,8 @@ test above) is `vecnorm`, but you can pass an alternative norm by
the `norm` keyword argument. (This is especially useful when `f`
returns a vector of integrands with different scalings.)
"""
function hcubature end

hcubature(f, a::SVector{n,T}, b::SVector{n,T};
norm=vecnorm, rtol::Real=sqrt(eps(T)), atol::Real=zero(T), maxevals::Integer=typemax(Int)) where {n, T<:AbstractFloat} =
hcubature(f, a, b; norm=vecnorm, rtol::Real=0, atol::Real=0,
maxevals::Integer=typemax(Int)) =
hcubature_(f, a, b, norm, rtol, atol, maxevals)
function hcubature(f, a::SVector{n,T}, b::SVector{n,S}; kws...) where {n, T<:Real, S<:Real}
F = float(promote_type(T, S))
return hcubature(f, SVector{n,F}(a), SVector{n,F}(b); kws...)
end
hcubature(f, a::AbstractVector{<:Real}, b::AbstractVector{<:Real}; kws...) =
hcubature(f, SVector{length(a)}(a), SVector{length(b)}(b); kws...)
hcubature(f, a::NTuple{n,<:Real}, b::NTuple{n,<:Real}; kws...) where {n} =
hcubature(f, SVector{n}(a), SVector{n}(b); kws...)

end # module
9 changes: 5 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ using Base.Test

@testset "simple" begin
@test hcubature(x -> cos(x[1])*cos(x[2]), [0,0], [1,1])[1] sin(1)^2
hcubature(x -> cos(x[1])*cos(x[2]), (0,0), (1,1))[1]
@test hcubature(x -> cos(x[1]), (0,), (1,))[1] sin(1)
@test hcubature(x -> cos(x[1]), (0.0f0,), (1.0f0,))[1] sin(1.0f0)
@test hcubature(x -> 1.7, SVector{0,Float64}(), SVector{0,Float64}())[1] == 1.7
@inferred(hcubature(x -> cos(x[1])*cos(x[2]), (0,0), (1,1)))[1]
@test @inferred(hcubature(x -> cos(x[1]), (0,), (1,)))[1] sin(1)
@test @inferred(hcubature(x -> cos(x[1]), (0.0f0,), (1.0f0,)))[1] sin(1.0f0)
@test @inferred(hcubature(x -> 1.7, SVector{0,Float64}(), SVector{0,Float64}()))[1] == 1.7
end

# function wrapper for counting evaluations
Expand Down Expand Up @@ -41,4 +41,5 @@ end
@test HCubature.countevals(g) == 1 + 4length(g.p[1]) + length(g.p[3]) + length(g.p[4])
end
end
@test HCubature.countevals(HCubature.Trivial()) == 1
end

0 comments on commit 174a8a5

Please sign in to comment.