Skip to content

Commit

Permalink
Made all samplers operational with mini-batching
Browse files Browse the repository at this point in the history
  • Loading branch information
papamarkou committed Aug 20, 2021
1 parent 8e3910a commit 0956fdb
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 6 deletions.
6 changes: 5 additions & 1 deletion eeyore/samplers/am.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def set_recursive_cov(self, n, offset=0):
def draw(self, x, y, savestate=False, offset=0):
proposed = {key : None for key in self.keys}

if self.counter.num_batches != 1:
self.current['target_val'] = self.model.log_target(self.current['sample'].clone().detach(), x, y)

randn_sample = torch.randn(self.model.num_params(), dtype=self.model.dtype, device=self.model.device)
if (self.counter.idx + 1 - offset > self.t0):
if torch.rand(1, dtype=self.model.dtype, device=self.model.device) < self.l:
Expand All @@ -76,7 +79,8 @@ def draw(self, x, y, savestate=False, offset=0):

if torch.log(torch.rand(1, dtype=self.model.dtype, device=self.model.device)) < log_rate:
self.current['sample'] = proposed['sample'].clone().detach()
self.current['target_val'] = proposed['target_val'].clone().detach()
if self.counter.num_batches == 1:
self.current['target_val'] = proposed['target_val'].clone().detach()
self.current['accepted'] = 1
if (self.counter.idx > 0):
self.num_accepted = self.num_accepted + 1
Expand Down
9 changes: 7 additions & 2 deletions eeyore/samplers/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def leapfrog(self, position0, momentum0, x, y):
def draw(self, x, y, savestate=False):
proposed = {key : None for key in self.keys}

if self.counter.num_batches != 1:
self.current['target_val'], self.current['grad_val'] = \
self.model.upto_grad_log_target(self.current['sample'].clone().detach(), x, y)

proposed['sample'] = self.current['sample'].clone().detach()
proposed['momentum'] = torch.randn(self.model.num_params(), dtype=self.model.dtype, device=self.model.device)

Expand All @@ -143,8 +147,9 @@ def draw(self, x, y, savestate=False):

if torch.rand(1, dtype=self.model.dtype, device=self.model.device) < rate:
self.current['sample'] = proposed['sample'].clone().detach()
self.current['target_val'] = proposed['target_val'].clone().detach()
self.current['grad_val'] = proposed['grad_val'].clone().detach()
if self.counter.num_batches == 1:
self.current['target_val'] = proposed['target_val'].clone().detach()
self.current['grad_val'] = proposed['grad_val'].clone().detach()
self.current['accepted'] = 1
else:
self.model.set_params(self.current['sample'].clone().detach())
Expand Down
9 changes: 7 additions & 2 deletions eeyore/samplers/mala.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def set_kernel(self, state):
def draw(self, x, y, savestate=False):
proposed = {key : None for key in self.keys}

if self.counter.num_batches != 1:
self.current['target_val'], self.current['grad_val'] = \
self.model.upto_grad_log_target(self.current['sample'].clone().detach(), x, y)

proposed['sample'] = self.kernel.sample()

proposed['target_val'], proposed['grad_val'] = \
Expand All @@ -61,8 +65,9 @@ def draw(self, x, y, savestate=False):

if torch.log(torch.rand(1, dtype=self.model.dtype, device=self.model.device)) < log_rate:
self.current['sample'] = proposed['sample'].clone().detach()
self.current['target_val'] = proposed['target_val'].clone().detach()
self.current['grad_val'] = proposed['grad_val'].clone().detach()
if self.counter.num_batches == 1:
self.current['target_val'] = proposed['target_val'].clone().detach()
self.current['grad_val'] = proposed['grad_val'].clone().detach()
self.current['accepted'] = 1
else:
self.model.set_params(self.current['sample'].clone().detach())
Expand Down
6 changes: 5 additions & 1 deletion eeyore/samplers/ram.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def set_all(self, theta, data=None, cov=None):
def draw(self, x, y, savestate=False, offset=0):
proposed = {key : None for key in self.keys}

if self.counter.num_batches != 1:
self.current['target_val'] = self.model.log_target(self.current['sample'].clone().detach(), x, y)

randn_sample = torch.randn(self.model.num_params(), dtype=self.model.dtype, device=self.model.device)
proposed['sample'] = self.current['sample'].clone().detach() + self.chol_cov @ randn_sample
proposed['target_val'] = self.model.log_target(proposed['sample'].clone().detach(), x, y)
Expand All @@ -46,7 +49,8 @@ def draw(self, x, y, savestate=False, offset=0):

if torch.log(torch.rand(1, dtype=self.model.dtype, device=self.model.device)) < log_rate:
self.current['sample'] = proposed['sample'].clone().detach()
self.current['target_val'] = proposed['target_val'].clone().detach()
if self.counter.num_batches == 1:
self.current['target_val'] = proposed['target_val'].clone().detach()
self.current['accepted'] = 1
else:
self.model.set_params(self.current['sample'].clone().detach())
Expand Down

0 comments on commit 0956fdb

Please sign in to comment.