diff --git a/Webpage/03-pytorch-gan.md b/Webpage/03-pytorch-gan.md index 850e938..7f6dbed 100644 --- a/Webpage/03-pytorch-gan.md +++ b/Webpage/03-pytorch-gan.md @@ -104,7 +104,7 @@ def test_discriminator(): critic = Discriminator((1, 28, 28), dropout=0.5, base_c=32, dnoise=0.1, num_classes=2) X = torch.randn(64, 1, 28, 28) out = critic(X) - assert(out.shape == torch.Size([64])) + assert(out.shape == torch.Size([64, 2])) ```