Skip to content

Commit

Permalink
VEGAS+: Replace torch. with autoray.numpy. operations
Browse files Browse the repository at this point in the history
After this commit it still only works with torch and should indirectly perform the same torch operations as before.
  • Loading branch information
FHof committed Mar 14, 2022
1 parent 7bccdf5 commit 23258a4
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 99 deletions.
73 changes: 41 additions & 32 deletions torchquad/integration/vegas.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import torch
from autoray import numpy as anp
from autoray import infer_backend
from loguru import logger


from .base_integrator import BaseIntegrator
from .utils import _setup_integration_domain
from .utils import _setup_integration_domain, RNG
from .vegas_map import VEGASMap
from .vegas_stratification import VEGASStratification


class VEGAS(BaseIntegrator):
"""VEGAS Enhanced in torch. Refer to https://arxiv.org/abs/2009.05112 .
"""VEGAS Enhanced. Refer to https://arxiv.org/abs/2009.05112 .
Implementation inspired by https://github.com/ycwu1030/CIGAR/ .
EQ <n> refers to equation <n> in the above paper.
"""
Expand All @@ -29,6 +30,7 @@ def integrate(
eps_abs=0,
max_iterations=20,
use_warmup=True,
backend="torch",
):
"""Integrates the passed function on the passed domain using VEGAS.
Expand All @@ -43,6 +45,7 @@ def integrate(
eps_abs (float, optional): Absolute error to abort at. Defaults to 0.
max_iterations (int, optional): Maximum number of vegas iterations to perform. Defaults to 32.
use_warmup (bool, optional): If a warmup should be used to initialize the map. Defaults to True.
backend (string, optional): Numerical backend. This argument is ignored if the backend can be inferred from integration_domain. Defaults to "torch".
Raises:
ValueError: If len(integration_domain) != dim
Expand Down Expand Up @@ -71,11 +74,12 @@ def integrate(
self._starting_N = N // self._max_iterations
self._N_increment = N // self._max_iterations
self._fn = fn
self._integration_domain = _setup_integration_domain(
dim, integration_domain, "torch"
self._integration_domain = integration_domain = _setup_integration_domain(
dim, integration_domain, backend
)
if seed is not None:
torch.random.manual_seed(seed)
self.backend = infer_backend(integration_domain)
self.dtype = integration_domain.dtype
self.rng = RNG(backend=self.backend, seed=seed)

# Initialize the adaptive VEGAS map,
# Note that a larger number of intervals may lead to problems if only few evals are allowed
Expand All @@ -87,7 +91,9 @@ def integrate(

# Initialize VEGAS' stratification
# Paper section III
self.strat = VEGASStratification(self._N_increment, dim=self._dim)
self.strat = VEGASStratification(
self._N_increment, dim=self._dim, rng=self.rng, backend=self.backend
)

logger.debug("Starting VEGAS")

Expand Down Expand Up @@ -125,8 +131,8 @@ def integrate(
chi2 = self._get_chisq()
acc = err / res

if torch.isnan(acc): # capture 0 error
acc = torch.tensor(0.0)
if anp.isnan(acc): # capture 0 error
acc = anp.array(0.0, like=acc)

# Abort if errors acceptable
logger.debug(f"Iteration {self.it},Chi2={chi2:.4e}")
Expand All @@ -136,9 +142,12 @@ def integrate(
# Adjust number of evals if Chi square indicates instability
# EQ 32
if chi2 / 5.0 < 1.0:
self._starting_N = torch.minimum(
torch.tensor(self._starting_N + self._N_increment),
self._starting_N * torch.sqrt(acc / (eps_rel + 1e-8)),
self._starting_N = anp.minimum(
anp.array(
self._starting_N + self._N_increment,
like=acc,
),
self._starting_N * anp.sqrt(acc / (eps_rel + 1e-8)),
)
self.results = [] # reset sample results
self.sigma2 = [] # reset sample results
Expand All @@ -165,8 +174,6 @@ def _warmup_grid(self, warmup_N_it=5, N_samples=1000):
f"Running Map Warmup with warmup_N_it={warmup_N_it}, N_samples={N_samples}..."
)

yrnd = torch.zeros(self._dim) # sample points
x = torch.zeros(self._dim) # transformed sample points
alpha_start = 0.5 # initial alpha value
# TODO in the original paper this is adjusted over time
self.alpha = alpha_start
Expand All @@ -181,8 +188,12 @@ def _warmup_grid(self, warmup_N_it=5, N_samples=1000):
jf = 0 # jacobians * function
jf2 = 0

# Sample points yrnd and transformed sample points x
# Multiplying by 0.99999999 as the edge case of y=1 leads to an error
yrnd = torch.rand(size=[N_samples, self._dim]) * 0.999999
yrnd = (
self.rng.uniform(size=[N_samples, self._dim], dtype=self.dtype)
* 0.999999
)
x = self.map.get_X(yrnd)
f_eval = self._eval(x).squeeze()
jac = self.map.get_Jac(yrnd)
Expand All @@ -198,7 +209,7 @@ def _warmup_grid(self, warmup_N_it=5, N_samples=1000):
self.sigma2[-1] += sig2 / N_samples # store results
self.map.update_map() # adapt the map
# TODO fix for integrals close to 0
acc = torch.sqrt(
acc = anp.sqrt(
self.sigma2[-1] / self.results[-1]
) # compute estimated accuracy,
logger.debug(
Expand All @@ -212,15 +223,13 @@ def _run_iteration(self):
"""Runs one iteration of VEGAS including stratification and updates the VEGAS map if use_grid_improve is set.
Returns:
float: Estimated accuracy.
backend tensor float: Estimated accuracy.
"""
y = torch.zeros(self._dim) # stratified sampling points
x = torch.zeros(self._dim) # transformed sample points

neval = self.strat.get_NH(self._starting_N) # Evals per strat cube
self.starting_N = torch.sum(neval) / self.strat.N_cubes # update real neval
self.starting_N = anp.sum(neval) / self.strat.N_cubes # update real neval
self._nr_of_fevals += neval.sum() # Locally track function evals

# Stratified sampling points y and transformed sample points x
y = self.strat.get_Y(neval)
x = self.map.get_X(y) # transform, EQ 8+9
f_eval = self._eval(x).squeeze() # eval integrand
Expand All @@ -233,19 +242,19 @@ def _run_iteration(self):
self.map.accumulate_weight(y, jf_vec2) # EQ 25
jf, jf2 = self.strat.accumulate_weight(neval, jf_vec) # update strat

ih = torch.divide(jf, neval) * self.strat.V_cubes # Compute integral per cube
ih = (jf / neval) * self.strat.V_cubes # Compute integral per cube

# Collect results
sig2 = torch.divide(jf2, neval) * (self.strat.V_cubes ** 2) - pow(ih, 2)
sig2 = (jf2 / neval) * (self.strat.V_cubes ** 2) - pow(ih, 2)
self.results[-1] = ih.sum() # store results
self.sigma2[-1] = torch.divide(sig2, neval).sum()
self.sigma2[-1] = (sig2 / neval).sum()

if self.use_grid_improve: # if on, update adaptive map
logger.debug("Running grid improvement")
self.map.update_map()

self.strat.update_DH() # update stratification
acc = torch.sqrt(self.sigma2[-1] / (self.results[-1])) # estimate accuracy
acc = anp.sqrt(self.sigma2[-1] / (self.results[-1])) # estimate accuracy

return acc

Expand All @@ -254,37 +263,37 @@ def _get_result(self):
"""Computes mean of results to estimate integral, EQ 30.
Returns:
float: Estimated integral.
backend tensor float: Estimated integral.
"""
res_num = 0
res_den = 0
for idx, res in enumerate(self.results):
res_num += res / self.sigma2[idx]
res_den += 1.0 / self.sigma2[idx]

if torch.isnan(res_num / res_den): # if variance is 0 just return mean result
return torch.mean(torch.tensor(self.results))
if anp.isnan(res_num / res_den): # if variance is 0 just return mean result
return anp.mean(anp.array(self.results, like=res_num))
else:
return res_num / res_den

def _get_error(self):
"""Estimates error from variance , EQ 31.
Returns:
float: Estimated error.
backend tensor float: Estimated error.
"""
res = 0
for sig in self.sigma2:
res += 1.0 / sig

return 1.0 / torch.sqrt(res)
return 1.0 / anp.sqrt(res)

def _get_chisq(self):
"""Computes chi square from estimated integral and variance, EQ 32.
Returns:
float: Chi squared.
backend tensor float: Chi squared.
"""
I_final = self._get_result()
chi2 = 0
Expand Down
Loading

0 comments on commit 23258a4

Please sign in to comment.