Skip to content

Commit

Permalink
Remove DataParallel container in SS-VAE model (#3227)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinrohbeck authored Jun 8, 2023
1 parent 25b9bb0 commit 727aff7
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
16 changes: 8 additions & 8 deletions examples/vae/ss_vae_M2.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
self.aux_loss_multiplier = aux_loss_multiplier

# define and instantiate the neural networks representing
# the paramters of various distributions in the model
# the parameters of various distributions in the model
self.setup_networks()

def setup_networks(self):
Expand Down Expand Up @@ -142,7 +142,7 @@ def model(self, xs, ys=None):
# parametrized distribution p(x|y,z) = bernoulli(decoder(y,z))
# where `decoder` is a neural network. We disable validation
# since the decoder output is a relaxed Bernoulli value.
loc = self.decoder.forward([zs, ys])
loc = self.decoder([zs, ys])
pyro.sample(
"x", dist.Bernoulli(loc, validate_args=False).to_event(1), obs=xs
)
Expand All @@ -168,12 +168,12 @@ def guide(self, xs, ys=None):
# (and score) the digit with the variational distribution
# q(y|x) = categorical(alpha(x))
if ys is None:
alpha = self.encoder_y.forward(xs)
alpha = self.encoder_y(xs)
ys = pyro.sample("y", dist.OneHotCategorical(alpha))

# sample (and score) the latent handwriting-style with the variational
# distribution q(z|x,y) = normal(loc(x,y),scale(x,y))
loc, scale = self.encoder_z.forward([xs, ys])
loc, scale = self.encoder_z([xs, ys])
pyro.sample("z", dist.Normal(loc, scale).to_event(1))

def classifier(self, xs):
Expand All @@ -185,7 +185,7 @@ def classifier(self, xs):
"""
# use the trained model q(y|x) = categorical(alpha(x))
# compute all class probabilities for the image(s)
alpha = self.encoder_y.forward(xs)
alpha = self.encoder_y(xs)

# get the index (digit) that corresponds to
# the maximum predicted class probability
Expand All @@ -207,7 +207,7 @@ def model_classify(self, xs, ys=None):
with pyro.plate("data"):
# this here is the extra term to yield an auxiliary loss that we do gradient descent on
if ys is not None:
alpha = self.encoder_y.forward(xs)
alpha = self.encoder_y(xs)
with pyro.poutine.scale(scale=self.aux_loss_multiplier):
pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys)

Expand Down Expand Up @@ -332,7 +332,7 @@ def main(args):
# build a list of all losses considered
losses = [loss_basic]

# aux_loss: whether to use the auxiliary loss from NIPS 14 paper (Kingma et al)
# aux_loss: whether to use the auxiliary loss from NIPS 14 paper (Kingma et al.)
if args.aux_loss:
elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
loss_aux = SVI(
Expand Down Expand Up @@ -444,7 +444,7 @@ def main(args):
"--aux-loss",
action="store_true",
help="whether to use the auxiliary loss from NIPS 14 paper "
"(Kingma et al). It is not used by default ",
"(Kingma et al.). It is not used by default ",
)
parser.add_argument(
"-alm",
Expand Down
4 changes: 0 additions & 4 deletions examples/vae/utils/custom_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,6 @@ def __init__(
cur_linear_layer.weight.data.normal_(0, 0.001)
cur_linear_layer.bias.data.normal_(0, 0.001)

# use GPUs to share data during training (if available)
if use_cuda:
cur_linear_layer = nn.DataParallel(cur_linear_layer)

# add our linear layer
all_modules.append(cur_linear_layer)

Expand Down

0 comments on commit 727aff7

Please sign in to comment.