Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gaussian quadrature #141

Merged
merged 76 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
317f2ab
basic version of gauss-legendre
elastufka Apr 7, 2022
ba7d1e7
fstrings for my sanity
elastufka Apr 7, 2022
4bc447b
fstrings for my sanity
elastufka Apr 7, 2022
345da2c
weights and points multidimensional
elastufka Apr 7, 2022
170ae2a
transform xi,wi correctly
elastufka Apr 8, 2022
1d27e08
basic version of gauss-legendre
elastufka Apr 7, 2022
4547646
fstrings for my sanity
elastufka Apr 7, 2022
2fa40b0
fstrings for my sanity
elastufka Apr 7, 2022
399234c
weights and points multidimensional
elastufka Apr 7, 2022
0c8aef3
transform xi,wi correctly
elastufka Apr 8, 2022
7bf8fa0
let function to integrate accept args, c.f. scipy.nquad
elastufka Apr 8, 2022
d630a4f
Merge branch 'main' of https://github.com/elastufka/torchquad into ga…
elastufka Apr 8, 2022
83da697
let function to integrate accept args, c.f. scipy.nquad
elastufka Apr 8, 2022
b5cca10
any edits
elastufka Apr 13, 2022
fc37852
add numpy import
elastufka Apr 13, 2022
2d79c48
autoray
elastufka Apr 13, 2022
444aaef
add Gaussian quadrature methods
elastufka Apr 21, 2022
9e196d4
fix import
elastufka Apr 22, 2022
c08ddfa
change anp.inf to numpy.inf
elastufka Apr 25, 2022
25b0747
fix interval transformation and clean up
elastufka Apr 26, 2022
adae7d7
make sure tensors are on same device
elastufka Apr 26, 2022
24357bc
make sure tensors are on same devicepart 2
elastufka Apr 26, 2022
5c458a4
make sure tensors are on same devicepart 3
elastufka Apr 26, 2022
58d7de7
make sure tensors are on same devicepart 4
elastufka Apr 26, 2022
684a1b5
make sure tensors are on same devicepart 5
elastufka Apr 26, 2022
3ad979c
add special import
elastufka May 12, 2022
9a94a55
Merge remote-tracking branch 'esa/develop' into gaussian-quadrature
elastufka May 16, 2022
6d42ab8
add tests to /tests
elastufka May 16, 2022
3fd4a00
run autopep8, add docstring
elastufka May 30, 2022
50f3785
Merge branch 'multi_dim_integrand' into gaussian_quadrature
ilan-gold Jan 13, 2023
868c8f2
(feat): cache for roots.
ilan-gold Jan 13, 2023
5073116
(feat): refactor out grid integration procedure
ilan-gold Jan 13, 2023
5868263
(feat): gaussian integration refactored, some tests passing
ilan-gold Jan 13, 2023
2d4adf4
(fix): scaling constant
ilan-gold Jan 13, 2023
f0b0859
(chore): higher dim integrals testing
ilan-gold Jan 13, 2023
44f063b
(feat): weights correct for multi-dim integrands.
ilan-gold Jan 14, 2023
96f2af3
(fix): correct number of argument.
ilan-gold Jan 14, 2023
0c21d1e
(fix): remove non-legendre tests.
ilan-gold Jan 14, 2023
1bf58fb
Merge branch 'multi_dim_integrand' into gaussian_quadrature
ilan-gold Jan 14, 2023
1d2a53c
(fix): import GaussLegendre
ilan-gold Jan 14, 2023
c6a6a85
(fix): ensure grid and weights are correct type
ilan-gold Jan 14, 2023
b2ac2f3
(style): docstrings.
ilan-gold Jan 14, 2023
9a33275
(fix): default `grid_func`
ilan-gold Jan 14, 2023
f619485
(fix): `_cached_poi...` returns tuple, not ndarray
ilan-gold Jan 14, 2023
69c5354
(fix): propagate `backend` correctly.
ilan-gold Jan 14, 2023
5ab18ce
Merge branch 'develop' into gaussian-quadrature
ilan-gold Jan 16, 2023
738d18f
(chore): export base types
ilan-gold Jan 16, 2023
ce20c33
(feat): add jit for gausssian
ilan-gold Jan 17, 2023
40c5e58
(feat): backward diff
ilan-gold Jan 19, 2023
86c9c20
(fix): env issue
ilan-gold Jan 19, 2023
76c4a24
Fixed tests badge
gomezzz Mar 1, 2023
76fc60a
Merge branch 'main' into gaussian_quadrature
ilan-gold Mar 3, 2023
5e2668d
Merge remote-tracking branch 'trunk/develop' into gaussian_quadrature
ilan-gold Mar 3, 2023
215fc5f
(chore): cleanup
ilan-gold Mar 3, 2023
edfd308
(fix): `intergal` -> `integral`
ilan-gold Mar 3, 2023
e3827fa
(chore): add tutorial
ilan-gold Mar 3, 2023
3ff0d2b
(fix): change to `argnums` to work around decorator
ilan-gold Mar 6, 2023
09ecbd2
Merge branch 'gaussian-quadrature' of https://github.com/elastufka/to…
ilan-gold Mar 6, 2023
644d420
Merge branch 'develop' into gaussian_quadrature
ilan-gold Mar 8, 2023
0b9def6
(fix): add fix from other PR
ilan-gold Mar 8, 2023
c6f06de
,erge
ilan-gold Mar 8, 2023
88bd469
(feat): add (broken) tests for gauss jit
ilan-gold Mar 8, 2023
308a990
(chore): remove unused import
ilan-gold Mar 8, 2023
0b1a12a
(fix): use `item` for `N` when `jit` with `jax`
ilan-gold Mar 16, 2023
d732ca4
(fix): `domain` for jit gauss `calculate_result`
ilan-gold Mar 16, 2023
93ba243
(chore): `black`
ilan-gold Mar 16, 2023
c31ccae
(chore): erroneous diff
ilan-gold Mar 17, 2023
24b94b8
(chore): remove erroneous print
ilan-gold Mar 20, 2023
d44e42d
(fix): correct comment
ilan-gold Mar 20, 2023
b8c18b9
(fix): clean up gaussian tests
ilan-gold Mar 20, 2023
1582735
(chore): add comments.
ilan-gold Mar 20, 2023
491cce3
(chore): formatting
ilan-gold Mar 20, 2023
b3d442e
(fix): error of 1D integral
ilan-gold Mar 20, 2023
270de4d
Merge remote-tracking branch 'trunk/fixing-coverage-workflow' into ga…
ilan-gold Mar 27, 2023
7383d5d
(fix): increase bounds.
ilan-gold Apr 18, 2023
31441c6
Merge branch 'develop' into gaussian_quadrature
ilan-gold Apr 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
*** Based on https://github.com/othneildrew/Best-README-Template
-->

