Skip to content

Commit

Permalink
Gaussian quadrature (#141)
Browse files Browse the repository at this point in the history
* basic version of gauss-legendre

* fstrings for my sanity

* fstrings for my sanity

* weights and points multidimensional

* transform xi,wi correctly

* basic version of gauss-legendre

* fstrings for my sanity

* fstrings for my sanity

* weights and points multidimensional

* transform xi,wi correctly

* let function to integrate accept args, c.f. scipy.nquad

* any edits

* add numpy import

* autoray

* add Gaussian quadrature methods

* fix import

* change anp.inf to numpy.inf

* fix interval transformation and clean up

* make sure tensors are on same device

* make sure tensors are on same devicepart 2

* make sure tensors are on same devicepart 3

* make sure tensors are on same devicepart 4

* make sure tensors are on same devicepart 5

* add special import

* add tests to /tests

* run autopep8, add docstring

* (feat): cache for roots.

* (feat): refactor out grid integration procedure

* (feat): gaussian integration refactored, some tests passing

* (fix): scaling constant

* (chore): higher dim integrals testing

* (feat): weights correct for multi-dim integrands.

* (fix): correct number of argument.

* (fix): remove non-legendre tests.

* (fix): import GaussLegendre

* (fix): ensure grid and weights are correct type

* (style): docstrings.

* (fix): default `grid_func`

* (fix): `_cached_poi...` returns tuple, not ndarray

* (fix): propagate `backend` correctly.

* (chore): export base types

* (feat): add jit for gausssian

* (feat): backward diff

* (fix): env issue

* Fixed tests badge

* (chore): cleanup

* (fix): `intergal` -> `integral`

* (chore): add tutorial

* (fix): change to `argnums` to work around decorator

* (fix): add fix from other PR

* (feat): add (broken) tests for gauss jit

* (chore): remove unused import

* (fix): use `item` for `N` when `jit` with `jax`

* (fix): `domain` for jit gauss `calculate_result`

* (chore): `black`

* (chore): erroneous diff

* (chore): remove erroneous print

* (fix): correct comment

* (fix): clean up gaussian tests

* (chore): add comments.

* (chore): formatting

* (fix): error of 1D integral

* (fix): increase bounds.

---------

Co-authored-by: ilan-gold <ilanbassgold@gmail.com>
Co-authored-by: Pablo Gómez <contact@pablo-gomez.net>
  • Loading branch information
3 people authored Apr 19, 2023
1 parent 64a0188 commit cd53251
Show file tree
Hide file tree
Showing 14 changed files with 678 additions and 248 deletions.
36 changes: 36 additions & 0 deletions docs/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -756,4 +756,40 @@ Now let's see how to do this a bit more simply, and in a way that provides signf
torch.all(torch.isclose(result_vectorized, result)) # True!
Custom Integrators
------------------

It is of course possible to extend our provided Integrators, perhaps for a special class of functions or for a new algorithm.

.. code:: ipython3
class GaussHermite(Gaussian):
"""Gauss Hermite quadrature rule in torch, for integrals of the form :math:`\\int_{-\\infty}^{+\\infty} e^{-x^{2}} f(x) dx`. It will correctly integrate
polynomials of degree :math:`2n - 1` or less over the interval
:math:`[-\\infty, \\infty]` with weight function :math:`f(x) = e^{-x^2}`. See https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature
"""
def __init__(self):
super().__init__()
self.name = "Gauss-Hermite"
self.root_fn = scipy.special.roots_hermite
self.default_integration_domain = [[-1 * numpy.inf, numpy.inf]]
self.wrapper_func = None
@staticmethod
def _apply_composite_rule(cur_dim_areas, dim, hs, domain):
"""Apply "composite" rule for gaussian integrals
cur_dim_areas will contain the areas per dimension
"""
# We collapse dimension by dimension
for _ in range(dim):
cur_dim_areas = anp.sum(cur_dim_areas, axis=len(cur_dim_areas.shape) - 1)
return cur_dim_areas
gh=torchquad.GaussHermite()
integral=gh.integrate(lambda x: 1-x,dim=1,N=200) #integral from -inf to inf of np.exp(-(x**2))*(1-x)
# Computed integral was 1.7724538509055168.
# analytic result = sqrt(pi)
6 changes: 6 additions & 0 deletions torchquad/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from .integration.simpson import Simpson
from .integration.boole import Boole
from .integration.vegas import VEGAS
from .integration.gaussian import GaussLegendre
from .integration.grid_integrator import GridIntegrator
from .integration.base_integrator import BaseIntegrator

from .integration.rng import RNG

Expand All @@ -22,12 +25,15 @@
from .utils.deployment_test import _deployment_test

__all__ = [
"GridIntegrator",
"BaseIntegrator",
"IntegrationGrid",
"MonteCarlo",
"Trapezoid",
"Simpson",
"Boole",
"VEGAS",
"GaussLegendre",
"RNG",
"plot_convergence",
"plot_runtime",
Expand Down
33 changes: 27 additions & 6 deletions torchquad/integration/base_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,41 @@ def integrate(self):
NotImplementedError("This is an abstract base class. Should not be called.")
)

def _eval(self, points):
def _eval(self, points, weights=None, args=None):
"""Call evaluate_integrand to evaluate self._fn function at the passed points and update self._nr_of_evals
Args:
points (backend tensor): Integration points
weights (backend tensor, optional): Integration weights. Defaults to None.
args (list or tuple, optional): Any arguments required by the function. Defaults to None.
"""
result, num_points = self.evaluate_integrand(self._fn, points)
result, num_points = self.evaluate_integrand(
self._fn, points, weights=weights, args=args
)
self._nr_of_fevals += num_points
return result

@staticmethod
def evaluate_integrand(fn, points):
def evaluate_integrand(fn, points, weights=None, args=None):
"""Evaluate the integrand function at the passed points
Args:
fn (function): Integrand function
points (backend tensor): Integration points
weights (backend tensor, optional): Integration weights. Defaults to None.
args (list or tuple, optional): Any arguments required by the function. Defaults to None.
Returns:
backend tensor: Integrand function output
int: Number of evaluated points
"""
num_points = points.shape[0]
result = fn(points)

if args is None:
args = ()

result = fn(points, *args)

if infer_backend(result) != infer_backend(points):
warnings.warn(
"The passed function's return value has a different numerical backend than the passed points. Will try to convert. Note that this may be slow as it results in memory transfers between CPU and GPU, if torchquad uses the GPU."
Expand All @@ -67,17 +78,27 @@ def evaluate_integrand(fn, points):
f"where first dimension matches length of passed elements. "
)

if weights is not None:
if (
len(result.shape) > 1
): # if the the integrand is multi-dimensional, we need to reshape/repeat weights so they can be broadcast in the *=
integrand_shape = anp.array(
result.shape[1:], like=infer_backend(points)
)
weights = anp.repeat(
anp.expand_dims(weights, axis=1), anp.prod(integrand_shape)
).reshape((weights.shape[0], *(integrand_shape)))
result *= weights

return result, num_points

@staticmethod
def _check_inputs(dim=None, N=None, integration_domain=None):
"""Used to check input validity
Args:
dim (int, optional): Dimensionality of function to integrate. Defaults to None.
N (int, optional): Total number of integration points. Defaults to None.
integration_domain (list or backend tensor, optional): Integration domain, e.g. [[0,1],[1,2]]. Defaults to None.
Raises:
ValueError: if inputs are not compatible with each other.
"""
Expand Down
2 changes: 1 addition & 1 deletion torchquad/integration/boole.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def integrate(self, fn, dim, N=None, integration_domain=None, backend=None):
return super().integrate(fn, dim, N, integration_domain, backend)

@staticmethod
def _apply_composite_rule(cur_dim_areas, dim, hs):
def _apply_composite_rule(cur_dim_areas, dim, hs, domain):
"""Apply composite Boole quadrature.
cur_dim_areas will contain the areas per dimension
"""
Expand Down
152 changes: 152 additions & 0 deletions torchquad/integration/gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import numpy
from autoray import numpy as anp
from .grid_integrator import GridIntegrator


class Gaussian(GridIntegrator):
"""Gaussian quadrature methods inherit from this. Default behaviour is Gauss-Legendre quadrature on [-1,1]."""

def __init__(self):
super().__init__()
self.name = "Gauss-Legendre"
self.root_fn = numpy.polynomial.legendre.leggauss
self.root_args = ()
self.default_integration_domain = [[-1, 1]]
self.transform_interval = True
self._cache = {}

def integrate(self, fn, dim, N=8, integration_domain=None, backend=None):
"""Integrates the passed function on the passed domain using Simpson's rule.
Args:
fn (func): The function to integrate over.
dim (int): Dimensionality of the integration domain.
N (int, optional): Total number of sample points to use for the integration. Should be odd. Defaults to 3 points per dimension if None is given.
integration_domain (list or backend tensor, optional): Integration domain, e.g. [[-1,1],[0,1]]. Defaults to [-1,1]^dim. It also determines the numerical backend if possible.
backend (string, optional): Numerical backend. This argument is ignored if the backend can be inferred from integration_domain. Defaults to the backend from the latest call to set_up_backend or "torch" for backwards compatibility.
Returns:
backend-specific number: Integral value
"""
return super().integrate(fn, dim, N, integration_domain, backend)

def _weights(self, N, dim, backend, requires_grad=False):
"""return the weights, broadcast across the dimensions, generated from the polynomial of choice
Args:
N (int): number of nodes
dim (int): number of dimensions
backend (string): which backend array to return
Returns:
backend tensor: the weights
"""
weights = anp.array(self._cached_points_and_weights(N)[1], like=backend)
if backend == "torch":
weights.requires_grad = requires_grad
return anp.prod(
anp.array(
anp.stack(
list(anp.meshgrid(*([weights] * dim))), like=backend, dim=0
)
),
axis=0,
).ravel()
else:
return anp.prod(
anp.meshgrid(*([weights] * dim), like=backend), axis=0
).ravel()

def _roots(self, N, backend, requires_grad=False):
"""return the roots generated from the polynomial of choice
Args:
N (int): number of nodes
backend (string): which backend array to return
Returns:
backend tensor: the roots
"""
roots = anp.array(self._cached_points_and_weights(N)[0], like=backend)
if requires_grad:
roots.requires_grad = True
return roots

@property
def _grid_func(self):
"""
function for generating a grid to be integrated over i.e., the polynomial roots, resized to the domain.
"""

def f(a, b, N, requires_grad, backend=None):
return self._resize_roots(a, b, self._roots(N, backend, requires_grad))

return f

def _resize_roots(self, a, b, roots): # scale from [-1,1] to [a,b]
"""resize the roots based on domain of [a,b]
Args:
a (backend tensor): lower bound
b (backend tensor): upper bound
roots (backend tensor): polynomial nodes
Returns:
backend tensor: rescaled roots
"""
return roots

# credit for the idea https://github.com/scipy/scipy/blob/dde50595862a4f9cede24b5d1c86935c30f1f88a/scipy/integrate/_quadrature.py#L72
def _cached_points_and_weights(self, N):
"""wrap the calls to get weights/roots in a cache
Args:
N (int): number of nodes to return
backend (string): which backend to use
Returns:
tuple: nodes and weights
"""
root_args = (N, *self.root_args)
if not isinstance(N, int):
if hasattr(N, "item"):
root_args = (N.item(), *self.root_args)
else:
raise NotImplementedError(
f"N {N} is not an int and lacks an `item` method"
)
if root_args in self._cache:
return self._cache[root_args]
self._cache[root_args] = self.root_fn(*root_args)
return self._cache[root_args]

@staticmethod
def _apply_composite_rule(cur_dim_areas, dim, hs, domain):
"""Apply "composite" rule for gaussian integrals
cur_dim_areas will contain the areas per dimension
"""
# We collapse dimension by dimension
for cur_dim in range(dim):
cur_dim_areas = (
0.5
* (domain[cur_dim][1] - domain[cur_dim][0])
* anp.sum(cur_dim_areas, axis=len(cur_dim_areas.shape) - 1)
)
return cur_dim_areas


class GaussLegendre(Gaussian):
"""Gauss Legendre quadrature rule in torch. See https://en.wikipedia.org/wiki/Gaussian_quadrature#Gauss%E2%80%93Legendre_quadrature.
Examples
--------
>>> gl=torchquad.GaussLegendre()
>>> integral = gl.integrate(lambda x:np.sin(x), dim=1, N=101, integration_domain=[[0,5]]) #integral from 0 to 5 of np.sin(x)
|TQ-INFO| Computed integral was 0.7163378000259399 #analytic result = 1-np.cos(5)"""

def __init__(self):
super().__init__()

def _resize_roots(self, a, b, roots): # scale from [-1,1] to [a,b]
return ((b - a) / 2) * roots + ((a + b) / 2)
Loading

0 comments on commit cd53251

Please sign in to comment.