Skip to content

Commit

Permalink
implement fix for #412; ordering of dictionaries
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Jan 20, 2017
1 parent d5f4f80 commit b2dc0be
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
8 changes: 5 additions & 3 deletions edward/inferences/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import six
import tensorflow as tf

from collections import OrderedDict
from edward.inferences.monte_carlo import MonteCarlo
from edward.models import Normal, RandomVariable, Uniform
from edward.util import copy
Expand Down Expand Up @@ -65,9 +66,10 @@ def build_update(self):
"""
old_sample = {z: tf.gather(qz.params, tf.maximum(self.t - 1, 0))
for z, qz in six.iteritems(self.latent_vars)}
old_sample = OrderedDict(old_sample)

# Sample momentum.
old_r_sample = {}
old_r_sample = OrderedDict()
for z, qz in six.iteritems(self.latent_vars):
event_shape = qz.get_event_shape()
normal = Normal(mu=tf.zeros(event_shape), sigma=tf.ones(event_shape))
Expand Down Expand Up @@ -153,8 +155,8 @@ def _log_joint(self, z_sample):


def leapfrog(z_old, r_old, step_size, log_joint):
z_new = {}
r_new = {}
z_new = z_old.copy()
r_new = r_old.copy()

grad_log_joint = tf.gradients(log_joint(z_old), list(six.itervalues(z_old)))
for i, key in enumerate(six.iterkeys(z_old)):
Expand Down
6 changes: 2 additions & 4 deletions edward/inferences/sgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,8 @@ def build_update(self):
grad_log_joint = tf.gradients(self._log_joint(old_sample),
list(six.itervalues(old_sample)))
sample = {}
for z, qz, grad_log_p in \
zip(six.iterkeys(self.latent_vars),
six.itervalues(self.latent_vars),
grad_log_joint):
for z, grad_log_p in zip(six.iterkeys(old_sample), grad_log_joint):
qz = self.latent_vars[z]
event_shape = qz.get_event_shape()
normal = Normal(mu=tf.zeros(event_shape),
sigma=learning_rate * tf.ones(event_shape))
Expand Down

0 comments on commit b2dc0be

Please sign in to comment.