![Read the Docs (version)](https://img.shields.io/readthedocs/torchquad/main?style=flat-square) ![GitHub Workflow Status (branch)](https://img.shields.io/github/actions/workflow/status/esa/torchquad/.github/workflows/run_tests.yml/main?style=flat-square) ![GitHub last commit](https://img.shields.io/github/last-commit/esa/torchquad?style=flat-square)
![Read the Docs (version)](https://img.shields.io/readthedocs/torchquad/main?style=flat-square) [![Tests](https://github.com/esa/torchquad/actions/workflows/run_tests.yml/badge.svg)](https://github.com/esa/torchquad/actions/workflows/run_tests.yml) ![GitHub last commit](https://img.shields.io/github/last-commit/esa/torchquad?style=flat-square)
![GitHub](https://img.shields.io/github/license/esa/torchquad?style=flat-square) ![Conda (channel only)](https://img.shields.io/conda/vn/conda-forge/torchquad?style=flat-square) ![PyPI](https://img.shields.io/pypi/v/torchquad?style=flat-square) ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torchquad?style=flat-square)

![GitHub contributors](https://img.shields.io/github/contributors/esa/torchquad?style=flat-square)
Expand Down
36 changes: 36 additions & 0 deletions docs/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -756,4 +756,40 @@ Now let's see how to do this a bit more simply, and in a way that provides signf

torch.all(torch.isclose(result_vectorized, result)) # True!

Custom Integrators
------------------

It is of course possible to extend our provided Integrators, perhaps for a special class of functions or for a new algorithm.

.. code:: ipython3

class GaussHermite(Gaussian):
"""Gauss Hermite quadrature rule in torch, for integrals of the form :math:`\\int_{-\\infty}^{+\\infty} e^{-x^{2}} f(x) dx`. It will correctly integrate
polynomials of degree :math:`2n - 1` or less over the interval
:math:`[-\\infty, \\infty]` with weight function :math:`f(x) = e^{-x^2}`. See https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature
"""

def __init__(self):
super().__init__()
self.name = "Gauss-Hermite"
self.root_fn = scipy.special.roots_hermite
self.default_integration_domain = [[-1 * numpy.inf, numpy.inf]]
self.wrapper_func = None

@staticmethod
def _apply_composite_rule(cur_dim_areas, dim, hs, domain):
"""Apply "composite" rule for gaussian integrals
cur_dim_areas will contain the areas per dimension
"""
# We collapse dimension by dimension
for _ in range(dim):
cur_dim_areas = anp.sum(cur_dim_areas, axis=len(cur_dim_areas.shape) - 1)
return cur_dim_areas

gh=torchquad.GaussHermite()
integral=gh.integrate(lambda x: 1-x,dim=1,N=200) #integral from -inf to inf of np.exp(-(x**2))*(1-x)
# Computed integral was 1.7724538509055168.
# analytic result = sqrt(pi)



8 changes: 5 additions & 3 deletions torchquad/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from .integration.simpson import Simpson
from .integration.boole import Boole
from .integration.vegas import VEGAS
from .integration.gaussian import GaussLegendre
from .integration.grid_integrator import GridIntegrator
from .integration.base_integrator import BaseIntegrator

from .integration.rng import RNG

Expand All @@ -22,16 +25,15 @@
from .utils.deployment_test import _deployment_test

__all__ = [
"GridIntegrator",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to expose these to the users? 🤔

Or is this with the idea of them also defining their own integrators based on them? Then it could be nice but we might want to either give an example or at least mention them in the docs 🤔 (if you prefer we can create a new issue for that)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do want to expose these. I use them for a custom integrator. I can add something in the docs.

"BaseIntegrator",
"IntegrationGrid",
"MonteCarlo",
"Trapezoid",
"Simpson",
"Boole",
"VEGAS",
"GaussLegendre",
"GaussJacobi",
"GaussLaguerre",
"GaussHermite",
"RNG",
"plot_convergence",
"plot_runtime",
Expand Down
22 changes: 16 additions & 6 deletions torchquad/integration/base_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ def _eval(self, points, weights=None, args=None):
weights (backend tensor, optional): Integration weights. Defaults to None.
args (list or tuple, optional): Any arguments required by the function. Defaults to None.
"""
result, num_points = self.evaluate_integrand(self._fn, points,weights=weights,args=args)
result, num_points = self.evaluate_integrand(
self._fn, points, weights=weights, args=args
)
self._nr_of_fevals += num_points
return result

@staticmethod
def evaluate_integrand(fn, points, weights=None,args=None):
def evaluate_integrand(fn, points, weights=None, args=None):
"""Evaluate the integrand function at the passed points

Args:
Expand All @@ -56,12 +58,12 @@ def evaluate_integrand(fn, points, weights=None,args=None):
int: Number of evaluated points
"""
num_points = points.shape[0]

if args is None:
args = ()

result = fn(points, *args)

if infer_backend(result) != infer_backend(points):
warnings.warn(
"The passed function's return value has a different numerical backend than the passed points. Will try to convert. Note that this may be slow as it results in memory transfers between CPU and GPU, if torchquad uses the GPU."
Expand All @@ -77,8 +79,15 @@ def evaluate_integrand(fn, points, weights=None,args=None):
)

if weights is not None:
if len(result.shape) > 1:
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved
integrand_shape = anp.array(
result.shape[1:], like=infer_backend(points)
)
weights = anp.repeat(
anp.expand_dims(weights, axis=1), anp.prod(integrand_shape)
).reshape((weights.shape[0], *(integrand_shape)))
result *= weights

return result, num_points

@staticmethod
Expand All @@ -105,3 +114,4 @@ def _check_inputs(dim=None, N=None, integration_domain=None):
if dim is not None and dim != dim_domain:
raise ValueError(
"The dimension of the integration domain must match the passed function dimensionality dim."
)
2 changes: 1 addition & 1 deletion torchquad/integration/boole.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def integrate(self, fn, dim, N=None, integration_domain=None, backend=None):
return super().integrate(fn, dim, N, integration_domain, backend)

@staticmethod
def _apply_composite_rule(cur_dim_areas, dim, hs):
def _apply_composite_rule(cur_dim_areas, dim, hs, domain):
"""Apply composite Boole quadrature.
cur_dim_areas will contain the areas per dimension
"""
Expand Down
Loading