Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gaussian quadrature #141

Merged
merged 76 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
317f2ab
basic version of gauss-legendre
elastufka Apr 7, 2022
ba7d1e7
fstrings for my sanity
elastufka Apr 7, 2022
4bc447b
fstrings for my sanity
elastufka Apr 7, 2022
345da2c
weights and points multidimensional
elastufka Apr 7, 2022
170ae2a
transform xi,wi correctly
elastufka Apr 8, 2022
1d27e08
basic version of gauss-legendre
elastufka Apr 7, 2022
4547646
fstrings for my sanity
elastufka Apr 7, 2022
2fa40b0
fstrings for my sanity
elastufka Apr 7, 2022
399234c
weights and points multidimensional
elastufka Apr 7, 2022
0c8aef3
transform xi,wi correctly
elastufka Apr 8, 2022
7bf8fa0
let function to integrate accept args, c.f. scipy.nquad
elastufka Apr 8, 2022
d630a4f
Merge branch 'main' of https://github.com/elastufka/torchquad into ga…
elastufka Apr 8, 2022
83da697
let function to integrate accept args, c.f. scipy.nquad
elastufka Apr 8, 2022
b5cca10
any edits
elastufka Apr 13, 2022
fc37852
add numpy import
elastufka Apr 13, 2022
2d79c48
autoray
elastufka Apr 13, 2022
444aaef
add Gaussian quadrature methods
elastufka Apr 21, 2022
9e196d4
fix import
elastufka Apr 22, 2022
c08ddfa
change anp.inf to numpy.inf
elastufka Apr 25, 2022
25b0747
fix interval transformation and clean up
elastufka Apr 26, 2022
adae7d7
make sure tensors are on same device
elastufka Apr 26, 2022
24357bc
make sure tensors are on same devicepart 2
elastufka Apr 26, 2022
5c458a4
make sure tensors are on same devicepart 3
elastufka Apr 26, 2022
58d7de7
make sure tensors are on same devicepart 4
elastufka Apr 26, 2022
684a1b5
make sure tensors are on same devicepart 5
elastufka Apr 26, 2022
3ad979c
add special import
elastufka May 12, 2022
9a94a55
Merge remote-tracking branch 'esa/develop' into gaussian-quadrature
elastufka May 16, 2022
6d42ab8
add tests to /tests
elastufka May 16, 2022
3fd4a00
run autopep8, add docstring
elastufka May 30, 2022
50f3785
Merge branch 'multi_dim_integrand' into gaussian_quadrature
ilan-gold Jan 13, 2023
868c8f2
(feat): cache for roots.
ilan-gold Jan 13, 2023
5073116
(feat): refactor out grid integration procedure
ilan-gold Jan 13, 2023
5868263
(feat): gaussian integration refactored, some tests passing
ilan-gold Jan 13, 2023
2d4adf4
(fix): scaling constant
ilan-gold Jan 13, 2023
f0b0859
(chore): higher dim integrals testing
ilan-gold Jan 13, 2023
44f063b
(feat): weights correct for multi-dim integrands.
ilan-gold Jan 14, 2023
96f2af3
(fix): correct number of argument.
ilan-gold Jan 14, 2023
0c21d1e
(fix): remove non-legendre tests.
ilan-gold Jan 14, 2023
1bf58fb
Merge branch 'multi_dim_integrand' into gaussian_quadrature
ilan-gold Jan 14, 2023
1d2a53c
(fix): import GaussLegendre
ilan-gold Jan 14, 2023
c6a6a85
(fix): ensure grid and weights are correct type
ilan-gold Jan 14, 2023
b2ac2f3
(style): docstrings.
ilan-gold Jan 14, 2023
9a33275
(fix): default `grid_func`
ilan-gold Jan 14, 2023
f619485
(fix): `_cached_poi...` returns tuple, not ndarray
ilan-gold Jan 14, 2023
69c5354
(fix): propagate `backend` correctly.
ilan-gold Jan 14, 2023
5ab18ce
Merge branch 'develop' into gaussian-quadrature
ilan-gold Jan 16, 2023
738d18f
(chore): export base types
ilan-gold Jan 16, 2023
ce20c33
(feat): add jit for gausssian
ilan-gold Jan 17, 2023
40c5e58
(feat): backward diff
ilan-gold Jan 19, 2023
86c9c20
(fix): env issue
ilan-gold Jan 19, 2023
76c4a24
Fixed tests badge
gomezzz Mar 1, 2023
76fc60a
Merge branch 'main' into gaussian_quadrature
ilan-gold Mar 3, 2023
5e2668d
Merge remote-tracking branch 'trunk/develop' into gaussian_quadrature
ilan-gold Mar 3, 2023
215fc5f
(chore): cleanup
ilan-gold Mar 3, 2023
edfd308
(fix): `intergal` -> `integral`
ilan-gold Mar 3, 2023
e3827fa
(chore): add tutorial
ilan-gold Mar 3, 2023
3ff0d2b
(fix): change to `argnums` to work around decorator
ilan-gold Mar 6, 2023
09ecbd2
Merge branch 'gaussian-quadrature' of https://github.com/elastufka/to…
ilan-gold Mar 6, 2023
644d420
Merge branch 'develop' into gaussian_quadrature
ilan-gold Mar 8, 2023
0b9def6
(fix): add fix from other PR
ilan-gold Mar 8, 2023
c6f06de
,erge
ilan-gold Mar 8, 2023
88bd469
(feat): add (broken) tests for gauss jit
ilan-gold Mar 8, 2023
308a990
(chore): remove unused import
ilan-gold Mar 8, 2023
0b1a12a
(fix): use `item` for `N` when `jit` with `jax`
ilan-gold Mar 16, 2023
d732ca4
(fix): `domain` for jit gauss `calculate_result`
ilan-gold Mar 16, 2023
93ba243
(chore): `black`
ilan-gold Mar 16, 2023
c31ccae
(chore): erroneous diff
ilan-gold Mar 17, 2023
24b94b8
(chore): remove erroneous print
ilan-gold Mar 20, 2023
d44e42d
(fix): correct comment
ilan-gold Mar 20, 2023
b8c18b9
(fix): clean up gaussian tests
ilan-gold Mar 20, 2023
1582735
(chore): add comments.
ilan-gold Mar 20, 2023
491cce3
(chore): formatting
ilan-gold Mar 20, 2023
b3d442e
(fix): error of 1D integral
ilan-gold Mar 20, 2023
270de4d
Merge remote-tracking branch 'trunk/fixing-coverage-workflow' into ga…
ilan-gold Mar 27, 2023
7383d5d
(fix): increase bounds.
ilan-gold Apr 18, 2023
31441c6
Merge branch 'develop' into gaussian_quadrature
ilan-gold Apr 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions torchquad/integration/gaussian.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy
import scipy
from autoray import numpy as anp
from .grid_integrator import GridIntegrator

