Skip to content

Commit

Permalink
update edward/inferences/
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Mar 8, 2017
1 parent 3e644d2 commit 8e49c38
Show file tree
Hide file tree
Showing 13 changed files with 339 additions and 553 deletions.
10 changes: 5 additions & 5 deletions edward/inferences/gan_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def __init__(self, data, discriminator):
Notes
-----
``GANInference`` does not support model wrappers or latent
variable inference. Note that GAN-style training also samples from
the prior: this does not work well for latent variables that are
shared across many data points (global variables).
``GANInference`` does not support latent variable inference. Note
that GAN-style training also samples from the prior: this does not
work well for latent variables that are shared across many data
points (global variables).
In building the computation graph for inference, the
discriminator's parameters can be accessed with the variable scope
Expand All @@ -55,7 +55,7 @@ def __init__(self, data, discriminator):
raise NotImplementedError()

self.discriminator = discriminator
super(GANInference, self).__init__(None, data, model_wrapper=None)
super(GANInference, self).__init__(None, data)

def initialize(self, optimizer=None, optimizer_d=None,
global_step=None, global_step_d=None, var_list=None,
Expand Down
48 changes: 22 additions & 26 deletions edward/inferences/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,32 +125,28 @@ def _log_joint(self, z_sample):
z_sample : dict
Latent variable keys to samples.
"""
if self.model_wrapper is None:
self.scope_iter += 1
scope = 'inference_' + str(id(self)) + '/' + str(self.scope_iter)
# Form dictionary in order to replace conditioning on prior or
# observed variable with conditioning on a specific value.
dict_swap = z_sample.copy()
for x, qx in six.iteritems(self.data):
if isinstance(x, RandomVariable):
if isinstance(qx, RandomVariable):
qx_copy = copy(qx, scope=scope)
dict_swap[x] = qx_copy.value()
else:
dict_swap[x] = qx

log_joint = 0.0
for z in six.iterkeys(self.latent_vars):
z_copy = copy(z, dict_swap, scope=scope)
log_joint += tf.reduce_sum(z_copy.log_prob(dict_swap[z]))

for x in six.iterkeys(self.data):
if isinstance(x, RandomVariable):
x_copy = copy(x, dict_swap, scope=scope)
log_joint += tf.reduce_sum(x_copy.log_prob(dict_swap[x]))
else:
x = self.data
log_joint = self.model_wrapper.log_prob(x, z_sample)
self.scope_iter += 1
scope = 'inference_' + str(id(self)) + '/' + str(self.scope_iter)
# Form dictionary in order to replace conditioning on prior or
# observed variable with conditioning on a specific value.
dict_swap = z_sample.copy()
for x, qx in six.iteritems(self.data):
if isinstance(x, RandomVariable):
if isinstance(qx, RandomVariable):
qx_copy = copy(qx, scope=scope)
dict_swap[x] = qx_copy.value()
else:
dict_swap[x] = qx

log_joint = 0.0
for z in six.iterkeys(self.latent_vars):
z_copy = copy(z, dict_swap, scope=scope)
log_joint += tf.reduce_sum(z_copy.log_prob(dict_swap[z]))

for x in six.iterkeys(self.data):
if isinstance(x, RandomVariable):
x_copy = copy(x, dict_swap, scope=scope)
log_joint += tf.reduce_sum(x_copy.log_prob(dict_swap[x]))

return log_joint

Expand Down
2 changes: 1 addition & 1 deletion edward/inferences/implicit_klqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self, latent_vars, data=None, discriminator=None,
self.discriminator = discriminator
self.global_vars = global_vars
# call grandparent's method; avoid parent (GANInference)
super(GANInference, self).__init__(latent_vars, data, model_wrapper=None)
super(GANInference, self).__init__(latent_vars, data)

def initialize(self, ratio_loss='log', *args, **kwargs):
"""Initialization.
Expand Down
Loading

0 comments on commit 8e49c38

Please sign in to comment.