Skip to content

Commit

Permalink
Merge pull request #147 from DanielVandH/gauss_quad
Browse files Browse the repository at this point in the history
Add Gauss-Legendre quadrature
  • Loading branch information
ChrisRackauckas authored Feb 19, 2023
2 parents 70444a9 + de0a6ee commit f7a60d9
Show file tree
Hide file tree
Showing 8 changed files with 211 additions and 12 deletions.
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Integrals"
uuid = "de52edbc-65ea-441a-8357-d3a637375a31"
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
version = "3.6.0"
version = "3.7.0"

[deps]
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
Expand All @@ -26,10 +26,12 @@ Requires = "1"
SciMLBase = "1.70"
Zygote = "0.4.22, 0.5, 0.6"
julia = "1.6"
FastGaussQuadrature = "0.5"

[extensions]
IntegralsForwardDiffExt = "ForwardDiff"
IntegralsZygoteExt = ["Zygote", "ChainRulesCore"]
IntegralsFastGaussQuadratureExt = "FastGaussQuadrature"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -42,11 +44,13 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"

[targets]
test = ["SciMLSensitivity", "StaticArrays", "FiniteDiff", "Pkg", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore"]
test = ["SciMLSensitivity", "StaticArrays", "FiniteDiff", "Pkg", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature"]

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
2 changes: 2 additions & 0 deletions docs/src/solvers/IntegralSolvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ The following algorithms are available:
- `CubaSUAVE`: SUAVE from Cuba.jl. Requires `using IntegralsCuba`.
- `CubaDivonne`: Divonne from Cuba.jl. Requires `using IntegralsCuba`.
- `CubaCuhre`: Cuhre from Cuba.jl. Requires `using IntegralsCuba`.
- `GaussLegendre`: Uses Gauss-Legendre quadrature with nodes and weights from FastGaussQuadrature.jl.

```@docs
QuadGKJL
HCubatureJL
VEGAS
GaussLegendre
```
51 changes: 51 additions & 0 deletions ext/IntegralsFastGaussQuadratureExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
module IntegralsFastGaussQuadratureExt
using Integrals
if isdefined(Base, :get_extension)
import FastGaussQuadrature
import FastGaussQuadrature: gausslegendre
# and eventually gausschebyshev, etc.
else
import ..FastGaussQuadrature
import ..FastGaussQuadrature: gausslegendre
end
using LinearAlgebra

Integrals.gausslegendre(n) = FastGaussQuadrature.gausslegendre(n)

function gauss_legendre(f, p, lb, ub, nodes, weights)
scale = (ub - lb) / 2
shift = (lb + ub) / 2
I = dot(weights, @. f(scale * nodes + shift, $Ref(p)))
return scale * I
end
function composite_gauss_legendre(f, p, lb, ub, nodes, weights, subintervals)
h = (ub - lb) / subintervals
I = zero(h)
for i in 1:subintervals
_lb = lb + (i - 1) * h
_ub = _lb + h
I += gauss_legendre(f, p, _lb, _ub, nodes, weights)
end
return I
end

function Integrals.__solvebp_call(prob::IntegralProblem, alg::Integrals.GaussLegendre{C},
sensealg, lb, ub, p;
reltol = nothing, abstol = nothing,
maxiters = nothing) where {C}
if isinplace(prob) || lb isa AbstractArray || ub isa AbstractArray
error("GaussLegendre only accepts one-dimensional quadrature problems.")
end
@assert prob.batch == 0
@assert prob.nout == 1
if C
val = composite_gauss_legendre(prob.f, prob.p, lb, ub,
alg.nodes, alg.weights, alg.subintervals)
else
val = gauss_legendre(prob.f, prob.p, lb, ub,
alg.nodes, alg.weights)
end
err = nothing
SciMLBase.build_solution(prob, alg, val, err, retcode = ReturnCode.Success)
end
end
21 changes: 11 additions & 10 deletions src/Integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,16 @@ function SciMLBase.solve(prob::IntegralProblem,
__solvebp(prob, alg, sensealg, prob.lb, prob.ub, prob.p; kwargs...)
end
# Throw error if alg is not provided, as defaults are not implemented.
SciMLBase.solve(::IntegralProblem) = throw(ArgumentError("""
No integration algorithm `alg` was supplied as the second positional argument.
Reccomended integration algorithms are:
For scalar functions: QuadGKJL()
For ≤ 8 dimensional vector functions: HCubatureJL()
For > 8 dimensional vector functions: MonteCarloIntegration.vegas(f, st, en, kwargs...)
See the docstrings of the different algorithms for more detail.
"""
))
function SciMLBase.solve(::IntegralProblem)
throw(ArgumentError("""
No integration algorithm `alg` was supplied as the second positional argument.
Reccomended integration algorithms are:
For scalar functions: QuadGKJL()
For ≤ 8 dimensional vector functions: HCubatureJL()
For > 8 dimensional vector functions: MonteCarloIntegration.vegas(f, st, en, kwargs...)
See the docstrings of the different algorithms for more detail.
"""))
end

# Give a layer to intercept with AD
__solvebp(args...; kwargs...) = __solvebp_call(args...; kwargs...)
Expand Down Expand Up @@ -188,5 +189,5 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, lb, ub, p;
SciMLBase.build_solution(prob, alg, val, err, chi = chi, retcode = ReturnCode.Success)
end

export QuadGKJL, HCubatureJL, VEGAS
export QuadGKJL, HCubatureJL, VEGAS, GaussLegendre
end # module
40 changes: 40 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,43 @@ struct VEGAS <: SciMLBase.AbstractIntegralAlgorithm
debug::Bool
end
VEGAS(; nbins = 100, ncalls = 1000, debug = false) = VEGAS(nbins, ncalls, debug)

"""
GaussLegendre{C, N, W}
Struct for evaluating an integral via (composite) Gauss-Legendre quadrature.
The field `C` will be `true` if `subintervals > 1`, and `false` otherwise.
The fields `nodes::N` and `weights::W` are defined by
`nodes, weights = gausslegendre(n)` for a given number of nodes `n`.
The field `subintervals::Int64 = 1` (with default value `1`) defines the
number of intervals to partition the original interval of integration
`[a, b]` into, splitting it into `[xⱼ, xⱼ₊₁]` for `j = 1,…,subintervals`,
where `xⱼ = a + (j-1)h` and `h = (b-a)/subintervals`. Gauss-Legendre
quadrature is then applied on each subinterval. For example, if
`[a, b] = [-1, 1]` and `subintervals = 2`, then Gauss-Legendre
quadrature will be applied separately on `[-1, 0]` and `[0, 1]`,
summing the two results.
"""
struct GaussLegendre{C, N, W} <: SciMLBase.AbstractIntegralAlgorithm
nodes::N
weights::W
subintervals::Int64
function GaussLegendre(nodes::N, weights::W, subintervals = 1) where {N, W}
if subintervals > 1
return new{true, N, W}(nodes, weights, subintervals)
elseif subintervals == 1
return new{false, N, W}(nodes, weights, subintervals)
else
throw(ArgumentError("Cannot use a nonpositive number of subintervals."))
end
end
end
function gausslegendre end
function GaussLegendre(; n = 250, subintervals = 1, nodes = nothing, weights = nothing)
if isnothing(nodes) || isnothing(weights)
nodes, weights = gausslegendre(n)
end
return GaussLegendre(nodes, weights, subintervals)
end
1 change: 1 addition & 0 deletions src/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
function __init__()
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin include("../ext/IntegralsForwardDiffExt.jl") end
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin include("../ext/IntegralsZygoteExt.jl") end
@require FastGaussQuadrature="442a2c76-b920-505d-bb47-c5924d526838" begin include("../ext/IntegralsFastGaussQuadratureExt.jl") end
end
end
99 changes: 99 additions & 0 deletions test/gaussian_quadrature_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
using Integrals, Test, FastGaussQuadrature

#=
f = (x, p) -> x^3 * sin(5x)
n = 250
nodes, weights = gausslegendre(n)
I = gauss_legendre(f, nothing, -1, 1, nodes, weights)
@test I ≈ 2 / (625) * (69sin(5) - 95cos(5))
I = Integrals.composite_gauss_legendre(f, nothing, -1, 1, nodes, weights, 2)
@test I ≈ 2 / (625) * (69sin(5) - 95cos(5))
f = (x, p) -> (x + p) * abs(x)
n = 100
nodes, weights = gausslegendre(n)
I = Integrals.gauss_legendre(f, 0.0, -2, 2, nodes, weights)
Ic = Integrals.composite_gauss_legendre(f, 6, -2, 2, nodes, weights, 5)
@inferred Integrals.gauss_legendre(f, 0.0, -2, 2, nodes, weights)
@inferred Integrals.composite_gauss_legendre(f, 6, -2, 2, nodes, weights, 5)
@test I≈0.0 atol=1e-6
@test Ic≈24 rtol=1e-4
=#

alg = GaussLegendre()
n = 250
nd, wt = gausslegendre(n)
@test alg.nodes == nd
@test alg.weights == wt
@test alg.subintervals == 1
alg = GaussLegendre(n = 125, subintervals = 3)
n = 125
nd, wt = gausslegendre(n)
@test alg.nodes == nd
@test alg.weights == wt
@test alg.subintervals == 3
@test typeof(alg).parameters[1]
nd, wt = gausslegendre(275)
alg = GaussLegendre(nodes = nd, weights = wt)
@test !typeof(alg).parameters[1]
@test alg.nodes == nd
@test alg.weights == wt
@test alg.subintervals == 1
alg = GaussLegendre(nodes = nd, weights = wt, subintervals = 20)
@test typeof(alg).parameters[1]
@test alg.nodes == nd
@test alg.weights == wt
@test alg.subintervals == 20

f = (x, p) -> 5x + sin(x) - p * exp(x)
prob = IntegralProblem(f, -5, 3, 3.3)
alg = GaussLegendre()
sol = solve(prob, alg)
@test isnothing(sol.chi)
@test sol.alg === alg
@test sol.prob === prob
@test isnothing(sol.resid)
@test SciMLBase.successful_retcode(sol)
@test sol.u -exp(3) * 3.3 + 3.3 / exp(5) - 40 + cos(5) - cos(3)
alg = GaussLegendre(subintervals = 7)
sol = solve(prob, alg)
@test sol.u -exp(3) * 3.3 + 3.3 / exp(5) - 40 + cos(5) - cos(3)

f = (x, p) -> exp(-x^2)
prob = IntegralProblem(f, 0.0, Inf)
alg = GaussLegendre()
sol = solve(prob, alg)
@test sol.u sqrt(π)/2
alg = GaussLegendre(subintervals=1)
@test sol.u sqrt(π)/2
alg = GaussLegendre(subintervals=17)
@test sol.u sqrt(π)/2

prob = IntegralProblem(f, -Inf, Inf)
alg = GaussLegendre()
sol = solve(prob, alg)
@test sol.u sqrt(π)
alg = GaussLegendre(subintervals=1)
@test sol.u sqrt(π)
alg = GaussLegendre(subintervals=17)
@test sol.u sqrt(π)

prob = IntegralProblem(f, -Inf, 0.0)
alg = GaussLegendre()
sol = solve(prob, alg)
@test sol.u sqrt(π)/2
alg = GaussLegendre(subintervals=1)
@test sol.u sqrt(π)/2
alg = GaussLegendre(subintervals=17)
@test sol.u sqrt(π)/2

# Make sure broadcasting correctly handles the argument p
f = (x, p) -> 1 + x + x^p[1] - cos(x*p[2]) + exp(x)*p[3]
p = [0.3, 1.3, -0.5]
prob = IntegralProblem(f, 2, 6.3, p)
alg = GaussLegendre()
sol = solve(prob, alg)
@test sol.u -240.25235266303063249920743158729
alg = GaussLegendre(n = 500, subintervals = 17)
sol = solve(prob, alg)
@test sol.u -240.25235266303063249920743158729
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ dev_subpkg("IntegralsCubature")
@time @safetestset "Interface Tests" begin include("interface_tests.jl") end
@time @safetestset "Derivative Tests" begin include("derivative_tests.jl") end
@time @safetestset "Infinite Integral Tests" begin include("inf_integral_tests.jl") end
@time @safetestset "Gaussian Quadrature Tests" begin include("gaussian_quadrature_tests.jl") end

0 comments on commit f7a60d9

Please sign in to comment.