Expand Down Expand Up @@ -105,9 +104,15 @@ def _cached_points_and_weights(self, N):
tuple: nodes and weights
"""
root_args = (N, *self.root_args)
if not isinstance(N, int):
if hasattr(N, "item"):
root_args = (N.item(), *self.root_args)
else:
raise NotImplementedError(
f"N {N} is not an int and lacks an `item` method"
)
if root_args in self._cache:
return self._cache[root_args]

self._cache[root_args] = self.root_fn(*root_args)
return self._cache[root_args]

Expand Down
18 changes: 13 additions & 5 deletions torchquad/integration/grid_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,11 @@ def get_jit_compiled_integrate(
self.calculate_grid, static_argnames=["N"]
)
self._jax_jit_calculate_result = jax.jit(
self.calculate_result, static_argnames=["dim", "n_per_dim"]
self.calculate_result,
static_argnums=(
1,
2,
), # dim and n_per_dim
)
jit_calculate_grid = self._jax_jit_calculate_grid
jit_calculate_result = self._jax_jit_calculate_result
Expand Down Expand Up @@ -207,7 +211,7 @@ def step1(integration_domain):

dim = int(integration_domain.shape[0])

def step3(function_values, hs):
def step3(function_values, hs, integration_domain):
return self.calculate_result(
function_values, dim, n_per_dim, hs, integration_domain
)
Expand All @@ -233,15 +237,17 @@ def step3(function_values, hs):
if function_values.requires_grad:
function_values = function_values.detach()
function_values.requires_grad = True
step3 = torch.jit.trace(step3, (function_values, hs))
step3 = torch.jit.trace(
step3, (function_values, hs, integration_domain)
)

# Define a compiled integrate function
def compiled_integrate(fn, integration_domain):
grid_points, hs, _ = step1(integration_domain)
function_values, _ = self.evaluate_integrand(
fn, grid_points, weights=self._weights(n_per_dim, dim, backend)
)
result = step3(function_values, hs)
result = step3(function_values, hs, integration_domain)
return result

return compiled_integrate
Expand All @@ -253,7 +259,9 @@ def compiled_integrate(fn, integration_domain):
def lazy_compiled_integrate(fn, integration_domain):
if compiled_func[0] is None:
compiled_func[0] = do_compile(fn)
return compiled_func[0](fn, integration_domain)
res = compiled_func[0](fn, integration_domain)
print(res)
return res

return lazy_compiled_integrate

Expand Down
162 changes: 101 additions & 61 deletions torchquad/tests/gauss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,69 +12,59 @@
def _run_gaussian_tests(backend, _precision):
"""Test the integrate function in integration.gaussian for the given backend."""

integrators = [GaussLegendre()]
gauss = GaussLegendre()

# 1D Tests
N = 60

for integrator in integrators:
ii = integrator
errors, funcs = compute_integration_test_errors(
ii.integrate,
{"N": N, "dim": 1},
integration_dim=1,
use_complex=True,
backend=backend,
)
print(
f"1D {integrator} Test passed. N: {N}, backend: {backend}, Errors: {errors}"
)
# Polynomials up to degree 1 can be integrated almost exactly with gaussian.
for err, test_function in zip(errors, funcs):
assert test_function.get_order() > 1 or err < 2e-11
for error in errors:
assert error < 1e-5
errors, funcs = compute_integration_test_errors(
gauss.integrate,
{"N": N, "dim": 1},
integration_dim=1,
use_complex=True,
backend=backend,
)
print(f"1D {gauss} Test passed. N: {N}, backend: {backend}, Errors: {errors}")
# Polynomials up to degree 1 can be integrated almost exactly with gaussian.
for err, test_function in zip(errors, funcs):
assert test_function.get_order() > 1 or err < 2e-11
for error in errors:
assert error < 1e-5

N = 2 # integration points, here 2 for order check (2 points should lead to almost 0 err for low order polynomials)
for integrator in integrators:
ii = integrator
errors, funcs = compute_integration_test_errors(
ii.integrate,
{"N": N, "dim": 1},
integration_dim=1,
use_complex=True,
backend=backend,
)
print(
f"1D {integrator} Test passed. N: {N}, backend: {backend}, Errors: {errors}"
)
# All polynomials up to degree = 1 should be 0
# If this breaks check if test functions in helper_functions changed.
for err, test_function in zip(errors, funcs):
assert test_function.get_order() > 1 or err < 1e-15
for error in errors[:2]:
assert error < 1e-15

errors, funcs = compute_integration_test_errors(
gauss.integrate,
{"N": N, "dim": 1},
integration_dim=1,
use_complex=True,
backend=backend,
)
print(f"1D {gauss} Test passed. N: {N}, backend: {backend}, Errors: {errors}")
# All polynomials up to degree = 1 should be 0
# If this breaks check if test functions in helper_functions changed.
for err, test_function in zip(errors, funcs):
assert test_function.get_order() > 1 or err < 1e-15
for error in errors[:2]:
assert error < 1e-15

# 3D Tests
N = 60**3
for integrator in integrators:
ii = integrator
errors, funcs = compute_integration_test_errors(
ii.integrate,
{"N": N, "dim": 3},
integration_dim=3,
use_complex=True,
backend=backend,
)
print(
f"3D {integrator} Test passed. N: {N}, backend: {backend}, Errors: {errors}"

errors, funcs = compute_integration_test_errors(
gauss.integrate,
{"N": N, "dim": 3},
integration_dim=3,
use_complex=True,
backend=backend,
)
print(f"3D {gauss} Test passed. N: {N}, backend: {backend}, Errors: {errors}")
for err, test_function in zip(errors, funcs):
assert test_function.get_order() > 1 or (
err < 1e-12 if test_function.is_integrand_1d else err < 1e-11
)
for err, test_function in zip(errors, funcs):
assert test_function.get_order() > 1 or (
err < 1e-12 if test_function.is_integrand_1d else err < 1e-11
)
for error in errors:
assert error < 6e-3
for error in errors:
assert error < 6e-3

# Tensorflow crashes with an Op:StridedSlice UnimplementedError with 10
# dimensions
Expand All @@ -84,22 +74,72 @@ def _run_gaussian_tests(backend, _precision):

# 10D Tests
N = (60**3) * 3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you really mean *3 ? Not sure how many point this is per dim? 🤔

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this was a guess. I'm not really sure what we should be going for here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

N_d points per dim would lead to (N_d - 1) * (degree+1) points in that dimension, right? Then I would propose to maybe go for 310 or something like that? For Boole, e.g. , by that logic we picked 510 but that is already quite some points 🙃

for integrator in integrators:
ii = integrator

errors, funcs = compute_integration_test_errors(
gauss.integrate,
{"N": N, "dim": 10},
integration_dim=10,
use_complex=True,
backend=backend,
)
print(f"10D {gauss} Test passed. N: {N}, backend: {backend}, Errors: {errors}")
for err, test_function in zip(errors, funcs):
assert test_function.get_order() > 1 or err < 1e-11
for error in errors:
assert error < 7000

# JIT Tests
if backend != "numpy":
N = 60
jit_integrate = None

def integrate(*args, **kwargs):
# this function initializes the jit_integrate variable with a jit'ed integrate function
# which is then re-used on all other integrations (as is the point of JIT).
nonlocal jit_integrate
if jit_integrate is None:
jit_integrate = gauss.get_jit_compiled_integrate(
dim=1, N=N, backend=backend
)
return jit_integrate(*args, **kwargs)

errors, funcs = compute_integration_test_errors(
integrate,
{},
integration_dim=1,
use_complex=True,
backend=backend,
filter_test_functions=lambda x: x.is_integrand_1d,
)

print(
f"1D Gaussian JIT Test passed. N: {N}, backend: {backend}, Errors: {errors}"
)
# Polynomials up to degree 1 can be integrated almost exactly with gaussian.
for err, test_function in zip(errors, funcs):
assert test_function.get_order() > float("inf") or err < 2e-10
for error in errors:
assert error < 1e-5

jit_integrate = (
None # set to None again so can be re-used with new integrand shape
)

errors, funcs = compute_integration_test_errors(
ii.integrate,
{"N": N, "dim": 10},
integration_dim=10,
integrate,
{},
integration_dim=1,
use_complex=True,
backend=backend,
filter_test_functions=lambda x: x.integrand_dims == [2, 2, 2],
)
print(
f"10D {integrator} Test passed. N: {N}, backend: {backend}, Errors: {errors}"
f"1D Gaussian JIT Test passed for [2, 2, 2] dimensional integrands. N: {N}, backend: {backend}, Errors: {errors}"
)
for err, test_function in zip(errors, funcs):
assert test_function.get_order() > 1 or err < 1e-11
assert test_function.get_order() > 1 or err < 2e-10
for error in errors:
assert error < 7000
assert error < 1e-5


test_integrate_numpy = setup_test_for_backend(_run_gaussian_tests, "numpy", "float64")
Expand Down