-
Notifications
You must be signed in to change notification settings - Fork 9
INN.BatchNorm2d
Zhang Yanbo edited this page Oct 26, 2022
·
1 revision
Implement batch normalization as it did in PyTorch. The INN.BatchNorm2d
is doing the same thing in forward as nn.BatchNorm2d(*, affine=False)
.
-
dim
: dimension of the input feature -
requires_grad
: Thevar
will have gradient ifrequires_grad=True
Compute the batch-normalized result y
. If compute_p=True
, it will return y
, logp
and log_detJ
.
import INN
import torch
model = INN.BatchNorm2d(3)
n_batch = 16
x = torch.randn(n_batch, 3, 16, 16)
y, logp, logdet = model(x)
Compute the inverse of y
. **args
is only a place-holder for consistency.
The inverse dose not work when it is in training mode. So, we need to set model.eval()
before using inverse:
import INN
import torch
model = INN.BatchNorm2d(3)
n_batch = 16
x = torch.randn(n_batch, 3, 16, 16)
y, logp, logdet = model(x)
model.eval()
x_hat = model.inverse(y)