Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat): add customization API for IntegrationGrid #173

Merged
merged 69 commits into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
317f2ab
basic version of gauss-legendre
elastufka Apr 7, 2022
ba7d1e7
fstrings for my sanity
elastufka Apr 7, 2022
4bc447b
fstrings for my sanity
elastufka Apr 7, 2022
345da2c
weights and points multidimensional
elastufka Apr 7, 2022
170ae2a
transform xi,wi correctly
elastufka Apr 8, 2022
1d27e08
basic version of gauss-legendre
elastufka Apr 7, 2022
4547646
fstrings for my sanity
elastufka Apr 7, 2022
2fa40b0
fstrings for my sanity
elastufka Apr 7, 2022
399234c
weights and points multidimensional
elastufka Apr 7, 2022
0c8aef3
transform xi,wi correctly
elastufka Apr 8, 2022
7bf8fa0
let function to integrate accept args, c.f. scipy.nquad
elastufka Apr 8, 2022
d630a4f
Merge branch 'main' of https://github.com/elastufka/torchquad into ga…
elastufka Apr 8, 2022
83da697
let function to integrate accept args, c.f. scipy.nquad
elastufka Apr 8, 2022
b5cca10
any edits
elastufka Apr 13, 2022
fc37852
add numpy import
elastufka Apr 13, 2022
2d79c48
autoray
elastufka Apr 13, 2022
444aaef
add Gaussian quadrature methods
elastufka Apr 21, 2022
9e196d4
fix import
elastufka Apr 22, 2022
c08ddfa
change anp.inf to numpy.inf
elastufka Apr 25, 2022
25b0747
fix interval transformation and clean up
elastufka Apr 26, 2022
adae7d7
make sure tensors are on same device
elastufka Apr 26, 2022
24357bc
make sure tensors are on same devicepart 2
elastufka Apr 26, 2022
5c458a4
make sure tensors are on same devicepart 3
elastufka Apr 26, 2022
58d7de7
make sure tensors are on same devicepart 4
elastufka Apr 26, 2022
684a1b5
make sure tensors are on same devicepart 5
elastufka Apr 26, 2022
3ad979c
add special import
elastufka May 12, 2022
9a94a55
Merge remote-tracking branch 'esa/develop' into gaussian-quadrature
elastufka May 16, 2022
6d42ab8
add tests to /tests
elastufka May 16, 2022
3fd4a00
run autopep8, add docstring
elastufka May 30, 2022
50f3785
Merge branch 'multi_dim_integrand' into gaussian_quadrature
ilan-gold Jan 13, 2023
868c8f2
(feat): cache for roots.
ilan-gold Jan 13, 2023
5073116
(feat): refactor out grid integration procedure
ilan-gold Jan 13, 2023
5868263
(feat): gaussian integration refactored, some tests passing
ilan-gold Jan 13, 2023
2d4adf4
(fix): scaling constant
ilan-gold Jan 13, 2023
f0b0859
(chore): higher dim integrals testing
ilan-gold Jan 13, 2023
44f063b
(feat): weights correct for multi-dim integrands.
ilan-gold Jan 14, 2023
96f2af3
(fix): correct number of argument.
ilan-gold Jan 14, 2023
0c21d1e
(fix): remove non-legendre tests.
ilan-gold Jan 14, 2023
1bf58fb
Merge branch 'multi_dim_integrand' into gaussian_quadrature
ilan-gold Jan 14, 2023
1d2a53c
(fix): import GaussLegendre
ilan-gold Jan 14, 2023
c6a6a85
(fix): ensure grid and weights are correct type
ilan-gold Jan 14, 2023
b2ac2f3
(style): docstrings.
ilan-gold Jan 14, 2023
9a33275
(fix): default `grid_func`
ilan-gold Jan 14, 2023
f619485
(fix): `_cached_poi...` returns tuple, not ndarray
ilan-gold Jan 14, 2023
69c5354
(fix): propagate `backend` correctly.
ilan-gold Jan 14, 2023
5ab18ce
Merge branch 'develop' into gaussian-quadrature
ilan-gold Jan 16, 2023
738d18f
(chore): export base types
ilan-gold Jan 16, 2023
ce20c33
(feat): add jit for gausssian
ilan-gold Jan 17, 2023
40c5e58
(feat): backward diff
ilan-gold Jan 19, 2023
86c9c20
(fix): env issue
ilan-gold Jan 19, 2023
8c23d67
(feat): start small with special grid for thesis
ilan-gold Jan 22, 2023
6e95058
(fix): correct `_grid_func` for `GridIntegrator`
ilan-gold Jan 22, 2023
3b8277c
(fix): use `static_argnums`
ilan-gold Mar 6, 2023
a38132a
Merge branch 'develop' into special_grid
ilan-gold Apr 21, 2023
7674aa9
(feat): add boolean check
ilan-gold Apr 21, 2023
9e15dc0
(fix): small dev changes.
ilan-gold Apr 21, 2023
79d56ad
(chore): add docs
ilan-gold Apr 21, 2023
1a9235e
(feat): add basic test
ilan-gold May 2, 2023
3696052
(style): change integration domain check location
ilan-gold May 2, 2023
d04f3a2
Update torchquad/integration/utils.py
ilan-gold May 3, 2023
8de3e34
Merge remote-tracking branch 'trunk/special_grid' into special_grid
ilan-gold May 3, 2023
2f086dd
(fix): `linspace` doesn't take arrays for all backends
ilan-gold May 3, 2023
3e3ae82
(fix): grid creation from flat array
ilan-gold May 4, 2023
9d01e0b
(fix): no all for `tf`?
ilan-gold May 4, 2023
7825aca
(fix): tf has no flatten
ilan-gold May 8, 2023
0aaa7e3
(chore): black
ilan-gold May 8, 2023
5dc3d57
(fix): tf has no flatten (!?!?!?!)
ilan-gold May 8, 2023
28ba225
(fix): try `reshape` with -1
ilan-gold May 10, 2023
0986f6b
(chore): formatting
ilan-gold May 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions torchquad/integration/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,18 @@ def _grid_func(self):
function for generating a grid to be integrated over i.e., the polynomial roots, resized to the domain.
"""

def f(a, b, N, requires_grad, backend=None):
return self._resize_roots(a, b, self._roots(N, backend, requires_grad))
def f(integration_domain, N, requires_grad, backend=None):
return self._resize_roots(
integration_domain, self._roots(N, backend, requires_grad)
)

return f

def _resize_roots(self, a, b, roots): # scale from [-1,1] to [a,b]
def _resize_roots(self, integration_domain, roots): # scale from [-1,1] to [a,b]
"""resize the roots based on domain of [a,b]

