-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathintegration_grid.py
126 lines (105 loc) · 4.75 KB
/
integration_grid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from autoray import numpy as anp
from autoray import infer_backend, astype, to_backend_dtype
from time import perf_counter
from loguru import logger
from .utils import (
_check_integration_domain,
_setup_integration_domain,
_linspace_with_grads,
)
def grid_func(integration_domain, N, requires_grad=False, backend=None):
a = integration_domain[0]
b = integration_domain[1]
return _linspace_with_grads(a, b, N, requires_grad=requires_grad)
class IntegrationGrid:
"""This class is used to store the integration grid for methods like Trapezoid or Simpsons, which require a grid."""
points = None # integration points
h = None # mesh width
_N = None # number of mesh points
_dim = None # dimensionality of the grid
_runtime = None # runtime for the creation of the integration grid
def __init__(
self,
N,
integration_domain,
grid_func=grid_func,
disable_integration_domain_check=False,
):
"""Creates an integration grid of N points in the passed domain. Dimension will be len(integration_domain)
Args:
N (int): Total desired number of points in the grid (will take next lower root depending on dim)
integration_domain (list or backend tensor): Domain to choose points in, e.g. [[-1,1],[0,1]]. It also determines the numerical backend (if it is a list, the backend is "torch").
grid_func (function): function for generating a grid of points over which to integrate (arguments: integration_domain, N, requires_grad, backend)
disable_integration_domain_check (bool): Disbaling integration domain checks (default False)
"""
start = perf_counter()
self._check_inputs(N, integration_domain, disable_integration_domain_check)
backend = infer_backend(integration_domain)
if backend == "builtins":
backend = "torch"
integration_domain = _setup_integration_domain(
len(integration_domain), integration_domain, backend=backend
)
else:
# Convert the grid domain to float64 if it was int32/64
# will cause problems otherwise as in issue #180
if "int" in str(integration_domain.dtype):
dtype = to_backend_dtype("float64", like=backend)
integration_domain = astype(integration_domain, dtype)
self._dim = integration_domain.shape[0]
# TODO Add that N can be different for each dimension
# A rounding error occurs for certain numbers with certain powers,
# e.g. (4**3)**(1/3) = 3.99999... Because int() floors the number,
# i.e. int(3.99999...) -> 3, a little error term is useful
self._N = int(N ** (1.0 / self._dim) + 1e-8) # convert to points per dim
logger.opt(lazy=True).debug(
"Creating {dim}-dimensional integration grid with {N} points over {dom}",
dim=lambda: str(self._dim),
N=lambda: str(N),
dom=lambda: str(integration_domain),
)
# Check if domain requires gradient
if hasattr(integration_domain, "requires_grad"):
requires_grad = integration_domain.requires_grad
else:
requires_grad = False
grid_1d = []
# Determine for each dimension grid points and mesh width
for dim in range(self._dim):
grid_1d.append(
grid_func(
integration_domain[dim],
self._N,
requires_grad=requires_grad,
backend=backend,
)
)
self.h = anp.stack(
[grid_1d[dim][1] - grid_1d[dim][0] for dim in range(self._dim)],
like=integration_domain,
)
logger.opt(lazy=True).debug("Grid mesh width is {h}", h=lambda: str(self.h))
# Get grid points
points = anp.meshgrid(*grid_1d)
self.points = anp.stack(
[mg.ravel() for mg in points], axis=1, like=integration_domain
)
logger.info("Integration grid created.")
self._runtime = perf_counter() - start
def _check_inputs(self, N, integration_domain, disable_integration_domain_check):
"""Used to check input validity"""
logger.debug("Checking inputs to IntegrationGrid.")
if disable_integration_domain_check:
dim = len(integration_domain)
else:
dim = _check_integration_domain(integration_domain)
if N < 2:
raise ValueError("N has to be > 1.")
if N ** (1.0 / dim) < 2:
raise ValueError(
"Cannot create a ",
dim,
"-dimensional grid with ",
N,
" points. Too few points per dimension.",
)