diff --git a/eeyore/samplers/am.py b/eeyore/samplers/am.py index 7116fa5..4b01cfa 100644 --- a/eeyore/samplers/am.py +++ b/eeyore/samplers/am.py @@ -61,6 +61,9 @@ def set_recursive_cov(self, n, offset=0): def draw(self, x, y, savestate=False, offset=0): proposed = {key : None for key in self.keys} + if self.counter.num_batches != 1: + self.current['target_val'] = self.model.log_target(self.current['sample'].clone().detach(), x, y) + randn_sample = torch.randn(self.model.num_params(), dtype=self.model.dtype, device=self.model.device) if (self.counter.idx + 1 - offset > self.t0): if torch.rand(1, dtype=self.model.dtype, device=self.model.device) < self.l: @@ -76,7 +79,8 @@ def draw(self, x, y, savestate=False, offset=0): if torch.log(torch.rand(1, dtype=self.model.dtype, device=self.model.device)) < log_rate: self.current['sample'] = proposed['sample'].clone().detach() - self.current['target_val'] = proposed['target_val'].clone().detach() + if self.counter.num_batches == 1: + self.current['target_val'] = proposed['target_val'].clone().detach() self.current['accepted'] = 1 if (self.counter.idx > 0): self.num_accepted = self.num_accepted + 1 diff --git a/eeyore/samplers/hmc.py b/eeyore/samplers/hmc.py index 649cf3d..0cc5731 100644 --- a/eeyore/samplers/hmc.py +++ b/eeyore/samplers/hmc.py @@ -126,6 +126,10 @@ def leapfrog(self, position0, momentum0, x, y): def draw(self, x, y, savestate=False): proposed = {key : None for key in self.keys} + if self.counter.num_batches != 1: + self.current['target_val'], self.current['grad_val'] = \ + self.model.upto_grad_log_target(self.current['sample'].clone().detach(), x, y) + proposed['sample'] = self.current['sample'].clone().detach() proposed['momentum'] = torch.randn(self.model.num_params(), dtype=self.model.dtype, device=self.model.device) @@ -143,8 +147,9 @@ def draw(self, x, y, savestate=False): if torch.rand(1, dtype=self.model.dtype, device=self.model.device) < rate: self.current['sample'] = proposed['sample'].clone().detach() - self.current['target_val'] = proposed['target_val'].clone().detach() - self.current['grad_val'] = proposed['grad_val'].clone().detach() + if self.counter.num_batches == 1: + self.current['target_val'] = proposed['target_val'].clone().detach() + self.current['grad_val'] = proposed['grad_val'].clone().detach() self.current['accepted'] = 1 else: self.model.set_params(self.current['sample'].clone().detach()) diff --git a/eeyore/samplers/mala.py b/eeyore/samplers/mala.py index fa54cb4..cf4e315 100644 --- a/eeyore/samplers/mala.py +++ b/eeyore/samplers/mala.py @@ -46,6 +46,10 @@ def set_kernel(self, state): def draw(self, x, y, savestate=False): proposed = {key : None for key in self.keys} + if self.counter.num_batches != 1: + self.current['target_val'], self.current['grad_val'] = \ + self.model.upto_grad_log_target(self.current['sample'].clone().detach(), x, y) + proposed['sample'] = self.kernel.sample() proposed['target_val'], proposed['grad_val'] = \ @@ -61,8 +65,9 @@ def draw(self, x, y, savestate=False): if torch.log(torch.rand(1, dtype=self.model.dtype, device=self.model.device)) < log_rate: self.current['sample'] = proposed['sample'].clone().detach() - self.current['target_val'] = proposed['target_val'].clone().detach() - self.current['grad_val'] = proposed['grad_val'].clone().detach() + if self.counter.num_batches == 1: + self.current['target_val'] = proposed['target_val'].clone().detach() + self.current['grad_val'] = proposed['grad_val'].clone().detach() self.current['accepted'] = 1 else: self.model.set_params(self.current['sample'].clone().detach()) diff --git a/eeyore/samplers/ram.py b/eeyore/samplers/ram.py index 5fd8c58..d99984a 100644 --- a/eeyore/samplers/ram.py +++ b/eeyore/samplers/ram.py @@ -38,6 +38,9 @@ def set_all(self, theta, data=None, cov=None): def draw(self, x, y, savestate=False, offset=0): proposed = {key : None for key in self.keys} + if self.counter.num_batches != 1: + self.current['target_val'] = self.model.log_target(self.current['sample'].clone().detach(), x, y) + randn_sample = torch.randn(self.model.num_params(), dtype=self.model.dtype, device=self.model.device) proposed['sample'] = self.current['sample'].clone().detach() + self.chol_cov @ randn_sample proposed['target_val'] = self.model.log_target(proposed['sample'].clone().detach(), x, y) @@ -46,7 +49,8 @@ def draw(self, x, y, savestate=False, offset=0): if torch.log(torch.rand(1, dtype=self.model.dtype, device=self.model.device)) < log_rate: self.current['sample'] = proposed['sample'].clone().detach() - self.current['target_val'] = proposed['target_val'].clone().detach() + if self.counter.num_batches == 1: + self.current['target_val'] = proposed['target_val'].clone().detach() self.current['accepted'] = 1 else: self.model.set_params(self.current['sample'].clone().detach())