-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gaussian quadrature #141
Merged
Merged
Gaussian quadrature #141
Changes from all commits
Commits
Show all changes
76 commits
Select commit
Hold shift + click to select a range
317f2ab
basic version of gauss-legendre
elastufka ba7d1e7
fstrings for my sanity
elastufka 4bc447b
fstrings for my sanity
elastufka 345da2c
weights and points multidimensional
elastufka 170ae2a
transform xi,wi correctly
elastufka 1d27e08
basic version of gauss-legendre
elastufka 4547646
fstrings for my sanity
elastufka 2fa40b0
fstrings for my sanity
elastufka 399234c
weights and points multidimensional
elastufka 0c8aef3
transform xi,wi correctly
elastufka 7bf8fa0
let function to integrate accept args, c.f. scipy.nquad
elastufka d630a4f
Merge branch 'main' of https://github.com/elastufka/torchquad into ga…
elastufka 83da697
let function to integrate accept args, c.f. scipy.nquad
elastufka b5cca10
any edits
elastufka fc37852
add numpy import
elastufka 2d79c48
autoray
elastufka 444aaef
add Gaussian quadrature methods
elastufka 9e196d4
fix import
elastufka c08ddfa
change anp.inf to numpy.inf
elastufka 25b0747
fix interval transformation and clean up
elastufka adae7d7
make sure tensors are on same device
elastufka 24357bc
make sure tensors are on same devicepart 2
elastufka 5c458a4
make sure tensors are on same devicepart 3
elastufka 58d7de7
make sure tensors are on same devicepart 4
elastufka 684a1b5
make sure tensors are on same devicepart 5
elastufka 3ad979c
add special import
elastufka 9a94a55
Merge remote-tracking branch 'esa/develop' into gaussian-quadrature
elastufka 6d42ab8
add tests to /tests
elastufka 3fd4a00
run autopep8, add docstring
elastufka 50f3785
Merge branch 'multi_dim_integrand' into gaussian_quadrature
ilan-gold 868c8f2
(feat): cache for roots.
ilan-gold 5073116
(feat): refactor out grid integration procedure
ilan-gold 5868263
(feat): gaussian integration refactored, some tests passing
ilan-gold 2d4adf4
(fix): scaling constant
ilan-gold f0b0859
(chore): higher dim integrals testing
ilan-gold 44f063b
(feat): weights correct for multi-dim integrands.
ilan-gold 96f2af3
(fix): correct number of argument.
ilan-gold 0c21d1e
(fix): remove non-legendre tests.
ilan-gold 1bf58fb
Merge branch 'multi_dim_integrand' into gaussian_quadrature
ilan-gold 1d2a53c
(fix): import GaussLegendre
ilan-gold c6a6a85
(fix): ensure grid and weights are correct type
ilan-gold b2ac2f3
(style): docstrings.
ilan-gold 9a33275
(fix): default `grid_func`
ilan-gold f619485
(fix): `_cached_poi...` returns tuple, not ndarray
ilan-gold 69c5354
(fix): propagate `backend` correctly.
ilan-gold 5ab18ce
Merge branch 'develop' into gaussian-quadrature
ilan-gold 738d18f
(chore): export base types
ilan-gold ce20c33
(feat): add jit for gausssian
ilan-gold 40c5e58
(feat): backward diff
ilan-gold 86c9c20
(fix): env issue
ilan-gold 76c4a24
Fixed tests badge
gomezzz 76fc60a
Merge branch 'main' into gaussian_quadrature
ilan-gold 5e2668d
Merge remote-tracking branch 'trunk/develop' into gaussian_quadrature
ilan-gold 215fc5f
(chore): cleanup
ilan-gold edfd308
(fix): `intergal` -> `integral`
ilan-gold e3827fa
(chore): add tutorial
ilan-gold 3ff0d2b
(fix): change to `argnums` to work around decorator
ilan-gold 09ecbd2
Merge branch 'gaussian-quadrature' of https://github.com/elastufka/to…
ilan-gold 644d420
Merge branch 'develop' into gaussian_quadrature
ilan-gold 0b9def6
(fix): add fix from other PR
ilan-gold c6f06de
,erge
ilan-gold 88bd469
(feat): add (broken) tests for gauss jit
ilan-gold 308a990
(chore): remove unused import
ilan-gold 0b1a12a
(fix): use `item` for `N` when `jit` with `jax`
ilan-gold d732ca4
(fix): `domain` for jit gauss `calculate_result`
ilan-gold 93ba243
(chore): `black`
ilan-gold c31ccae
(chore): erroneous diff
ilan-gold 24b94b8
(chore): remove erroneous print
ilan-gold d44e42d
(fix): correct comment
ilan-gold b8c18b9
(fix): clean up gaussian tests
ilan-gold 1582735
(chore): add comments.
ilan-gold 491cce3
(chore): formatting
ilan-gold b3d442e
(fix): error of 1D integral
ilan-gold 270de4d
Merge remote-tracking branch 'trunk/fixing-coverage-workflow' into ga…
ilan-gold 7383d5d
(fix): increase bounds.
ilan-gold 31441c6
Merge branch 'develop' into gaussian_quadrature
ilan-gold File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import numpy | ||
from autoray import numpy as anp | ||
from .grid_integrator import GridIntegrator | ||
|
||
|
||
class Gaussian(GridIntegrator): | ||
"""Gaussian quadrature methods inherit from this. Default behaviour is Gauss-Legendre quadrature on [-1,1].""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.name = "Gauss-Legendre" | ||
self.root_fn = numpy.polynomial.legendre.leggauss | ||
self.root_args = () | ||
self.default_integration_domain = [[-1, 1]] | ||
self.transform_interval = True | ||
self._cache = {} | ||
|
||
def integrate(self, fn, dim, N=8, integration_domain=None, backend=None): | ||
"""Integrates the passed function on the passed domain using Simpson's rule. | ||
|
||
Args: | ||
fn (func): The function to integrate over. | ||
dim (int): Dimensionality of the integration domain. | ||
N (int, optional): Total number of sample points to use for the integration. Should be odd. Defaults to 3 points per dimension if None is given. | ||
integration_domain (list or backend tensor, optional): Integration domain, e.g. [[-1,1],[0,1]]. Defaults to [-1,1]^dim. It also determines the numerical backend if possible. | ||
backend (string, optional): Numerical backend. This argument is ignored if the backend can be inferred from integration_domain. Defaults to the backend from the latest call to set_up_backend or "torch" for backwards compatibility. | ||
|
||
Returns: | ||
backend-specific number: Integral value | ||
""" | ||
return super().integrate(fn, dim, N, integration_domain, backend) | ||
|
||
def _weights(self, N, dim, backend, requires_grad=False): | ||
"""return the weights, broadcast across the dimensions, generated from the polynomial of choice | ||
|
||
Args: | ||
N (int): number of nodes | ||
dim (int): number of dimensions | ||
backend (string): which backend array to return | ||
|
||
Returns: | ||
backend tensor: the weights | ||
""" | ||
weights = anp.array(self._cached_points_and_weights(N)[1], like=backend) | ||
if backend == "torch": | ||
weights.requires_grad = requires_grad | ||
return anp.prod( | ||
anp.array( | ||
anp.stack( | ||
list(anp.meshgrid(*([weights] * dim))), like=backend, dim=0 | ||
) | ||
), | ||
axis=0, | ||
).ravel() | ||
else: | ||
return anp.prod( | ||
anp.meshgrid(*([weights] * dim), like=backend), axis=0 | ||
).ravel() | ||
|
||
def _roots(self, N, backend, requires_grad=False): | ||
"""return the roots generated from the polynomial of choice | ||
|
||
Args: | ||
N (int): number of nodes | ||
backend (string): which backend array to return | ||
|
||
Returns: | ||
backend tensor: the roots | ||
""" | ||
roots = anp.array(self._cached_points_and_weights(N)[0], like=backend) | ||
if requires_grad: | ||
roots.requires_grad = True | ||
return roots | ||
|
||
@property | ||
def _grid_func(self): | ||
ilan-gold marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
function for generating a grid to be integrated over i.e., the polynomial roots, resized to the domain. | ||
""" | ||
|
||
def f(a, b, N, requires_grad, backend=None): | ||
return self._resize_roots(a, b, self._roots(N, backend, requires_grad)) | ||
|
||
return f | ||
|
||
def _resize_roots(self, a, b, roots): # scale from [-1,1] to [a,b] | ||
"""resize the roots based on domain of [a,b] | ||
|
||
Args: | ||
a (backend tensor): lower bound | ||
b (backend tensor): upper bound | ||
roots (backend tensor): polynomial nodes | ||
|
||
Returns: | ||
backend tensor: rescaled roots | ||
""" | ||
return roots | ||
gomezzz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# credit for the idea https://github.com/scipy/scipy/blob/dde50595862a4f9cede24b5d1c86935c30f1f88a/scipy/integrate/_quadrature.py#L72 | ||
def _cached_points_and_weights(self, N): | ||
"""wrap the calls to get weights/roots in a cache | ||
|
||
Args: | ||
N (int): number of nodes to return | ||
backend (string): which backend to use | ||
|
||
Returns: | ||
tuple: nodes and weights | ||
""" | ||
root_args = (N, *self.root_args) | ||
if not isinstance(N, int): | ||
if hasattr(N, "item"): | ||
root_args = (N.item(), *self.root_args) | ||
else: | ||
raise NotImplementedError( | ||
f"N {N} is not an int and lacks an `item` method" | ||
) | ||
if root_args in self._cache: | ||
return self._cache[root_args] | ||
self._cache[root_args] = self.root_fn(*root_args) | ||
return self._cache[root_args] | ||
|
||
@staticmethod | ||
def _apply_composite_rule(cur_dim_areas, dim, hs, domain): | ||
"""Apply "composite" rule for gaussian integrals | ||
|
||
cur_dim_areas will contain the areas per dimension | ||
""" | ||
# We collapse dimension by dimension | ||
for cur_dim in range(dim): | ||
cur_dim_areas = ( | ||
0.5 | ||
* (domain[cur_dim][1] - domain[cur_dim][0]) | ||
* anp.sum(cur_dim_areas, axis=len(cur_dim_areas.shape) - 1) | ||
) | ||
return cur_dim_areas | ||
|
||
|
||
class GaussLegendre(Gaussian): | ||
"""Gauss Legendre quadrature rule in torch. See https://en.wikipedia.org/wiki/Gaussian_quadrature#Gauss%E2%80%93Legendre_quadrature. | ||
|
||
Examples | ||
-------- | ||
>>> gl=torchquad.GaussLegendre() | ||
>>> integral = gl.integrate(lambda x:np.sin(x), dim=1, N=101, integration_domain=[[0,5]]) #integral from 0 to 5 of np.sin(x) | ||
|TQ-INFO| Computed integral was 0.7163378000259399 #analytic result = 1-np.cos(5)""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
def _resize_roots(self, a, b, roots): # scale from [-1,1] to [a,b] | ||
return ((b - a) / 2) * roots + ((a + b) / 2) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to expose these to the users? 🤔
Or is this with the idea of them also defining their own integrators based on them? Then it could be nice but we might want to either give an example or at least mention them in the docs 🤔 (if you prefer we can create a new issue for that)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do want to expose these. I use them for a custom integrator. I can add something in the docs.