Args:
a (backend tensor): lower bound
b (backend tensor): upper bound
integration_domain (backend tensor): domain
roots (backend tensor): polynomial nodes

Returns:
Expand Down Expand Up @@ -148,5 +149,7 @@ class GaussLegendre(Gaussian):
def __init__(self):
super().__init__()

def _resize_roots(self, a, b, roots): # scale from [-1,1] to [a,b]
def _resize_roots(self, integration_domain, roots): # scale from [-1,1] to [a,b]
a = integration_domain[0]
b = integration_domain[1]
return ((b - a) / 2) * roots + ((a + b) / 2)
16 changes: 13 additions & 3 deletions torchquad/integration/grid_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def __init__(self):

@property
def _grid_func(self):
def f(a, b, N, requires_grad=False, backend=None):
def f(integration_domain, N, requires_grad=False, backend=None):
a = integration_domain[0]
b = integration_domain[1]
return _linspace_with_grads(a, b, N, requires_grad=requires_grad)

return f
Expand Down Expand Up @@ -95,12 +97,18 @@ def calculate_result(self, function_values, dim, n_per_dim, hs, integration_doma
)
return result

def calculate_grid(self, N, integration_domain):
def calculate_grid(
self,
N,
integration_domain,
disable_integration_domain_check=False,
):
"""Calculate grid points, widths and N per dim

Args:
N (int): Number of points
integration_domain (backend tensor): Integration domain
disable_integration_domain_check (bool): Disbaling integration domain checks (default False)

Returns:
backend tensor: Grid points
Expand All @@ -119,7 +127,9 @@ def calculate_grid(self, N, integration_domain):
)

# Create grid and assemble evaluation points
grid = IntegrationGrid(N, integration_domain, self._grid_func)
grid = IntegrationGrid(
N, integration_domain, self._grid_func, disable_integration_domain_check
)

return grid.points, grid.h, grid._N

Expand Down
26 changes: 19 additions & 7 deletions torchquad/integration/integration_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
)


def grid_func(a, b, N, requires_grad=False, backend=None):
def grid_func(integration_domain, N, requires_grad=False, backend=None):
a = integration_domain[0]
b = integration_domain[1]
return _linspace_with_grads(a, b, N, requires_grad=requires_grad)


Expand All @@ -23,15 +25,23 @@ class IntegrationGrid:
_dim = None # dimensionality of the grid
_runtime = None # runtime for the creation of the integration grid

def __init__(self, N, integration_domain, grid_func=grid_func):
def __init__(
self,
N,
integration_domain,
grid_func=grid_func,
disable_integration_domain_check=False,
):
"""Creates an integration grid of N points in the passed domain. Dimension will be len(integration_domain)

