Skip to content

INN.BatchNorm2d

Zhang Yanbo edited this page Oct 26, 2022 · 1 revision

CLASS INN.BatchNorm2d(dim, requires_grad=True)

Implement batch normalization as it did in PyTorch. The INN.BatchNorm2d is doing the same thing in forward as nn.BatchNorm2d(*, affine=False).

  • dim: dimension of the input feature
  • requires_grad: The var will have gradient if requires_grad=True

Methods

forward(input, log_p0=0, log_det_J_=0)

Compute the batch-normalized result y. If compute_p=True, it will return y, logp and log_detJ.

import INN
import torch

model = INN.BatchNorm2d(3)

n_batch = 16
x = torch.randn(n_batch, 3, 16, 16)

y, logp, logdet = model(x)

inverse(y, **args)

Compute the inverse of y. **args is only a place-holder for consistency.

The inverse dose not work when it is in training mode. So, we need to set model.eval() before using inverse:

import INN
import torch

model = INN.BatchNorm2d(3)

n_batch = 16
x = torch.randn(n_batch, 3, 16, 16)

y, logp, logdet = model(x)
model.eval()

x_hat = model.inverse(y)
Clone this wiki locally