diff --git a/torchquad/integration/gaussian.py b/torchquad/integration/gaussian.py index c76cc982..14828bf8 100644 --- a/torchquad/integration/gaussian.py +++ b/torchquad/integration/gaussian.py @@ -78,17 +78,18 @@ def _grid_func(self): function for generating a grid to be integrated over i.e., the polynomial roots, resized to the domain. """ - def f(a, b, N, requires_grad, backend=None): - return self._resize_roots(a, b, self._roots(N, backend, requires_grad)) + def f(integration_domain, N, requires_grad, backend=None): + return self._resize_roots( + integration_domain, self._roots(N, backend, requires_grad) + ) return f - def _resize_roots(self, a, b, roots): # scale from [-1,1] to [a,b] + def _resize_roots(self, integration_domain, roots): # scale from [-1,1] to [a,b] """resize the roots based on domain of [a,b] Args: - a (backend tensor): lower bound - b (backend tensor): upper bound + integration_domain (backend tensor): domain roots (backend tensor): polynomial nodes Returns: @@ -148,5 +149,7 @@ class GaussLegendre(Gaussian): def __init__(self): super().__init__() - def _resize_roots(self, a, b, roots): # scale from [-1,1] to [a,b] + def _resize_roots(self, integration_domain, roots): # scale from [-1,1] to [a,b] + a = integration_domain[0] + b = integration_domain[1] return ((b - a) / 2) * roots + ((a + b) / 2) diff --git a/torchquad/integration/grid_integrator.py b/torchquad/integration/grid_integrator.py index ac0428bd..2243399d 100644 --- a/torchquad/integration/grid_integrator.py +++ b/torchquad/integration/grid_integrator.py @@ -18,7 +18,9 @@ def __init__(self): @property def _grid_func(self): - def f(a, b, N, requires_grad=False, backend=None): + def f(integration_domain, N, requires_grad=False, backend=None): + a = integration_domain[0] + b = integration_domain[1] return _linspace_with_grads(a, b, N, requires_grad=requires_grad) return f @@ -95,12 +97,18 @@ def calculate_result(self, function_values, dim, n_per_dim, hs, integration_doma ) return result - def calculate_grid(self, N, integration_domain): + def calculate_grid( + self, + N, + integration_domain, + disable_integration_domain_check=False, + ): """Calculate grid points, widths and N per dim Args: N (int): Number of points integration_domain (backend tensor): Integration domain + disable_integration_domain_check (bool): Disbaling integration domain checks (default False) Returns: backend tensor: Grid points @@ -119,7 +127,9 @@ def calculate_grid(self, N, integration_domain): ) # Create grid and assemble evaluation points - grid = IntegrationGrid(N, integration_domain, self._grid_func) + grid = IntegrationGrid( + N, integration_domain, self._grid_func, disable_integration_domain_check + ) return grid.points, grid.h, grid._N diff --git a/torchquad/integration/integration_grid.py b/torchquad/integration/integration_grid.py index def7be3b..c6df3cee 100644 --- a/torchquad/integration/integration_grid.py +++ b/torchquad/integration/integration_grid.py @@ -10,7 +10,9 @@ ) -def grid_func(a, b, N, requires_grad=False, backend=None): +def grid_func(integration_domain, N, requires_grad=False, backend=None): + a = integration_domain[0] + b = integration_domain[1] return _linspace_with_grads(a, b, N, requires_grad=requires_grad) @@ -23,15 +25,23 @@ class IntegrationGrid: _dim = None # dimensionality of the grid _runtime = None # runtime for the creation of the integration grid - def __init__(self, N, integration_domain, grid_func=grid_func): + def __init__( + self, + N, + integration_domain, + grid_func=grid_func, + disable_integration_domain_check=False, + ): """Creates an integration grid of N points in the passed domain. Dimension will be len(integration_domain) Args: N (int): Total desired number of points in the grid (will take next lower root depending on dim) integration_domain (list or backend tensor): Domain to choose points in, e.g. [[-1,1],[0,1]]. It also determines the numerical backend (if it is a list, the backend is "torch"). + grid_func (function): function for generating a grid of points over which to integrate (arguments: integration_domain, N, requires_grad, backend) + disable_integration_domain_check (bool): Disbaling integration domain checks (default False) """ start = perf_counter() - self._check_inputs(N, integration_domain) + self._check_inputs(N, integration_domain, disable_integration_domain_check) backend = infer_backend(integration_domain) if backend == "builtins": backend = "torch" @@ -64,8 +74,7 @@ def __init__(self, N, integration_domain, grid_func=grid_func): for dim in range(self._dim): grid_1d.append( grid_func( - integration_domain[dim][0], - integration_domain[dim][1], + integration_domain[dim], self._N, requires_grad=requires_grad, backend=backend, @@ -88,11 +97,14 @@ def __init__(self, N, integration_domain, grid_func=grid_func): self._runtime = perf_counter() - start - def _check_inputs(self, N, integration_domain): + def _check_inputs(self, N, integration_domain, disable_integration_domain_check): """Used to check input validity""" logger.debug("Checking inputs to IntegrationGrid.") - dim = _check_integration_domain(integration_domain) + if disable_integration_domain_check: + dim = len(integration_domain) + else: + dim = _check_integration_domain(integration_domain) if N < 2: raise ValueError("N has to be > 1.") diff --git a/torchquad/integration/utils.py b/torchquad/integration/utils.py index 28944064..c2ab3fcc 100644 --- a/torchquad/integration/utils.py +++ b/torchquad/integration/utils.py @@ -177,7 +177,7 @@ def _check_integration_domain(integration_domain): integration_domain, " does not specify a valid integration bound.", ) - if bounds[0] > bounds[1]: + if anp.any(bounds[0] > bounds[1]): raise ValueError( bounds, " in ", diff --git a/torchquad/tests/integration_grid_test.py b/torchquad/tests/integration_grid_test.py index 1ac1e4c4..1f29134b 100644 --- a/torchquad/tests/integration_grid_test.py +++ b/torchquad/tests/integration_grid_test.py @@ -1,14 +1,50 @@ import sys +import pytest + sys.path.append("../") from autoray import numpy as anp from autoray import to_backend_dtype - +import autoray as ar from integration.integration_grid import IntegrationGrid +from integration.grid_integrator import GridIntegrator +from integration.utils import _linspace_with_grads from helper_functions import setup_test_for_backend +class MockIntegrator(GridIntegrator): + def __init__(self, disable_integration_domain_check, *args, **kwargs): + super().__init__(*args, **kwargs) + self.disable_integration_domain_check = disable_integration_domain_check + + def integrate(self, fn, dim, N, integration_domain, backend, grid_check): + grid_points, _, _ = self.calculate_grid( + N, + integration_domain, + disable_integration_domain_check=self.disable_integration_domain_check, + ) + grid_points = grid_points.reshape(N, -1) + assert grid_check(grid_points) + + @property # need to override in order to handle the grid so that we return a multiple 1d grids for each domain in the customized integration_domain + def _grid_func(self): + def f(integration_domain, N, requires_grad=False, backend=None): + b = integration_domain[:, 1] + a = integration_domain[:, 0] + grid = anp.stack( + [ + _linspace_with_grads(a[ind], b[ind], N, requires_grad=requires_grad) + for ind in range(len(a)) + ] + ).T + return anp.reshape( + grid, [-1] + ) # flatten, but it works with TF as well which has no flatten + + return f + + def _check_grid_validity(grid, integration_domain, N, eps): """Check if a specific grid object contains illegal values""" assert grid._N == int( @@ -64,6 +100,43 @@ def _run_integration_grid_tests(backend, dtype_name): grid = IntegrationGrid(N, integration_domain) _check_grid_validity(grid, integration_domain, N, eps) + mock_integrator_no_check = MockIntegrator(disable_integration_domain_check=True) + mock_integrator_check = MockIntegrator(disable_integration_domain_check=False) + + # Bypassing check, the output grid should be shape (N, 3) for 3 different 1d domains. + # Our custom _grid_func treats the integration_domain as a list of 1d domains + # That is why the domain shape is (1, 3, 2) so that the IntegrationGrid recognizes it as a 1d integral but our + # custom handler does the rest without the check, and fails with the check. + N = 500 + dim = 1 + + def grid_check(x): + has_right_shape = x.shape == (N, 3) + has_right_vals = anp.all(ar.to_numpy(x[0, :]) == 0) and anp.all( + ar.to_numpy(x[-1, :]) == 1 + ) + return has_right_shape and has_right_vals + + mock_integrator_no_check.integrate( + lambda x: x, + dim, + N, + anp.array([[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]]], like=backend), + backend, + grid_check, + ) + # Without bypassing check, the error raised should be that the input domain is not compatible with the requested dimensions + with pytest.raises(ValueError) as excinfo: + mock_integrator_check.integrate( + lambda x: x, + dim, + 49, + anp.array([[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]]], like=backend), + backend, + grid_check, + ) + assert "The integration_domain tensor has an invalid shape" == str(excinfo.value) + test_integration_grid_numpy = setup_test_for_backend( _run_integration_grid_tests, "numpy", "float64"