Args:
N (int): Total desired number of points in the grid (will take next lower root depending on dim)
integration_domain (list or backend tensor): Domain to choose points in, e.g. [[-1,1],[0,1]]. It also determines the numerical backend (if it is a list, the backend is "torch").
grid_func (function): function for generating a grid of points over which to integrate (arguments: integration_domain, N, requires_grad, backend)
disable_integration_domain_check (bool): Disbaling integration domain checks (default False)
"""
start = perf_counter()
self._check_inputs(N, integration_domain)
self._check_inputs(N, integration_domain, disable_integration_domain_check)
backend = infer_backend(integration_domain)
if backend == "builtins":
backend = "torch"
Expand Down Expand Up @@ -64,8 +74,7 @@ def __init__(self, N, integration_domain, grid_func=grid_func):
for dim in range(self._dim):
grid_1d.append(
grid_func(
integration_domain[dim][0],
integration_domain[dim][1],
integration_domain[dim],
self._N,
requires_grad=requires_grad,
backend=backend,
Expand All @@ -88,11 +97,14 @@ def __init__(self, N, integration_domain, grid_func=grid_func):

self._runtime = perf_counter() - start

def _check_inputs(self, N, integration_domain):
def _check_inputs(self, N, integration_domain, disable_integration_domain_check):
"""Used to check input validity"""

logger.debug("Checking inputs to IntegrationGrid.")
dim = _check_integration_domain(integration_domain)
if disable_integration_domain_check:
dim = len(integration_domain)
else:
dim = _check_integration_domain(integration_domain)

if N < 2:
raise ValueError("N has to be > 1.")
Expand Down
2 changes: 1 addition & 1 deletion torchquad/integration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _check_integration_domain(integration_domain):
integration_domain,
" does not specify a valid integration bound.",
)
if bounds[0] > bounds[1]:
if anp.any(bounds[0] > bounds[1]):
raise ValueError(
bounds,
" in ",
Expand Down
75 changes: 74 additions & 1 deletion torchquad/tests/integration_grid_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,50 @@
import sys
import pytest


sys.path.append("../")

from autoray import numpy as anp
from autoray import to_backend_dtype

import autoray as ar
from integration.integration_grid import IntegrationGrid
from integration.grid_integrator import GridIntegrator
from integration.utils import _linspace_with_grads
from helper_functions import setup_test_for_backend


class MockIntegrator(GridIntegrator):
def __init__(self, disable_integration_domain_check, *args, **kwargs):
super().__init__(*args, **kwargs)
self.disable_integration_domain_check = disable_integration_domain_check

def integrate(self, fn, dim, N, integration_domain, backend, grid_check):
grid_points, _, _ = self.calculate_grid(
N,
integration_domain,
disable_integration_domain_check=self.disable_integration_domain_check,
)
grid_points = grid_points.reshape(N, -1)
assert grid_check(grid_points)

@property # need to override in order to handle the grid so that we return a multiple 1d grids for each domain in the customized integration_domain
def _grid_func(self):
def f(integration_domain, N, requires_grad=False, backend=None):
b = integration_domain[:, 1]
a = integration_domain[:, 0]
grid = anp.stack(
[
_linspace_with_grads(a[ind], b[ind], N, requires_grad=requires_grad)
for ind in range(len(a))
]
).T
return anp.reshape(
grid, [-1]
) # flatten, but it works with TF as well which has no flatten

return f


def _check_grid_validity(grid, integration_domain, N, eps):
"""Check if a specific grid object contains illegal values"""
assert grid._N == int(
Expand Down Expand Up @@ -64,6 +100,43 @@ def _run_integration_grid_tests(backend, dtype_name):
grid = IntegrationGrid(N, integration_domain)
_check_grid_validity(grid, integration_domain, N, eps)

mock_integrator_no_check = MockIntegrator(disable_integration_domain_check=True)
mock_integrator_check = MockIntegrator(disable_integration_domain_check=False)

# Bypassing check, the output grid should be shape (N, 3) for 3 different 1d domains.
# Our custom _grid_func treats the integration_domain as a list of 1d domains
# That is why the domain shape is (1, 3, 2) so that the IntegrationGrid recognizes it as a 1d integral but our
# custom handler does the rest without the check, and fails with the check.
N = 500
dim = 1

def grid_check(x):
has_right_shape = x.shape == (N, 3)
has_right_vals = anp.all(ar.to_numpy(x[0, :]) == 0) and anp.all(
ar.to_numpy(x[-1, :]) == 1
)
return has_right_shape and has_right_vals

mock_integrator_no_check.integrate(
lambda x: x,
dim,
N,
anp.array([[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]]], like=backend),
backend,
grid_check,
)
# Without bypassing check, the error raised should be that the input domain is not compatible with the requested dimensions
with pytest.raises(ValueError) as excinfo:
mock_integrator_check.integrate(
lambda x: x,
dim,
49,
anp.array([[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]]], like=backend),
backend,
grid_check,
)
assert "The integration_domain tensor has an invalid shape" == str(excinfo.value)


test_integration_grid_numpy = setup_test_for_backend(
_run_integration_grid_tests, "numpy", "float64"
Expand Down