From 20dc37c60406581ee7b7389b7405f469893548f6 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sat, 10 Aug 2024 18:01:00 +1000 Subject: [PATCH 001/152] ok --- .github/workflows/runner.yml | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/.github/workflows/runner.yml b/.github/workflows/runner.yml index 1355103..fe6802d 100644 --- a/.github/workflows/runner.yml +++ b/.github/workflows/runner.yml @@ -11,14 +11,5 @@ jobs: runs-on: self-hosted steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: '3.x' - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements.txt - name: Run tests - run: python -m unittest discover tests \ No newline at end of file + run: python test.py \ No newline at end of file From 7786d726f86c2e9277d0f39fbd9396d6eacff838 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sat, 10 Aug 2024 18:10:12 +1000 Subject: [PATCH 002/152] ok --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 33030b1..98fb5a9 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,6 @@ The frame_skip in VideoDataset - is only at 1. - **xformers + cuda ** https://github.com/facebookresearch/xformers From c606246d5df75df0fe86139494e0956c727848dc Mon Sep 17 00:00:00 2001 From: John Pope Date: Sat, 10 Aug 2024 18:11:39 +1000 Subject: [PATCH 003/152] ok --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index a4bcd07..6f4e31e 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ recon_epoch_1.png *.png __pycache__/vit.cpython-311.pyc __pycache__/helper.cpython-311.pyc +actions-runner/* \ No newline at end of file From 7fae70793bd1f8119f00718842683d66eec49f5c Mon Sep 17 00:00:00 2001 From: John Pope Date: Sat, 10 Aug 2024 18:13:34 +1000 Subject: [PATCH 004/152] ok --- .github/workflows/runner.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/runner.yml b/.github/workflows/runner.yml index fe6802d..9ebe8a8 100644 --- a/.github/workflows/runner.yml +++ b/.github/workflows/runner.yml @@ -12,4 +12,5 @@ jobs: steps: - name: Run tests - run: python test.py \ No newline at end of file + run: python test.py + working-directory: ../ \ No newline at end of file From 37c56af0ed0397b11695c674d8c1104dd78c014f Mon Sep 17 00:00:00 2001 From: John Pope Date: Sat, 10 Aug 2024 18:14:49 +1000 Subject: [PATCH 005/152] ok --- .github/workflows/runner.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/runner.yml b/.github/workflows/runner.yml index 9ebe8a8..88578b4 100644 --- a/.github/workflows/runner.yml +++ b/.github/workflows/runner.yml @@ -13,4 +13,5 @@ jobs: steps: - name: Run tests run: python test.py - working-directory: ../ \ No newline at end of file + working-directory: ../ + From 40e25fd21da7cfa7b091d7567ad36f2cbf31dbbf Mon Sep 17 00:00:00 2001 From: John Pope Date: Sat, 10 Aug 2024 18:15:42 +1000 Subject: [PATCH 006/152] ok --- .github/workflows/runner.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/runner.yml b/.github/workflows/runner.yml index 88578b4..ec7aebc 100644 --- a/.github/workflows/runner.yml +++ b/.github/workflows/runner.yml @@ -15,3 +15,5 @@ jobs: run: python test.py working-directory: ../ + + \ No newline at end of file From ddcc182b4276b213dc655fdb4a4335a55e434cbd Mon Sep 17 00:00:00 2001 From: John Pope Date: Sat, 10 Aug 2024 18:19:01 +1000 Subject: [PATCH 007/152] ok --- .github/workflows/runner.yml | 3 +- __pycache__/vit_mlgffn.cpython-311.pyc | Bin 23160 -> 23160 bytes test.py | 58 +------------------------ 3 files changed, 2 insertions(+), 59 deletions(-) diff --git a/.github/workflows/runner.yml b/.github/workflows/runner.yml index ec7aebc..bba6a1e 100644 --- a/.github/workflows/runner.yml +++ b/.github/workflows/runner.yml @@ -13,7 +13,6 @@ jobs: steps: - name: Run tests run: python test.py - working-directory: ../ + working-directory: /media/oem/12TB/IMF - \ No newline at end of file diff --git a/__pycache__/vit_mlgffn.cpython-311.pyc b/__pycache__/vit_mlgffn.cpython-311.pyc index c41b8c8c00a79a423dad219d7c9918d5dea7f45f..6abf115401ad788bdc67364083affe539d774170 100644 GIT binary patch delta 21 bcmeydh4IH0My}<&yj%=G@K1IlS6&nVRtyHt delta 21 bcmeydh4IH0My}<&yj%=Gz`?wcD=!KFPsRnl diff --git a/test.py b/test.py index 36cf7ae..ee76cbd 100644 --- a/test.py +++ b/test.py @@ -7,7 +7,7 @@ # Add the directory containing the module to the Python path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from vit import PositionalEncoding, AdaptiveLayerNorm, TransformerBlock, ImplicitMotionAlignment, CrossAttentionModule +from vit import PositionalEncoding, TransformerBlock, ImplicitMotionAlignment, CrossAttentionModule from model import LatentTokenEncoder, LatentTokenDecoder, DenseFeatureEncoder, FrameDecoder class TestNeuralNetworkComponents(unittest.TestCase): @@ -29,33 +29,8 @@ def test_positional_encoding(self): output = pe(x) self.assertEqual(output.shape, (100, 1, self.motion_dim)) - def test_adaptive_layer_norm(self): - aln = AdaptiveLayerNorm(dim=self.feature_dim).to(self.device) - x = torch.randn(self.B, 10, self.feature_dim).to(self.device) - output = aln(x) - self.assertEqual(output.shape, (self.B, 10, self.feature_dim)) - def test_transformer_block(self): - tb = TransformerBlock(feature_dim=self.feature_dim, heads=self.heads, dim_head=self.dim_head, mlp_dim=self.mlp_dim).to(self.device) - x = torch.randn(self.B, self.C_f, self.H, self.W).to(self.device) - output = tb(x) - self.assertEqual(output.shape, (self.B, self.C_f, self.H, self.W)) - def test_cross_attention_module(self): - cam = CrossAttentionModule(motion_dim=self.motion_dim, feature_dim=self.feature_dim, heads=self.heads, dim_head=self.dim_head).to(self.device) - ml_c = torch.randn(self.B, self.C_m, self.H, self.W).to(self.device) - ml_r = torch.randn(self.B, self.C_m, self.H, self.W).to(self.device) - fl_r = torch.randn(self.B, self.C_f, self.H, self.W).to(self.device) - output = cam(ml_c, ml_r, fl_r) - self.assertEqual(output.shape, (self.B, self.C_f, self.H, self.W)) - - def test_implicit_motion_alignment(self): - ima = ImplicitMotionAlignment(feature_dim=self.feature_dim, motion_dim=self.motion_dim, heads=self.heads, dim_head=self.dim_head, mlp_dim=self.mlp_dim).to(self.device) - ml_c = torch.randn(self.B, self.C_m, self.H, self.W).to(self.device) - ml_r = torch.randn(self.B, self.C_m, self.H, self.W).to(self.device) - fl_r = torch.randn(self.B, self.C_f, self.H, self.W).to(self.device) - output = ima(ml_c, ml_r, fl_r) - self.assertEqual(output.shape, (self.B, self.C_f, self.H, self.W)) def test_latent_token_encoder(self): et = LatentTokenEncoder(dm=self.dm).to(self.device) @@ -101,37 +76,6 @@ def test_latent_token_decoder(self): for m_r_x, m_c_x in zip(m_r, m_c): self.assertEqual(m_r_x.shape, m_c_x.shape, "Shapes of reference and current outputs should match") - def test_frame_decoder(self): - f_c1 = torch.randn(self.B, 512, 8, 8).to(self.device) - f_c2 = torch.randn(self.B, 512, 16, 16).to(self.device) - f_c3 = torch.randn(self.B, 512, 32, 32).to(self.device) - f_c4 = torch.randn(self.B, 256, 64, 64).to(self.device) - - model = FrameDecoder().to(self.device) - output = model([f_c4, f_c3, f_c2, f_c1]) - - self.assertEqual(output.shape, (self.B, 3, 256, 256), f"Expected output shape (self.B, 3, 256, 256), but got {output.shape}") - - def test_invalid_input_shapes(self): - ima = ImplicitMotionAlignment(feature_dim=self.feature_dim, motion_dim=self.motion_dim, heads=self.heads, dim_head=self.dim_head, mlp_dim=self.mlp_dim).to(self.device) - - with self.assertRaises(RuntimeError): - ml_c = torch.randn(self.B + 1, self.C_m, self.H, self.W).to(self.device) - ml_r = torch.randn(self.B, self.C_m, self.H, self.W).to(self.device) - fl_r = torch.randn(self.B, self.C_f, self.H, self.W).to(self.device) - ima(ml_c, ml_r, fl_r) - - with self.assertRaises(RuntimeError): - ml_c = torch.randn(self.B, self.C_m, self.H, self.W).to(self.device) - ml_r = torch.randn(self.B, self.C_m + 1, self.H, self.W).to(self.device) - fl_r = torch.randn(self.B, self.C_f, self.H, self.W).to(self.device) - ima(ml_c, ml_r, fl_r) - - with self.assertRaises(RuntimeError): - ml_c = torch.randn(self.B, self.C_m, self.H, self.W).to(self.device) - ml_r = torch.randn(self.B, self.C_m, self.H, self.W).to(self.device) - fl_r = torch.randn(self.B, self.C_f, self.H + 1, self.W).to(self.device) - ima(ml_c, ml_r, fl_r) if __name__ == '__main__': unittest.main() \ No newline at end of file From 08c0bb1e351a85dd596ad9d9fbaa5e8a5b8cda9b Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 04:01:27 +1000 Subject: [PATCH 008/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model.py | 193 +--------------------------- resblock.py | 362 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 364 insertions(+), 191 deletions(-) create mode 100644 resblock.py diff --git a/model.py b/model.py index d68f334..9a26f3a 100644 --- a/model.py +++ b/model.py @@ -10,6 +10,8 @@ from vit import ImplicitMotionAlignment from stylegan import EqualConv2d,EqualLinear from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights +from resnet import UpConvResBlock,DownConvResBlock,FeatResBlock,ModulatedConv2d,StyledConv,ResBlock + # from common import DownConvResBlock,UpConvResBlock import colored_traceback.auto # makes terminal show color coded output when crash @@ -107,105 +109,6 @@ def forward(self, x): return features -class UpConvResBlock(nn.Module): - def __init__(self, in_channels, out_channels): - super().__init__() - - self.upsample = nn.Upsample(scale_factor=2, mode='nearest') - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.bn1 = nn.BatchNorm2d(out_channels) - self.relu = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.feat_res_block = FeatResBlock(out_channels) - - self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None - - - # Add a 1x1 convolution for the residual connection if channel sizes differ - self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None - - def forward(self, x): - residual = self.upsample(x) - if self.residual_conv: - residual = self.residual_conv(residual) - - out = self.upsample(x) - out = self.conv1(out) - out = self.bn1(out) - out = self.relu(out) - out = self.conv2(out) - out = out + residual - out = self.relu(out) - out = self.feat_res_block(out) - return out - - -class DownConvResBlock(nn.Module): - def __init__(self, in_channels, out_channels): - super().__init__() - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.bn1 = nn.BatchNorm2d(out_channels) - self.relu = nn.ReLU(inplace=True) - self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2) - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.feat_res_block1 = FeatResBlock(out_channels) - self.feat_res_block2 = FeatResBlock(out_channels) - - def forward(self, x): - debug_print(f"DownConvResBlock input shape: {x.shape}") - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - debug_print(f"After conv1, bn1, relu: {out.shape}") - out = self.avgpool(out) - debug_print(f"After avgpool: {out.shape}") - out = self.conv2(out) - debug_print(f"After conv2: {out.shape}") - out = self.feat_res_block1(out) - debug_print(f"After feat_res_block1: {out.shape}") - out = self.feat_res_block2(out) - debug_print(f"DownConvResBlock output shape: {out.shape}") - return out - - -''' -DenseFeatureEncoder -It starts with a Conv-64-k7-s1-p3 layer, followed by BatchNorm and ReLU, as shown in the first block of the image. -It then uses a series of DownConvResBlocks, which perform downsampling (indicated by ↓2 in the image) while increasing the number of channels: - -DownConvResBlock-64 -DownConvResBlock-128 -DownConvResBlock-256 -DownConvResBlock-512 -DownConvResBlock-512 - - -It outputs multiple feature maps (f¹ᵣ, f²ᵣ, f³ᵣ, f⁴ᵣ) as shown in the image. These are collected in the features list and returned. - -Each DownConvResBlock performs downsampling using a strided convolution, maintains a residual connection, and applies BatchNorm and ReLU activations, which is consistent with typical ResNet architectures.''' - - -class FeatResBlock(nn.Module): - def __init__(self, channels): - super().__init__() - - self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) - self.bn1 = nn.BatchNorm2d(channels) - self.relu1 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) - self.bn2 = nn.BatchNorm2d(channels) - self.relu2 = nn.ReLU(inplace=True) - - def forward(self, x): - residual = x - out = self.conv1(x) - out = self.bn1(out) - out = self.relu1(out) - out = self.conv2(out) - out = self.bn2(out) - out += residual - out = self.relu2(out) - return out class DenseFeatureEncoder(nn.Module): @@ -331,52 +234,6 @@ def forward(self, x): Our latent tokens are directly learned by the encoder, rather than being restricted to coordinates with a limited value range. ''' -class ModulatedConv2d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1, demodulate=True): - super().__init__() - self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) - self.stride = stride - self.padding = padding - self.demodulate = demodulate - - def forward(self, x, style): - batch, in_channel, height, width = x.shape - style = style.view(batch, 1, in_channel, 1, 1) - - # Weight modulation - weight = self.weight.unsqueeze(0) * style - - if self.demodulate: - demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) - weight = weight * demod.view(batch, self.weight.size(0), 1, 1, 1) - - weight = weight.view( - batch * self.weight.size(0), in_channel, self.weight.size(2), self.weight.size(3) - ) - - x = x.view(1, batch * in_channel, height, width) - out = F.conv2d(x, weight, padding=self.padding, groups=batch) - _, _, height, width = out.shape - out = out.view(batch, self.weight.size(0), height, width) - - return out - -class StyledConv(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, style_dim, upsample=False, demodulate=True): - super().__init__() - self.conv = ModulatedConv2d(in_channels, out_channels, kernel_size, demodulate=demodulate) - self.style = nn.Linear(style_dim, in_channels) - self.upsample = upsample - self.activation = nn.LeakyReLU(0.2) - - def forward(self, x, latent): - if self.upsample: - x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) - - style = self.style(latent) - x = self.conv(x, style) - x = self.activation(x) - return x class LatentTokenDecoder(nn.Module): def __init__(self, latent_dim=32, const_dim=32): @@ -494,53 +351,7 @@ def forward(self, features): ReLU activations are applied both after adding the residual and at the end of the block. The FeatResBlock is now a subclass of ResBlock with downsample=False, as it doesn't change the spatial dimensions. ''' -class ResBlock(nn.Module): - def __init__(self, in_channels, out_channels, downsample=False): - super().__init__() - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2 if downsample else 1, padding=1) - self.bn1 = nn.BatchNorm2d(out_channels) - self.relu1 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.bn2 = nn.BatchNorm2d(out_channels) - self.relu2 = nn.ReLU(inplace=True) - - if downsample or in_channels != out_channels: - self.shortcut = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2 if downsample else 1, padding=0), - nn.BatchNorm2d(out_channels) - ) - else: - self.shortcut = nn.Identity() - - self.downsample = downsample - self.in_channels = in_channels - self.out_channels = out_channels - - def forward(self, x): - debug_print(f"ResBlock input shape: {x.shape}") - debug_print(f"ResBlock parameters: in_channels={self.in_channels}, out_channels={self.out_channels}, downsample={self.downsample}") - residual = self.shortcut(x) - debug_print(f"After shortcut: {residual.shape}") - - out = self.conv1(x) - debug_print(f"After conv1: {out.shape}") - out = self.bn1(out) - out = self.relu1(out) - debug_print(f"After bn1 and relu1: {out.shape}") - - out = self.conv2(out) - debug_print(f"After conv2: {out.shape}") - out = self.bn2(out) - debug_print(f"After bn2: {out.shape}") - - out += residual - debug_print(f"After adding residual: {out.shape}") - - out = self.relu2(out) - debug_print(f"ResBlock output shape: {out.shape}") - - return out class TokenManipulationNetwork(nn.Module): diff --git a/resblock.py b/resblock.py new file mode 100644 index 0000000..feaac08 --- /dev/null +++ b/resblock.py @@ -0,0 +1,362 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import matplotlib.pyplot as plt + + +DEBUG = False +def debug_print(*args, **kwargs): + if DEBUG: + print(*args, **kwargs) + +class UpConvResBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.bn1 = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.feat_res_block = FeatResBlock(out_channels) + + self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None + + + # Add a 1x1 convolution for the residual connection if channel sizes differ + self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None + + def forward(self, x): + residual = self.upsample(x) + if self.residual_conv: + residual = self.residual_conv(residual) + + out = self.upsample(x) + out = self.conv1(out) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = out + residual + out = self.relu(out) + out = self.feat_res_block(out) + return out + + +class DownConvResBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.bn1 = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.feat_res_block1 = FeatResBlock(out_channels) + self.feat_res_block2 = FeatResBlock(out_channels) + + def forward(self, x): + debug_print(f"DownConvResBlock input shape: {x.shape}") + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + debug_print(f"After conv1, bn1, relu: {out.shape}") + out = self.avgpool(out) + debug_print(f"After avgpool: {out.shape}") + out = self.conv2(out) + debug_print(f"After conv2: {out.shape}") + out = self.feat_res_block1(out) + debug_print(f"After feat_res_block1: {out.shape}") + out = self.feat_res_block2(out) + debug_print(f"DownConvResBlock output shape: {out.shape}") + return out + + + +class FeatResBlock(nn.Module): + def __init__(self, channels): + super().__init__() + + self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) + self.bn1 = nn.BatchNorm2d(channels) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) + self.bn2 = nn.BatchNorm2d(channels) + self.relu2 = nn.ReLU(inplace=True) + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + out = self.conv2(out) + out = self.bn2(out) + out += residual + out = self.relu2(out) + return out + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels, downsample=False): + super().__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, + stride=2 if downsample else 1, padding=1) + self.bn1 = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.bn2 = nn.BatchNorm2d(out_channels) + + if downsample or in_channels != out_channels: + self.shortcut = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, + stride=2 if downsample else 1, bias=False), + nn.BatchNorm2d(out_channels) + ) + else: + self.shortcut = nn.Identity() + + self.downsample = downsample + self.in_channels = in_channels + self.out_channels = out_channels + + def forward(self, x): + identity = self.shortcut(x) + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + out += identity + out = self.relu(out) + + return out + + +class ModulatedConv2d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1, demodulate=True): + super().__init__() + self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) + self.stride = stride + self.padding = padding + self.demodulate = demodulate + + def forward(self, x, style): + batch, in_channel, height, width = x.shape + style = style.view(batch, 1, in_channel, 1, 1) + + # Weight modulation + weight = self.weight.unsqueeze(0) * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.weight.size(0), 1, 1, 1) + + weight = weight.view( + batch * self.weight.size(0), in_channel, self.weight.size(2), self.weight.size(3) + ) + + x = x.view(1, batch * in_channel, height, width) + out = F.conv2d(x, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.weight.size(0), height, width) + + return out + +class StyledConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, style_dim, upsample=False, demodulate=True): + super().__init__() + self.conv = ModulatedConv2d(in_channels, out_channels, kernel_size, demodulate=demodulate) + self.style = nn.Linear(style_dim, in_channels) + self.upsample = upsample + self.activation = nn.LeakyReLU(0.2) + + def forward(self, x, latent): + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + + style = self.style(latent) + x = self.conv(x, style) + x = self.activation(x) + return x + + + + + +def test_upconvresblock(block, input_shape): + print("\nTesting UpConvResBlock") + x = torch.randn(input_shape) + output = block(x) + print(f"Input shape: {x.shape}, Output shape: {output.shape}") + assert output.shape[2:] == tuple(2*x for x in x.shape[2:]), "UpConvResBlock should double spatial dimensions" + assert output.shape[1] == block.conv2.out_channels, "Output channels should match block's out_channels" + + # Test gradient flow + output.sum().backward() + for name, param in block.named_parameters(): + assert param.grad is not None, f"No gradient for {name}" + print(f"{name} gradient shape: {param.grad.shape}") + +def test_downconvresblock(block, input_shape): + print("\nTesting DownConvResBlock") + x = torch.randn(input_shape) + output = block(x) + print(f"Input shape: {x.shape}, Output shape: {output.shape}") + assert output.shape[2:] == tuple(x//2 for x in x.shape[2:]), "DownConvResBlock should halve spatial dimensions" + assert output.shape[1] == block.conv2.out_channels, "Output channels should match block's out_channels" + + # Test gradient flow + output.sum().backward() + for name, param in block.named_parameters(): + assert param.grad is not None, f"No gradient for {name}" + print(f"{name} gradient shape: {param.grad.shape}") + +def test_featresblock(block, input_shape): + print("\nTesting FeatResBlock") + x = torch.randn(input_shape) + output = block(x) + print(f"Input shape: {x.shape}, Output shape: {output.shape}") + assert output.shape == x.shape, "FeatResBlock should maintain input shape" + + # Test gradient flow + output.sum().backward() + for name, param in block.named_parameters(): + assert param.grad is not None, f"No gradient for {name}" + print(f"{name} gradient shape: {param.grad.shape}") + +def test_modulatedconv2d(conv, input_shape, style_dim): + print("\nTesting ModulatedConv2d") + x = torch.randn(input_shape) + style = torch.randn(input_shape[0], style_dim) + output = conv(x, style) + print(f"Input shape: {x.shape}, Style shape: {style.shape}, Output shape: {output.shape}") + assert output.shape[1] == conv.weight.shape[0], "Output channels should match conv's out_channels" + + # Test gradient flow + output.sum().backward() + assert conv.weight.grad is not None, "No gradient for weight" + print(f"Weight gradient shape: {conv.weight.grad.shape}") + +def test_styledconv(conv, input_shape, latent_dim): + print("\nTesting StyledConv") + x = torch.randn(input_shape) + latent = torch.randn(input_shape[0], latent_dim) + output = conv(x, latent) + print(f"Input shape: {x.shape}, Latent shape: {latent.shape}, Output shape: {output.shape}") + expected_shape = list(input_shape) + expected_shape[1] = conv.conv.weight.shape[0] + if conv.upsample: + expected_shape[2] *= 2 + expected_shape[3] *= 2 + assert list(output.shape) == expected_shape, f"Output shape {output.shape} doesn't match expected {expected_shape}" + + # Test gradient flow + output.sum().backward() + for name, param in conv.named_parameters(): + assert param.grad is not None, f"No gradient for {name}" + print(f"{name} gradient shape: {param.grad.shape}") + +def test_resblock(resblock, input_shape): + print("\nTesting ResBlock") + x = torch.randn(input_shape) + output = resblock(x) + print(f"Input shape: {x.shape}, Output shape: {output.shape}") + + expected_output_shape = list(input_shape) + expected_output_shape[1] = resblock.out_channels + if resblock.downsample: + expected_output_shape[2] //= 2 + expected_output_shape[3] //= 2 + assert tuple(output.shape) == tuple(expected_output_shape), f"Expected shape {expected_output_shape}, got {output.shape}" + + # Test gradient flow + output.sum().backward() + for name, param in resblock.named_parameters(): + assert param.grad is not None, f"No gradient for {name}" + print(f"{name} gradient shape: {param.grad.shape}") + + # Test residual connection + resblock.eval() + with torch.no_grad(): + residual_output = resblock(x) + identity = resblock.shortcut(x) + main_path = resblock.bn2(resblock.conv2(resblock.relu(resblock.bn1(resblock.conv1(x))))) + direct_output = resblock.relu(main_path + identity) + assert torch.allclose(residual_output, direct_output, atol=1e-6), "Residual connection not working correctly" + + print("ResBlock test passed successfully!") + + + +def test_block_with_dropout(block, input_shape, block_name): + print(f"\nTesting {block_name}") + x = torch.randn(input_shape) + + # Test in training mode + block.train() + output_train = block(x) + print(f"Training mode - Input shape: {x.shape}, Output shape: {output_train.shape}") + + # Test in eval mode + block.eval() + with torch.no_grad(): + output_eval = block(x) + print(f"Eval mode - Input shape: {x.shape}, Output shape: {output_eval.shape}") + + # Check if outputs are different in train and eval modes + assert not torch.allclose(output_train, output_eval), f"{block_name} outputs should differ between train and eval modes due to dropout" + + # Test gradient flow + block.train() # Set back to train mode for gradient check + output_train.sum().backward() + for name, param in block.named_parameters(): + assert param.grad is not None, f"No gradient for {name}" + print(f"{name} gradient shape: {param.grad.shape}") + + print(f"{block_name} test passed successfully!") +def visualize_feature_maps(block, input_shape, num_channels=4, latent_dim=None): + x = torch.randn(input_shape) + if isinstance(block, StyledConv): + latent = torch.randn(input_shape[0], latent_dim) + output = block(x, latent) + else: + output = block(x) + + fig, axs = plt.subplots(1, num_channels, figsize=(20, 5)) + for i in range(num_channels): + axs[i].imshow(output[0, i].detach().cpu().numpy(), cmap='viridis') + axs[i].axis('off') + axs[i].set_title(f'Channel {i}') + plt.tight_layout() + plt.show() + + +# Run all tests +upconv = UpConvResBlock(64, 128) +test_upconvresblock(upconv, (1, 64, 56, 56)) +visualize_feature_maps(upconv, (1, 64, 56, 56)) + +downconv = DownConvResBlock(128, 256) +test_downconvresblock(downconv, (1, 128, 56, 56)) +visualize_feature_maps(downconv, (1, 128, 56, 56)) + +featres = FeatResBlock(256) +test_featresblock(featres, (1, 256, 28, 28)) +visualize_feature_maps(featres, (1, 256, 28, 28)) + +modconv = ModulatedConv2d(64, 128, 3) +test_modulatedconv2d(modconv, (1, 64, 56, 56), 64) + +styledconv = StyledConv(64, 128, 3, 32, upsample=True) +test_styledconv(styledconv, (1, 64, 56, 56), 32) +visualize_feature_maps(styledconv, (1, 64, 56, 56), num_channels=4, latent_dim=32) + + +resblock = ResBlock(64, 128, downsample=True) +test_resblock(resblock, (1, 64, 56, 56)) +visualize_feature_maps(resblock, (1, 64, 56, 56)) + +# Usage +resblock = ResBlock(64, 64) +test_resblock(resblock, (1, 64, 56, 56)) \ No newline at end of file From 9f99616c7c65c1a31d3ea5fc01b52dd82faf29ac Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 04:01:44 +1000 Subject: [PATCH 009/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model.py b/model.py index 9a26f3a..c09ab3e 100644 --- a/model.py +++ b/model.py @@ -10,7 +10,7 @@ from vit import ImplicitMotionAlignment from stylegan import EqualConv2d,EqualLinear from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights -from resnet import UpConvResBlock,DownConvResBlock,FeatResBlock,ModulatedConv2d,StyledConv,ResBlock +from resblock import UpConvResBlock,DownConvResBlock,FeatResBlock,StyledConv,ResBlock # from common import DownConvResBlock,UpConvResBlock From 64440bed5c9d75d19d5de0eb185ec712de170fbd Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 04:03:04 +1000 Subject: [PATCH 010/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/resblock.cpython-311.pyc | Bin 0 -> 25184 bytes resblock.py | 45 ++++++++++++++------------- 2 files changed, 24 insertions(+), 21 deletions(-) create mode 100644 __pycache__/resblock.cpython-311.pyc diff --git a/__pycache__/resblock.cpython-311.pyc b/__pycache__/resblock.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98a166573fd605e493fcd4559d920251734a5959 GIT binary patch literal 25184 zcmeHvdu$t5y5Eo-l0#A?MT^wi5@pMh8Ozoqw&U1}ZTS_)55;!uI8kCVG>4LCQx6X* zKNRX#349yc)!lGY^@h9G!E(3RtbEtF5Q?-{ZU3mT*j{*Vx9yBDfP?`A7;OX8KnuvQ zw_W6qqThFh9L@|WS?N06dx1L|eVjS-oipdmob&r0=R5NYr_)Zsw)~fW-T$lY6!l;6 zB)?g62TymJDC%>HrC39Z>XUyBL&iR%K|VGO8N+ZqME99iD zt8kSQE*EgQSK;!o<PGOHquD?H`W=~z~c_BpIv z4r_Re3bZ`Ns`MTS=c+`0uwicaf7$`$b1Fd+Wx}}<%7CD#Bo(EShA$02FyO#?i~Qd+ z^ahNQ`Ox75=U$i0qg-@20q1mx8;DDmE7$Qij(y4R|F*GxD8fcV?IV$)_Rg-e2io5_ zd8D0-#4p80!dJG9-jp0{XObLO zBHVB!7K}&7Ba$Va;G%3qvW|vWHaa|T_MSnqMu$gZp>QN%;3^?Yz%7~MW1|sHVuHcw za5NDNO2*+~i8(hK4-JjRA`&x}KeAv~?qVf#7=Ak?$AM5HJlH$J4MCde(r~9lpNX6} zC(&FaHYPK5No7YOp+vs>ORj!63qn(ZWVlLBIE=Dmp;!=$yhlrPJQC}d9MR!mcrY{! z4T(!-BV&oenJpL$$3pRVFc`-k?DzX8IqXJUJ&^G#?Qq+cZxX)d`2PX&d%yt3dD}W= zU0|x`nQDQl6`5L|sr{z9_O4^rGgre`9~7z&iq!{G^ptbSS2g`^=EIp24?_H@)BJny z^S%p$?}F&N09Wlx6~5`c**2k~U94z_1ZK%wIsJO}5bq5L-hk*0K$=Z1B02N^gRA`M zGyMA(c;7|AcTx0RBt=GMxrD&0UuyT#?!v&BuA=!vz6OtArn}OV;xh?P`8E!~S zerQQ17(n^K;wOVevi6T~AXVA%O5K3TrLTuV;(rBXl3J{&`uw%eUdwufij88$#uT&Y zaEp%hAR}$t7pv+Pssi&>ftk%h)yrbl%c-M}T z7Wj%)S|VVvPDNX>C{~*rEOhSY*0r=i)U|5Ds3{kO<8#YZhPb-YoHhH@2<_zvZo z*HErv%zTG(Ezd322nxTocf3JAcKo2Qk0t!^!O&=Aw|~64r#}(l{6qzB^Mlsi=0|-S zuDInZxq%D`&H`%}oL0%Fx={@m8jEfT8iVSd0UNQS!BD_D=#`Is@Dya!*pEj&KmsPB zhH#yj@G_7%c3#lcWqFEHcf4AE(``B3V;Az5gNQ;^y@L1dojv{adH8`{VD^d3KAzc^ zS1LROEn;P*f|Rm*CVu}W@7)c%!0Zv3Jv_4qONrfyWnLCsjiReDWg?uN+3=T9-rEVg zz;uaB7teI%@0U|z2w!{l;FoXk-gejprbA>pc%}pR7F-+VT^q7cW~=CG&6WA;O!z*O z`6}!Jvs+|#^UUs*czb5UUjuIs>;iK@WDfAmfr56!?SRL_?Z#&Wk0`I9bYmitt$5*e zAb_v=zhxt1-n~qfAYK{@ULun)cOv}35PO8V|(E6B+xkpE%{#c5%A9Z zKk6eoy9aa&1+4yHeMA9t(3n=Pq2eRa+)sgGy{Ld1meykwFf6GtDf$c1H!P*~jX>F2 z)eFgtnAAyN0bK)Kaza-Q7$hcd-OIYRj1FoI!l!l{DB8nu=rQr10)gJWf*DT?_8T+R zb5|dF__ODD-`j%kZPE8OF}8`lzFDZ)B35jH1O`=t^Vze!?Q(Pj`pIhq=lOwoEq+uoJtQ=vzFv6tx^=**Jular9^CDrH?+R zbW`5KwWvhP5y&J<6+|=cC=&E6=haQMDDi4KOHPcOL*@}6Xu^8IY|YuBD~k-Bm8v^c z%NNFxx^dfGY_2LcXlay21SBwc4C{vN(xLPgMCl^pwINEcMr4%6;=}N-Z28ywv;~sX z4r>-%)Znr<-m-zq8t1YKE^D*IL-m2?fLjb3gTVTCE86|KnpLj`X#u?I`n@Oh|RTtfx5lcu;u<&8e7X`_PE zn$$v_&_``jJw5cR!#FJ1;F^5O2-PdwMYb_|KFBs)ttkP$B4qmt=r41}IL8xIawNeo zrtyzcL9|=R2G5WRmgmoi?6Ku!5Ax^@-XkDC&}mJKU5UOF4aEbVwaO%ws|<&pWPdAi zbu2QRh=yVk6CWJm65+9g#Js^`TH>Z;XTg~w`@p$2yeaBd0kf>xDhb>fyhO}ee43!+ zY8=v@{9saK*d^eKw?bcAQHH53;KeOF_?GT$C>xsVd=P@ksmeDDPbj3%2t|SQ%?t2C z>x+2g@GVnT+e?|cnX>ts?R?GlrTT`uCzs0CP1`bCvdIUY#g-1f>6NS}>zNJBMWAM$ z!&oz*P&1(40IMsL6w2Gg^0p^ddcErz1tfK7>SXHV@5-q4jdzd#=;Ym#nDD5!85-hV ztZPBjyLz#HGyc}_WG^+9^v?Dhmn4n_e>_^0u@T)oFP|di^n_D3w)4rn+4v zB(qxHDk$*0JOz1s3RqP_cN3nXu;6fn1Y~D2vEn2K=8?#_knK4{VC?HEgRWfZ>NyE~ z19;*e0x97{MmI4E*p;QR%H5l1F6}~oq?X3&T#7Uf7><+{$S0t39Fh=L6%RB5*=r6g zQ!A0S)qJ|d<10&LEHw=O%9ejk#ZuW!q_T-fWm+SZEt2~ry0Jrv2#fyRE|x&!*|1e9 zR-Wf3$FGn^@Im81pE zl)Io*kY_;bI;fXyr$QX~`N1C|nPECKJP2k+XqX+AEY~B^fx!fk6I?ai6VN8-fa_>V18Pc&z*v(;6TPM}Jm!W|BZQzfe1+p@@S5{BQe!k<&Ju}I<_a1KH z&tDQYhQ*Cxp^_CVS%5*uVr4z~scL|qvPBsiQ|^>oM#6meft!%X2J*(+fTB%t8B@vj zKD>ts(*!+Xo&eWVwlwaUK+~eZHa2K=1VqZq3jSr)cgz3X0Z`P`?X@0GRcaVxpOgU*!^#X>{beWQvasL3{(oWf+b~7IfeOjAN+*N5NeRPS(Ut zh`KOe!r%(YUWkN{EQ18pB=hwsn;4YuY9_EoWU^!k#&Nh2w33NRI#=l-D<`DWkRyH# z-~oj6mbUw^2+Zb`IW?ZynccBi?)%*SnLV>3+m!8__00ANq~JT5I`YU<^?BWAb&F-5^r7kF3*P2=Z*%sv;B6JXtwLGrOzrI1uisnP zad>{m;RhjM$5C;|(TB~y-Li13Z~j;xL~;nnE{Ml2JTV#9*`HBBQWl7)aJoJ{a_308 zGaboX7AiN-w9IZ2%65rmyHbWFmwRe3eKpfE(|o^eZj5&w7F>r#*J1DiG&J3f3QS$f zkm^kJrz49l&!^}u7|XDCuPiidn{U`Qb5&^Q6dO8cch2n)Ts@+zCuMq6&`(TqXT%

~0&6@^!(s4~BRd&f~PURjm#ZpQ+n2}$B{jtAOW=RW8~eUE5if1xOG4zXDAjOfS9)th~4Aj0gL?J`tcIP!hwD}I)b{t+#xv3HU>mv zpjnV%I6+uSIZ4;4-f<55{|d^8e*grqA7J}6jOjUVSSt5SJ2Rn~(ClUY)H%NVZK3>a zvHa~+&!ciLkq0}4@(p77hLm}+to+kA?!1u>Wqd+eqgd9MGGNl1cix2XpRxwAtN{|- z4Ay6?Z+tAg0G?A%ljC5zp7Mg;qEXll`&Az?o}l*#U+*2 z{Iz_g^?of637#ma8Eg3tQ5Rl$AKHeNP||!^@v=@jp{2GYO~NvC?Nz@oIwaC%iH{KYNp!>QiRnO8}=8RMt|xD za-1>{OuS&5z}RC0YUXcKpBv!bb_~zIXCwvm-h(5_xUGZuZ*tM|L?Qgv);$mzj@%gK z_KvqCppL<@Q6hWtLG;}asve5P_ij`89>9I$EkOP|{7q7P^Iq6zPG?VNum1Ec`Akj8 zkQ;B;wi_ZCM#f^W6k=rTQY_++kAesb#rzWx1PdzE^gwUgR?;WIaJHJJ|qW$fZJdUVw`24 z7y5TO7E_r^p)jsGV5M@57G#5?c|Sv3qVY};V`!Emm;s(4QX(ivMqs>yWRZg>BpS2A zVg}3_$9t=Ea;c)kbLajkF#TKjj~^uZ_!4d7?aczcNu)RN^d^|_oN}a(3yfc6{F6P4 z&ULAC>CW5lOuaLCgs}Exn(nsp_8!=0cYfV19|`mUkv_oF2NoTksqTy+W6T&uN8J}= z^81=?`nq+lX>P;ZhOf6kiqO1Q1lqV)Y}`BV*vrfN{;PsxuSkw_rA+P1MDAV@8+Oh+ zcJlJR>x$smDUxHb56q5Fj(vRWqZ7AIEYLOcbd5mQiF6%L*DY4nFH~)vui7e9wTo5l zU@4fp3Ayb^(W%Iy+dDPRyEoz9x__=22<(D;lStCw8Kk|uYd7pO=KFSW^Dg;B3F&~S zj%9PcNYV}e0rb^i{tux4{epgt0{Z&!CT;>oHuon;aB6F9_zqevH4f5Cjqltfmhlmi zkAbL>ox24I-m#CD_nx@m*e8q|@>c#S@=Z(Wd3oDV~Z4fheBlH7cGKP&}2KYf?NRg4CyAkUB1= zfIs7#)#>@+I}`##6aLlVriB3uo(62QPJJTU5_bnL{1YToNJ@@s4;j_{u=)#@r4Y@D z%K-BN{Kx+n5QyhOZD)f(Zxrc`JiYOSjOfn0ugp6(@bcc=D>ycY(b3?>qu|CvUHMGcT*N^E~R~7FX`qW zU@)n!vlfkQsSMac3F;cFhPzce^uS^@uVpA_*D#E6xEE%$3WISHE#)HnR`YCZrE?Zq zQ))>7u^#Qca@j|Kih_}fC$WW}pnAhSL%dWvEMpPKZ z69$h!*|a#Ss4!NTOJ@qRyF`U?B&`I$Ag@}{QX?R8ux09;g*O3tSMr@Xd&esj^ecF+ zRlo->vlGPIU5tFhxFS=wx8O2R!}3lCLHMGdnG!uVu1# zvAX`QO|0H5R>K^pSlv0J(*27}sY{R^xotO8{>kle_UC(!3$9-T-J`f!3BDuUO_~L5iw)v_yp=z60 zweA12>!}Q08PV+SfL`;0kotR>Z_pr^F3x5GBU2QgKg15vIRsTSDL_#8YKRUY&Ew$m zP!N=6EvsEmpnD{@sb?Ov2&&=I&6%jM)k^&LjJlkX*lPNtVJpUGS#;QHOPZgHt-t~E z5~`ff!&aEF#6DQGX)#|n*X7K!Yq@s797M^vE{OWcy7GC{zeTfR3~;L6bb%Mr%L`|P zBS`lY&I(764)VEpgY=5x)Z!(x+J3e6Gl*bXsYbRWQz9PCuCc>VO(z>fO4(V*{W9%$ zWynmNQ~T6EHrU{6N%*I=$qdx6V&q(Hu>vI z0eA1XN0GMq5Z;2bK#oko7bOc-_u;(Q%!yS?n}1*=p%S74fA&BBW%hsn&FtQQ0lwgr z7>SJspjh&i7hJVE3_=%Xd;G(Y>G}wF1t-X0-kXEu*v;_``5O39?D1$?E1Gmsv06XF>^>@Xz- z!Bjg)H=xF%)CVum4!iK%#2K9{71b=3L-#>x-$e1|tp?s`N zL#;068yBiN=Bqk{sxGmr>-!M%Fu&#eHs1GMlrVKiYD4OUk`iQMXXhJYRtUY89R&YctYMb*u8`{MEUg}oQVy%&Uzi(L%-$jrop?Y#S_;65t4kMhhI?#IRiW=xcinXzR@ zp|Wz7pOA__Q#C@}pyK1-|Ni&IiE_&CT$I_)+qVgHyGXb5`8|$+{O_9%>^T}RJq*xC zw^|=|7=VV=Canc{h@#{Ls#lbfMj{xY;}@1{=|mR#N61!W9%6ux`c8R9U0_5`wCT)~ z30ecD5NlR&3xOyGN16J(L>R{k0Yh4d*R=(o1=+M*iul#n(fkL92fC<5=G9xI5od}C z4Z+K1t$5-f)B@idYlbgiKx@?I;#%KGlK!J9Xch>Kw(1l#P}!cVpaBP~$*Fmox;KY` zBkKM|o)=U6#gN<*xpA8!vd7n-`!3nepprEN5M(Ti2K>Gh(9snXStD4$Xzg?$7@%S+j`>~ zf-kUa5{SY53_1q)&yjqIq=e)KhvN_M5=vjmPNX3D%9Ff43BMADff%Gn$&ow5Dsu(% z8ogzC^IEp9tB6#U3%edo32cn^5q5!Y7U^c5Zq~_J6ifC` zj(>doqm#EzPM%z>s7~+5cz(ERde`I$;@~?!({vv~G|t2Rpj$k6PCgXqcSQOfo_r<_pW7%GM_NYbgmD6_mDG% zf};YiKT~n(pAlWYKzrwDuRvFdbmaowI8Qe!-i+@*S_E31>f`M%%ez2t7imy+a(f&N zi(i$M9cVWFs+m5p+4`$32B6$UNDk9L@*Cnbky&?&+^~hO-xSFsTD_DuNhK7the${$ zr!X1AZB}E0xq&?tN_u)lwtIVgXB(+}SiQ)1>|IZ8Y_98ZLX!w`}T z<-p387_6C&k6jv#jU?h^`P~3~cQiT#!G6~z3mZvtU1i&=;7>I{QzTR9MqDyP z7Da$wR zbyLYqPbN8go_D_{xL*_9uYGiM@?ffB$+<4wEjaz6)4xoaaIq;62$_ZjtvyqgPn~z1 z>Hf^=tRZuDI<`<5n6C^7m0QHhEen;}XW4De? zl0Py5^>{x#w#Sn|AOIb^6!#lU*o8;`0gu4(pj`Ufd@8w=&+uDhtH(?K5eb?FkNrT- z!ZHc;uETPxU&9Z1+mM77JtQIQ_6GcG}3}EHk3&&hP zWR^VO`p2Fci(Z0HMdOiVk$<^=jrYT`A%E-2LZ)3K}n#a)hnKA)2%=EiYF#TEK;sHYqG8dCHH+}O7ravoO-0el)Vv_Efe)S2aKPz0^){MNx zB*jfZxsd*>aIx1=kIcm+eQ7%S1k;}tE_S;Ou!NwPq}NT?Ji+v5g^Mqh86aF$A@#}; zNvn6@RfW_#414i3M`}HIRUs|)h7PP^j_il$wgaDAj%;@s-0*EK4QZzA)u{#u;XNqW z{G!c2X<4Kyc&Z|OB=e>~wTV<4Pqi&ll{{6MzLxn%0@W^3?L5`~2xpw@G6%C2f@72D z*fhz={Ukle`#S`xQ=~e1sxz1I2H(^vP+cO`#Zz66?3JmubX>64iT1il>mudhDNnkU zuiK)(-*t6SkL)$6_RNM%INK@MH;ML5nAOWu-t>9!O>l+comHtmY!7@$;=rtzckUFN zJ4NTtN&6z@o^+?}nKx(72vmnib?{V2u7cy)eGfbjIi7+mXGH1@96-M<^QLyC4`fU; zH1F6dIJSz8t=LTGq%$>`aSPODk=o2toAZs#HY)8dWHt(pI+&d=l@GGQPKM+E1<6|r A>i_@% literal 0 HcmV?d00001 diff --git a/resblock.py b/resblock.py index feaac08..1355e14 100644 --- a/resblock.py +++ b/resblock.py @@ -332,31 +332,34 @@ def visualize_feature_maps(block, input_shape, num_channels=4, latent_dim=None): plt.show() -# Run all tests -upconv = UpConvResBlock(64, 128) -test_upconvresblock(upconv, (1, 64, 56, 56)) -visualize_feature_maps(upconv, (1, 64, 56, 56)) -downconv = DownConvResBlock(128, 256) -test_downconvresblock(downconv, (1, 128, 56, 56)) -visualize_feature_maps(downconv, (1, 128, 56, 56)) +if __name__ == "__main__": -featres = FeatResBlock(256) -test_featresblock(featres, (1, 256, 28, 28)) -visualize_feature_maps(featres, (1, 256, 28, 28)) + # Run all tests + upconv = UpConvResBlock(64, 128) + test_upconvresblock(upconv, (1, 64, 56, 56)) + visualize_feature_maps(upconv, (1, 64, 56, 56)) -modconv = ModulatedConv2d(64, 128, 3) -test_modulatedconv2d(modconv, (1, 64, 56, 56), 64) + downconv = DownConvResBlock(128, 256) + test_downconvresblock(downconv, (1, 128, 56, 56)) + visualize_feature_maps(downconv, (1, 128, 56, 56)) -styledconv = StyledConv(64, 128, 3, 32, upsample=True) -test_styledconv(styledconv, (1, 64, 56, 56), 32) -visualize_feature_maps(styledconv, (1, 64, 56, 56), num_channels=4, latent_dim=32) + featres = FeatResBlock(256) + test_featresblock(featres, (1, 256, 28, 28)) + visualize_feature_maps(featres, (1, 256, 28, 28)) + modconv = ModulatedConv2d(64, 128, 3) + test_modulatedconv2d(modconv, (1, 64, 56, 56), 64) -resblock = ResBlock(64, 128, downsample=True) -test_resblock(resblock, (1, 64, 56, 56)) -visualize_feature_maps(resblock, (1, 64, 56, 56)) + styledconv = StyledConv(64, 128, 3, 32, upsample=True) + test_styledconv(styledconv, (1, 64, 56, 56), 32) + visualize_feature_maps(styledconv, (1, 64, 56, 56), num_channels=4, latent_dim=32) -# Usage -resblock = ResBlock(64, 64) -test_resblock(resblock, (1, 64, 56, 56)) \ No newline at end of file + + resblock = ResBlock(64, 128, downsample=True) + test_resblock(resblock, (1, 64, 56, 56)) + visualize_feature_maps(resblock, (1, 64, 56, 56)) + + # Usage + resblock = ResBlock(64, 64) + test_resblock(resblock, (1, 64, 56, 56)) \ No newline at end of file From 4dcd1b50acfb8f5a5357f06710725c602790c88a Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 04:13:59 +1000 Subject: [PATCH 011/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/resblock.cpython-311.pyc | Bin 25184 -> 25237 bytes resblock.py | 128 ++++++++++++++++++--------- 2 files changed, 87 insertions(+), 41 deletions(-) diff --git a/__pycache__/resblock.cpython-311.pyc b/__pycache__/resblock.cpython-311.pyc index 98a166573fd605e493fcd4559d920251734a5959..151d23b6acdf499f0d8cc7fbf5eab8daebbc4a02 100644 GIT binary patch delta 1035 zcmZ`$O-~b16z$Ahi9tF;f!fZrWm+gg6B4NXkkkbP4H6QAh&0*+8p#`t@hjGiaWj+X zqC0CElUdFRNgxJg;nsz2{K^~i7L9JYC2rsk81JjCwFX}1&b&GIoO9p3ul`!a6j2@zq&<6f*H874BLWuVn zArkJy1rOrt5xAFxNTicr2qVAW2$MmMSLa2gIxPlG8F>QM!}zlb5#$XR5u$hMoAnOI zd-=IZbVEWXjHc5dI{A%35++(_rs3t|ni&!$F*3A0ym1Vr7)_~;(vyac(%2zVwUG&s z*mk^1MZd|54{eWhG7C}k2pG{L26s5F;Gwk6#*g+P!m@{lm>At%mwsKl`u%vZc=!6P zd&T133^uZ7;VS{90sNFca$t(@KJ%Ri2i{73YN3r(3;n`rVV?S0eYKhtV%7TU8?ltF zK3luo{R44+We=b1{#-i{Oz0^;{wAqZQun(&a=5%^&pMLpjWciDnyW#8Lcs2F6c;8L zOj4Mvfkr{Iv(B^&DF!JDsk)+5MRz1;*;NuuNm%nWP${VP7>%c0$S}xI$TWphbR^?K zmO++6wl14m`Hr1(Wu3{ob)g1+3jWGXs!#l1@XWtnbvaVLZ?8MU&I4CYFgd}cK?*@T z;7shCZu|Q}6^%FKENzNASEw)N`f|*dvleR5Z!K0998=lpb0N(jO(ESpV$Lz&Ez#u_ z3V9du4Dz_a;q(T}k8IICw|Sl_NmofSCCNMWSr;o~_M!{p48|#pxB7I7e_L;hXI({S Xir&^m+CtRgvwzlu6PDD{3B07=TY_f$# z%9S8&nJNZN;dzCvh{E{=Gcd(smG&ArTMO}RK8!0Fc>9L`e}3=Hn|=OK7<~|YFJxH+ zT*oip-g|%XnJ*?(_p{f9Y=*?13Jt{t%@v}Mc;5x^3S+Vwanw&W<3#PQy2BSRJADDr zmr0x?oO(k=J=j!9(&h@7o&K2fAthipjA^$7Qlah-q1axv??O_2H~JV;KZ$<0nItMv zoX)sjV8RfMq=)PM6G&oA$u%S+%{3&o1*9@|qC^{>?t75)oS%W?!v1mL^ze)mZKSYh zw3(Vb&Q;6{hO}A5sn0p5E&vg3z{c5wmizauYe$#y4o5nEHax8mrq(B0*tB>c2{ z&8*Keb!^SlulBKoZGGz=?vw?s`)lV-cmH-x{Enji?0feQ#5+HQwixTInJ#miEo*#P z)0Zp=QV1GX%o>9{hdhP61rZ7n+O^+mqgbN;xf|XEb)7UcI0}qu7Oe%1xU=J$ma%b1{GstnsQOM24 a?4|xrK8IaQQn{oq;t-`?$h9Q=9e)6G8|ZKV diff --git a/resblock.py b/resblock.py index 1355e14..839aa07 100644 --- a/resblock.py +++ b/resblock.py @@ -10,76 +10,78 @@ def debug_print(*args, **kwargs): print(*args, **kwargs) class UpConvResBlock(nn.Module): - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels, out_channels, dropout_rate=0.1): super().__init__() - self.upsample = nn.Upsample(scale_factor=2, mode='nearest') self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.feat_res_block = FeatResBlock(out_channels) + self.bn2 = nn.BatchNorm2d(out_channels) + self.dropout = nn.Dropout2d(dropout_rate) - self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None - + self.feat_res_block1 = FeatResBlock(out_channels) + self.feat_res_block2 = FeatResBlock(out_channels) + self.feat_res_block3 = FeatResBlock(out_channels) - # Add a 1x1 convolution for the residual connection if channel sizes differ - self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None + self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): - residual = self.upsample(x) - if self.residual_conv: - residual = self.residual_conv(residual) + x = self.upsample(x) + residual = self.residual_conv(x) - out = self.upsample(x) - out = self.conv1(out) + out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) + out = self.bn2(out) out = out + residual out = self.relu(out) - out = self.feat_res_block(out) + out = self.dropout(out) + out = self.feat_res_block1(out) + out = self.feat_res_block2(out) + out = self.feat_res_block3(out) return out class DownConvResBlock(nn.Module): - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels, out_channels, dropout_rate=0.1): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.bn2 = nn.BatchNorm2d(out_channels) + self.dropout = nn.Dropout2d(dropout_rate) self.feat_res_block1 = FeatResBlock(out_channels) self.feat_res_block2 = FeatResBlock(out_channels) + self.feat_res_block3 = FeatResBlock(out_channels) def forward(self, x): - debug_print(f"DownConvResBlock input shape: {x.shape}") out = self.conv1(x) out = self.bn1(out) out = self.relu(out) - debug_print(f"After conv1, bn1, relu: {out.shape}") out = self.avgpool(out) - debug_print(f"After avgpool: {out.shape}") out = self.conv2(out) - debug_print(f"After conv2: {out.shape}") + out = self.bn2(out) + out = self.relu(out) + out = self.dropout(out) out = self.feat_res_block1(out) - debug_print(f"After feat_res_block1: {out.shape}") out = self.feat_res_block2(out) - debug_print(f"DownConvResBlock output shape: {out.shape}") + out = self.feat_res_block3(out) return out - class FeatResBlock(nn.Module): - def __init__(self, channels): + def __init__(self, channels, dropout_rate=0.1): super().__init__() - self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.relu1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) self.bn2 = nn.BatchNorm2d(channels) + self.dropout = nn.Dropout2d(dropout_rate) self.relu2 = nn.ReLU(inplace=True) def forward(self, x): @@ -89,45 +91,46 @@ def forward(self, x): out = self.relu1(out) out = self.conv2(out) out = self.bn2(out) + out = self.dropout(out) out += residual out = self.relu2(out) return out class ResBlock(nn.Module): - def __init__(self, in_channels, out_channels, downsample=False): + def __init__(self, in_channels, out_channels, downsample=False, dropout_rate=0.1): super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels # Add this line + self.downsample = downsample + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2 if downsample else 1, padding=1) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) self.bn2 = nn.BatchNorm2d(out_channels) + self.dropout = nn.Dropout2d(dropout_rate) if downsample or in_channels != out_channels: self.shortcut = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=1, - stride=2 if downsample else 1, bias=False), + nn.Conv2d(in_channels, out_channels, kernel_size=3, + stride=2 if downsample else 1, padding=1), nn.BatchNorm2d(out_channels) ) else: self.shortcut = nn.Identity() - self.downsample = downsample - self.in_channels = in_channels - self.out_channels = out_channels - def forward(self, x): - identity = self.shortcut(x) + residual = self.shortcut(x) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) - out = self.conv2(out) out = self.bn2(out) - - out += identity + out = self.dropout(out) + out += residual out = self.relu(out) return out @@ -315,6 +318,7 @@ def test_block_with_dropout(block, input_shape, block_name): print(f"{name} gradient shape: {param.grad.shape}") print(f"{block_name} test passed successfully!") + def visualize_feature_maps(block, input_shape, num_channels=4, latent_dim=None): x = torch.randn(input_shape) if isinstance(block, StyledConv): @@ -323,16 +327,44 @@ def visualize_feature_maps(block, input_shape, num_channels=4, latent_dim=None): else: output = block(x) - fig, axs = plt.subplots(1, num_channels, figsize=(20, 5)) - for i in range(num_channels): - axs[i].imshow(output[0, i].detach().cpu().numpy(), cmap='viridis') - axs[i].axis('off') - axs[i].set_title(f'Channel {i}') + # Determine the number of intermediate outputs + if isinstance(block, UpConvResBlock): + intermediate_outputs = [ + block.conv2(block.relu(block.bn1(block.conv1(x)))), + block.feat_res_block1(block.dropout(block.relu(block.bn2(block.conv2(block.relu(block.bn1(block.conv1(x))))) + block.residual_conv(x)))), + block.feat_res_block2(block.feat_res_block1(block.dropout(block.relu(block.bn2(block.conv2(block.relu(block.bn1(block.conv1(x))))) + block.residual_conv(x))))), + output + ] + titles = ['After Conv2', 'After FeatResBlock1', 'After FeatResBlock2', 'Final Output'] + elif isinstance(block, DownConvResBlock): + intermediate_outputs = [ + block.conv2(block.avgpool(block.relu(block.bn1(block.conv1(x))))), + block.feat_res_block1(block.dropout(block.relu(block.bn2(block.conv2(block.avgpool(block.relu(block.bn1(block.conv1(x))))))))), + block.feat_res_block2(block.feat_res_block1(block.dropout(block.relu(block.bn2(block.conv2(block.avgpool(block.relu(block.bn1(block.conv1(x)))))))))), + output + ] + titles = ['After Conv2', 'After FeatResBlock1', 'After FeatResBlock2', 'Final Output'] + else: + intermediate_outputs = [output] + titles = ['Output'] + + fig, axs = plt.subplots(len(intermediate_outputs), num_channels, figsize=(20, 5 * len(intermediate_outputs))) + if len(intermediate_outputs) == 1: + axs = [axs] # Make it 2D for consistency + + for i, out in enumerate(intermediate_outputs): + for j in range(num_channels): + axs[i][j].imshow(out[0, j].detach().cpu().numpy(), cmap='viridis') + axs[i][j].axis('off') + if j == 0: + axs[i][j].set_title(f'{titles[i]}\nChannel {j}') + else: + axs[i][j].set_title(f'Channel {j}') + plt.tight_layout() plt.show() - if __name__ == "__main__": # Run all tests @@ -362,4 +394,18 @@ def visualize_feature_maps(block, input_shape, num_channels=4, latent_dim=None): # Usage resblock = ResBlock(64, 64) - test_resblock(resblock, (1, 64, 56, 56)) \ No newline at end of file + test_resblock(resblock, (1, 64, 56, 56)) + + + # dropout + upconv = UpConvResBlock(64, 128, dropout_rate=0.1) + test_block_with_dropout(upconv, (1, 64, 56, 56), "UpConvResBlock") + + downconv = DownConvResBlock(128, 256, dropout_rate=0.1) + test_block_with_dropout(downconv, (1, 128, 56, 56), "DownConvResBlock") + + featres = FeatResBlock(256, dropout_rate=0.1) + test_block_with_dropout(featres, (1, 256, 28, 28), "FeatResBlock") + + resblock = ResBlock(64, 128, downsample=True, dropout_rate=0.1) + test_block_with_dropout(resblock, (1, 64, 56, 56), "ResBlock") From 4d0cd4f504f5fca409c1116135f70ff35f6dbeb9 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 04:19:46 +1000 Subject: [PATCH 012/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/resblock.cpython-311.pyc | Bin 25237 -> 29886 bytes config.yaml | 2 +- helper.py | 46 +++++++++++++++++++++++++++ train.py | 8 +++-- 4 files changed, 53 insertions(+), 3 deletions(-) diff --git a/__pycache__/resblock.cpython-311.pyc b/__pycache__/resblock.cpython-311.pyc index 151d23b6acdf499f0d8cc7fbf5eab8daebbc4a02..a91bc991b41f9a90db8524492d624994e27d249c 100644 GIT binary patch delta 8539 zcmb_hdvKFicK^Ol@3;KIvMgcQ#>N6{Fvi5j#$X--;zzJ~1>0B%OEv;Uwm-=Rykg~S z(^eran7yP+UX6xn6WF915@$ASCJ(x8C)-W4)u$@8t0r4!JIt1yovghj-87w@_MH1k ze&IOzV?X5IJ?GqWALoAO+evb-|8J`xU4sEHtFkSkFYs* zx9JJCBhO=Y=bh$IW1QuWOwJSHC1-m{;IWd^+->&g98=TimjR zH7o-)I%s2EqLqcVvba~e6-%_zr`o_QUcs4!=~x|B+Q{lIaI2PRWS|o)g2BpOl0rlH z6Q2$nZD)fiY8JP1vj-Ubv#+^_%B_x`R3lo*01fa*E=rQ*wIVaA%D0&HhX!```-j^- zfgQg7p3}d4_8EMBw5>-|V4~jAtilm^GNqg#+7a)^GK9c+7xu9jwCfgL;dC$$jFC^BD*%#gOhg! zJxP_*+2eBs0#0WD7sdw9D6ayFw-SZ2kYiaglbg?eN*u~FMOisl8lvljto4bk^%3cq z^3&YBE1_t!kXxC^tptJgE^!zO$RVRK`>wU%`J>Tq-E97eD`wpVSlv8r# ze+$YbY=|Mk>RfWM!YJcHOq3(9tL$Xd79ww|*OJH8)vSVKl0VtGB$YS)CVIL|tOq%QSxdOSX#(+LXI zAPukp+r@XJLsbMVRjo$u-M986a$JbSDsxdPfgZ^AQc8z`<_<8t6lQH#FOHHuXy3A- zVwdohs8oZn1t5Tj1Z)zb|Dl(Tm=-%lMQ=ZU#>Kn0GK`W5yaPI;3@J4iLHrmr^TLL- zK}@}?$qdhvedD-UBH(62iLZw-c^IDIcL2Wo3 z>fG|Bo!hQoT6YhAQX1#aZ_it2WwUw*xbi%AS)hzu)5v@SA`Yw-PB(^LPAM zu36p!vcMZDH3G`Yp1i- z_`a3{@qMj@_vI*6rcYCcxU1oLn11j?M9=@Pi?GSxi%ZW)0eL@hi&e`NCBmoZ1 zqfbynV(r=Fqx_Ow#hffnG0Z6?DU|2uV*m6BYSI+csY#v1s+pCf`Z{B==4#Vg3-p!b zvcBZ#tUmNGo0Wv5FTuI-3Z5PGfjMFlB+` z#YE31RV#}hvCGAMYqlYehvPiYmjer)KbP$QKa!F5fC{J6>-PqoP7jC!H8ck@P34zH zqzR>I%2I}9-jsP2`{ujsFa1mn7d zaovnUYSrIo03y4_S|af8b1eg@Q`0jo&$pySwS*QA z@&^WoJf0DcIEha3y|PjrHnVF|r!V0lqlt#9Je+B~)fSD;V)@cD;rkSxrtE2<}@_zl6YL{i@D z^?O`Ae;OFxM((V$HqtC%F4}N$ld3@QoX_KQdmj@M_2rkiq1+_@J~Rus@a&m5S6;!I zaAQa+Pk4QFJO^YWnK)z?dA}m3*M_7QH8fs-(8CY( z%Hk6UbVo7|;Dp_n;HGn!*U6Uke8G~bL;%M?xB)AfOMbC_CmWc^-#o~X?`_<|8cBTP zb^{vFdV{&aCAo?Xd<`kA+QI%UaaVcR9P-nup2AM-qYL2#KvEVQ!i|k?X}I0-eWZ0$ zJzGPb+SDNXa}c~Y@v}|;D$$H$-Qx&G@^#w%9}~a)`0fzim;c=D0cBtq4F;?>Rf*=?Qw=PHMV>zk*dylMfta zvTH!zpSbThF3%Z7=>rq?wpiujgvnUT>F}bJdJ)@Df1#(`7_v*uG%58AyL?H7zu(!% zyWHXlLO+tJ!_gCD6}Bw`NGe@EUyrXp;IV6xO5PKocM3dRc{w$WkWg#>D$xpTU;riY zwgIjx28M!3rPmG7<_(@BKWyD5`vK@ECcbQ4C1)Qe1&4al_CyY?ZlARWB7k>!{WKc* zzrwk^gzz%L5*C3wUdAe%RZ>l7==8fD^FXArhYRGX6)ruzUiKzvte)sNY~n1(f$YJp zI{n~gMU%Vj8glW`ch}BK{6b6+kAu|n5(E~#yP`bYqz{u{@(-_t^^(kspo-}_-9-R2@^HY$$f=+%dE)PH6`1cYrfw2Z_*3& zqz^B&)h`F0J+E*}0F}YQw3U!R&nB$Fk~wKeu_&45+)}q}IgJrm$zJYNdFratlj^Gr z9O}I4yn4B&ZiRSp%S%FPn4UU_!J2v&Ph=N!Y?}6|Ij;@c@gA}WME`bP7gC2bm{EvN zh-Mce9nGQsrggI+T`xDM6I}uFY+~#%u(`98iBiYw^&8yES5*t0fMu2$B#*dCa%$BJ zmknCz+-%Ca@?~0^=@gaoO?ZxRYeG7b(UfC-YuT1Cu3a0_J;%JjEgpUBld+ivFFPMT zQ}T54U%v$;rZD{`*R}HSwjiMx(KPl#o?*i@VdF@a$jrZutz1gDaYXCz`dvQTBjV;L z7t4Fu(C0mjGdr79^m_XOUdRLStV$||y}Z}$4cN)vPF+}~-YI^++eVaW%5I%Ih?8n> z!0QhLUH%>q%}VIKpqNT5+}Xt&EqcL4H~;l8n0Pt53#mOgwq(Eu`*&dI#DK3q7)VNd zI*&i8^7w~fZ-pF+T4?YhJ z0G1~uuCsw8>rExMrvmu#no7h|1-*4jt{CLeo{*i|S|9(dw6Luhm{Z^jhclJ8xy&=zg_ZD6UTw z*9)et3DZ_V+mO&UOldPFwS{qQVYKAs@@wTcWpAn9P~WP4zxM6g+lM~r{BdWj?eJt< zXS}U5*5#V)^2ED5vEI{@y@T;^?_i7%2))5XZ&2tON^}hgZNrJSVPW@~#O^af-PuIl z*=dFev5nmCuy};~tyq9?Ki9_1AlJ^_7YpF(+fi0|07Za67av3#5NO>Yqyf`24#Siu ziyn#>J#uC79gxWg$y9j#ix-51X|}q8ZbPi&74-r zZ0eZOGOcAYoopwAu&R;92)+JQyX5dB6mh_s>_mc2~ZqglZ>yA0y zlTL5k>5ZN0pF9M`;sx7~eD&AUwi7io-W?M{zD#0N2QMAcqBT!)zyVSO9mD z?LrO<;JVmj$YB9oJKK#M7QpRcoxp`@32g0gtOEqP*o8D8(7F>y1Ae8WMC5mDrB}~_ z15IU`$ITJ>T|?$r=!Kmz+a|$Mov>8j4E^Lpto^{phmXcO{KDb>#NmFSy+7765Hk#h z1;b#%F!=c1(Vda3PZjEzra(~G5(-uDseIjgKKWVLsTdM?Xb;4RbX|0P}>jdl8gmvrfeKD;?(6%PDtubxuyxuOg zUTD2qEoci9+QOK&aLQmDZ@e%Z8J;p`jW=GAj}Kl^O&W{i#^UJkb=O~>i>=xQzgvxB zDHyjWjN4=McgL8uZc&TY>x0qOXsckXN*JqR^f#>ob3QjRG9}r6?4dCIDoemFevh6P z|AG_)8#;Wq)2fd_z@wBxS~`hp^nPLrtFn=`0UWfmJeF|n@*xy_6lqqx`@_=&=uKE~ z>FYo%WsZZa)m}&s)1)%k&-a|H@cZF@QPW+*dP!w>0s9_l>(1Xn56HEkO|QD>HHvuF z;r%-l;5AL+>(R355oAyE44EyzfUXW z6gQCHIX^4dyp}agGYg@MWu)53NL-R9Oh{Z|Hax9l^p;5X)g4jE&ElADv!L6Y&}|-d zOiA-&()_EA=p(V}-GbDSkUC;g#~njK%uo=OL=W5y#0+(Up)O&l8{L~y-Wx039IJU) zkTxZxO)+WH9lbdM|LuDHKuljL=qnTY%F#V3)jhG2>YLvdqp9wgb3~BtNJw|Y(!c2=V2e2>$)oS({{bjLSK0so delta 5804 zcmbVQdr+Ilwf{anKLL_3ddmoez<_Lw0l)DJu!HRcr`WNRnAj1rB(Q8D$tz*U1P5|& znkX(!h)-J6PTG<-b>o`2_K?1&GoAj?&h(M?N?grAzcv|9+Wz5xTHNH`_K!C0IV;4= zH0gBp>8JDUIlFsy&)KtQ^~1a3=YAp@zMY)8O zvC9#11|4S=B*E@Yj)?Dx1wru@BL>Ci zq%LJp@tn}5N(oPiU21Hnlnuw?dt=e_M?#6VNGxzRprroiiWYzA{R6WCUa_43mF6F! z14gwj=q}e44T^(`vs|3n;5}_!X%xpA6nut!^W_3kUH}YIL}NuUuMvi&C+myaL;hr5 z1KH4nZ$1`E_)d{VvL4BBu-_l?1yIrJDmb9?ROo!ZK*XO&_nEy3-1+4wOUq=V$4fWK~PDshM=0D7U8;ztpkt#;J%yqGAzuP zY*%+(*)_f@W2(%WDpS%-p)p(NniWLdrWxDnX`6S-=6$m^W80XuZA=}wV<>sK?{eQ* zYsOHPHI%{YrB$G}8sS!z9^ThGY}1B{DMQ70WyVmQHB_gRGscos0V{+~!&(ikC))(W z28**E$pp=+CAZXiVi6QUX1!&|pdJusg@1u=z1lP@d>pXCe0KpP=F7kB(;F8#b@9H~ zh3F@{^#qLCy5V=mzs-Af!5no`t7{1gxIv_cnF)1aW&#VsQRX1hh~QN)a$;s9+Q?QD z&6mqXlp8KMTrOSf`QFZ0EYcK|wf^(HaU^qAtw9-3={y#;7OCY<45v%VGkDesedWY7Ils{5q6yB)B>ylQD@>ifjTPp%-QSuQp-s6Z=dA08-MZc=wR8t+fK z6Q}+0P_ui$-rAE4F*k*X#ya<@C;|&b`jKPIapqPZEA$JZP55fdy)KP9jmRj+A%~W@ zkvrC(Tr{Xx$0%USYOcic&ASI#Ia{|7)r#O%vUNl^5^O|BkR>sAODWnP>w> z!1|sTyWnTRZPcI`?L92KE1A;f-M0?E{TTk>lacmhr9EkB&kPJ#)eUiW^P7o@i)nK+ zJ{f68R@#x4cFbTTh+K}0ZOje`BQ>7Tp(7uuaa%JVaFzR3XS}^7KaTYqIiGv?*87H0N!y6gcgVMnjPV;}JzkA$#Gx zggXwW<{_*~yq-HLea?D#$k~Qp6Wn%|i^Y&}{uX8`i{SanT6niew@I_e)!Ru6eom1I z*a3nD>P>}!o}G7XI8^G;93w^yBBhVQ|CV~HR=;3Gat8SFP8qBb$93~3%sm%w{=UCE#_ zp;?r|97w96d#xER+8s7UP&ue%il72B8f8#BAzf1o!VNW?vs=JzFTgI0Q0`s@>$EyW z)DFd~tZ=>F0JYkpp+V)c6iM*WS;EVp0(r&xH|`v|Ip-jm<%du^8hHKyl2xa|{)Bhc zCpZl#OH2y3AAawy^vMN5JSG$VzIY_W>SzGkN?vs?M$K2SBgEpagR|r|U7TyJ%CC04 z&liq{lRn=mluFcyAUZ4HrHV^X>Zvz5sQMYEQN!iXfm9OL+J;%QROpF3c9&0|{1J&p#nkghSqpV!ERixLf(u3qy9 ztv>lq8caYM)reOZe z9G$GFdtO}2=fKx-W#c`H9W#Oy8m~)y6xY3dlGQuJAzG&1FMff_&-Qo=7HuyQmsON- zFiQAjO-DEy@-ucCUwBF{>pDXIvlowqI*!Q&{y;K(-k%J|qAOcQ)+Le`BOzZf+?Tr# z~iUzyp8^rWqak?f?G}Mbj6eDEmR5-%*py;rQR4Jiqy+0D}jp7Z8MMEs% z)#S$LAy#2HikVY9MqZOso>8-Yge6-@Ljutx7g-+NQFg#w@(9g<8z8^@H^J4!%2$l@ zaT%w(~tf+5bjiB1nCjJQaZwZNJ zcxy|b;yCruMZmpGo$RMj#!n*(Wh@AdTX%|$5Z$_44OsDm$=h3htt`Amyk`iE@VOnk zONWVkmf*h-66EXobL%=C#u9t zkU3y2<_pE=&JQ6I{N=#<$ikG0R+^Hmii8u%+(=vre#Kp3o|r<c`eAujJN3j)r z`S2d~H<3_HzJIt%BR&J##{+pSN&|k4;0=Nm^hBR}iCCn7 zT)_V_@kRZ8A>1x^A1}}AVK}})eG4V3CfknN6`CiJ45;AQdb8#-y8BfD-g&|fzwj90 z^(U^t;cAOJh?m5mazb@YiJ1+4bL9jIU(#NZl6w3b$Z6Uqoc#PH-Jmq6bL1t!-))yP zcoWt2C>$`}ydGxvS3${1D>OW2s0d=Rbq)Ktf!|yag{93~Puj|Z+H08T-=Jilm*q0z zlQvZl2gMt%#dtMIiQQZ|qlee)t$0UTQfZ({izwtXDjJ+8Q(grado3NZ_I#LygW-f% z#DL!CkITwfPtSm~H_}F|BS~3G^q-1HV#x%&(|d4F z!6K6sMf>}F^BMS;Xy`@0%{%A`_sUBDQwdoNC+I>6&z!C;{3#N)7C}$I`}K8RM?3ao4j4hWDmS@9IlNcV_hNtlm8=xBp0&(r2Ux zv(kfU>A`$2O@)^WM<2*YPW(G8Ejj0^rzQK8WFIRVZyhTeE6Yf=S*bS7pLbPSn22s2 z;`{RFIG4NI5ZF6J2{!X?1-x%3LrMTO;6Y(B%Ka3uE_4qxi|8RlXCOSy?ei&u1Uc5*3A1*MP%_Mt2j5fWx{T|hS$e9|B v^>FswAImlwL?hmVmq3-EDoPDZDIMv_r`V8aN~`MdncX3%oN1+l=cNAxxqWJ# diff --git a/config.yaml b/config.yaml index 0d6e927..1deaa61 100644 --- a/config.yaml +++ b/config.yaml @@ -64,7 +64,7 @@ logging: output_dir: "./samples" visualize_every: 100 # Visualize latent tokens every 100 batches print_model_details: False - + log_every: 100 # Accelerator settings accelerator: mixed_precision: "fp16" # Options: "no", "fp16", "bf16" diff --git a/helper.py b/helper.py index f67207f..89de74d 100644 --- a/helper.py +++ b/helper.py @@ -10,8 +10,54 @@ import torch import matplotlib.pyplot as plt import os +import io +def log_grad_flow(named_parameters, step): + ave_grads = [] + layers = [] + for n, p in named_parameters: + if p.requires_grad and "bias" not in n: + layers.append(n) + ave_grads.append(p.grad.abs().mean().item()) + + # Create the matplotlib figure + plt.figure(figsize=(12, 6)) + plt.plot(ave_grads, alpha=0.3, color="b") + plt.hlines(0, 0, len(ave_grads), linewidth=1, color="k") + plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical") + plt.xlim(left=0, right=len(ave_grads)) + plt.xlabel("Layers") + plt.ylabel("Average Gradient") + plt.title("Gradient Flow") + plt.grid(True) + plt.tight_layout() + + # Save the figure to a bytes buffer + buf = io.BytesIO() + plt.savefig(buf, format='png') + buf.seek(0) + + # Create a wandb.Image from the buffer + img = wandb.Image(buf) + + # Close the matplotlib figure to free up memory + plt.close() + + # Create a wandb.Table with the data + data = [[layer, grad] for layer, grad in zip(layers, ave_grads)] + table = wandb.Table(data=data, columns=["layer", "average_gradient"]) + + # Log the image and table + wandb.log({ + "gradient_flow_plot": img, + "gradient_flow_data": table, + "max_gradient": np.max(ave_grads), + "min_gradient": np.min(ave_grads), + "mean_gradient": np.mean(ave_grads), + "median_gradient": np.median(ave_grads) + }, step=step) + def count_model_params(model, trainable_only=False, verbose=False): """ Count the number of parameters in a PyTorch model. diff --git a/train.py b/train.py index 4eef807..0892e11 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,7 @@ from VideoDataset import VideoDataset from EMODataset import EMODataset,gpu_padded_collate from torchvision.utils import save_image -from helper import count_model_params,normalize,visualize_latent_token, add_gradient_hooks, sample_recon +from helper import log_grad_flow,count_model_params,normalize,visualize_latent_token, add_gradient_hooks, sample_recon from torch.optim import AdamW from omegaconf import OmegaConf import lpips @@ -308,7 +308,11 @@ def train(config, model, discriminator, train_dataloader, accelerator): "lr_d": optimizer_d.param_groups[0]['lr'] }) - + # Log gradient flow for generator and discriminator + if accelerator.is_main_process and batch_idx % config.logging.log_every == 0: + log_grad_flow(model.named_parameters()) + log_grad_flow(discriminator.named_parameters()) + progress_bar.close() From 0c99d8e40eaca1ad85624ae2048d09479ec5d809 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 04:20:20 +1000 Subject: [PATCH 013/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/helper.cpython-311.pyc | Bin 12058 -> 14848 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index 6b2a828829cb70a94c21afd372bce4b8b312ab6f..8e5dffec0b859989b126fcc941e2178037b22629 100644 GIT binary patch delta 3523 zcmaJ@e@q+K9lx`m&o-QGei&i{#wNcFfiNXm6OECi{0M|5T@q5MGz~kr!(2K0$h(6; zWNLR}8{OCny>)Xkx<>0l$wIYBXkDaDnl|l^O3NR1<|)>jqBhzzEz|rY9;J$wE`##^F@B6;*Zd^|4?hjn9G6o+1&we=;+_UWVvAg5$!Om6)Qw1%yx*CVmbW!L>{<)(FL?lbOWsy*$Yg#YWdS$r$;SR_tS^k zvBLX=w=v8Qm`(=3+aR^yYswca>dWf|rU zy!nS^;ugAqK+0nLw!F~@Rd`rrJlea<_{_h{_ZIb#4{hze7VXAC@UWI}3c#F52~)r!(Vk`>Q)dhv-}{TZORiz?*-zNtYM+ zy1y`D7^=tkWPjZW(NStTZWYVd`BL6Cg~gByO1$f?RY>xk;+9YWqcZGGi0)^m!BcQZ z^gI(+QN&d|6IWTpRhl^emi;WrXq8d9ueJi*+*go*?ehY2hm`}@zf$F`i*sVtdiD0W z;@r(zC<1SvX#boi2e;_*6~=L{6gOutal}s@fl3m+^NfC^AjB|nYur+5yPH{sor)ZjuF9f5j%f)J$FziXc~nrsEOB5}7j#*T0S+nBm`-e1 zj*KUC!o48ON?5yAaT4G{L<*h4f+$Nd{hF&V3iT@L)j?vNj72~dH>TpKppO=72W%mD ztKoe@(?%OioO#u0HWqfy5@+LriA(A6281o!6QaH8M;A!(-q^TZf`TNLnWh3XpaX)Iva#OuD=~|9^cVCEXlc zxVY54aHhz70vHC*bk(#yJgKPqcAhE2 zM}?W9Cc;PM*#Bp)s3dHedZLmjZ(%@2k)}(N;U?n3(v>M0OIpH|CY&%iDaAxglS-_@ zh(-|AMUbpZQP>Uzb`R=g3QGi;u5shE1TMlz>75HH|MKPUU0%C#*P2Eh8Pt(Q9WbDRr@u3`bnabis{gus#l6O6Jk428^IGrS zp)~5tpuQ~XgMn#bsp;KAsoCqDE1hZ7oI%Z5)C>bt$HGho`Lf8DMm}?%_B3VC&TInh zOrxEY)4!l*P-7M~rctA*Y+>+?v-4+@XJL{>-VCbEqS`d7h4{){i_LEySUiyIqQHSA zYpMs7P<5uZ9I}72?`wSv=iV5YA4m>Rfr`bJrO`}TUAC+)*_F>wzqy(i*yI?#hlPXt z;f0^q-t}i%yRxlaFwK;AXUn^j-4FT7#evk|+Nt!BQ|XqzG~b`$`?Gw1vWMy^oA1vx zHsu<3J@nPBaH-jhFP!y-ODF#Nl^r+z8GlRG-}10NuyX9?%bEIv+4_T^u6I&g2!!Y08Si;U|v_Xc({VJwZRG5lj#i*D#O-*B?;8~<&*blyw` zZF%?>C2f?nP-1##leHMMK&ud&#ndAL3^_K2dntnE0H30SdX2k)P#263gY8_5a;_ME z3BJJIGTK8s0;VsOQ8gy(DlWEb4Sxi{A8v>tm1a4+vxn`tO!aW4eSDj$`6&=m2FX#i zVMU&hF#YVHT@V5bvu3|8?7nevcZj`bOz+m(&8(ZZp9P!V{0gKrdWnagZkByqRVRkW zV))m<{Da}&)5Hdidu4uOaL@NI7K?=56A`%kBs@&^9Cfcd#EAZ#D1Sy$mb( zZAuJE=vNkTXj3Ef3exhKKV|%>`HNnBjIJM|q?3~8ff)N*F0uh*uBE}2gXyDu74imfl`So|_I?Oq~fg;2Gl|2aZ;lcFY|<0c`wj<8bRM7pYy`K?!a0=HW0)+^i^b z7tPF=abjxAJeuZtH4o)~9361FrWC0Y{}op0Y1Pgdr1hiv9OF!yf16wj3kSYPhFIh6 Jws}Oo{U4a_k^lez delta 1071 zcmY*YU1(fI6yDj*y}LZcEzT*iyNzAsca%6oOSW4W-2v#Xh9cWzw15TypQ+ z^~_C;RV2m5M==b73W}goghC!}>AT`X9~68ZU;5fUh~R@D2%a-V!GU|`oHOV5JLm4D zpI@)usZ=I%cz*o*eji%<)h4^Q|IDac72b$#ObdUM+c+!^i4saPq9!Kreo9QDwjs($ zo1%jBh+wbhg8Ke%C$eh!8?XL|VZ`rp0kdpZ6K~dwbLQvz!puAhniB5i9=!B`|MbdU z?vI?QPi@WFn&Pn(y&(s@QwiEuK)$sTg4DeMaA7`}`jt(Z-gMKWD2KB5rk~BTiYa%F zWIxQTGgdMajf?Y(_-4mq$(1eWROkZA?RFa`Nk6UldY-sVbdTJJ*4 z2hvQMDd_cO$~)0tQc5`aeNE%;Hudp zyF8MA=T_WGlI>n!!EIE2VID8cxxpX!gbOsy_nqlgdI|~-+$T{2RQ!MZQ(TDu8tPF@ z4I{L=Yif%d*wpsoCxsK6s2Jdf=ixbA;bK1=xY-wPh7gKSL&+TLuCl(_=stb&RWyUk zM6MvQg<+K1lHS=IMhQYjXhUMY>MqS%FAdWOHz<}ZTE{{gJ4BO(9* From 6aeba62d2618aa812851b7b37e1677f3136c4038 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 04:21:41 +1000 Subject: [PATCH 014/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.yaml b/config.yaml index 1deaa61..2be640e 100644 --- a/config.yaml +++ b/config.yaml @@ -63,7 +63,7 @@ logging: sample_size: 1 # for images on wandb output_dir: "./samples" visualize_every: 100 # Visualize latent tokens every 100 batches - print_model_details: False + print_model_details: True log_every: 100 # Accelerator settings accelerator: From 471b1778f5a2988696bb473ff8d604a68aaf0c56 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 04:24:40 +1000 Subject: [PATCH 015/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- helper.py | 2 +- train.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/helper.py b/helper.py index 89de74d..f4f572d 100644 --- a/helper.py +++ b/helper.py @@ -13,7 +13,7 @@ import io -def log_grad_flow(named_parameters, step): +def log_grad_flow(named_parameters, step=1): ave_grads = [] layers = [] for n, p in named_parameters: diff --git a/train.py b/train.py index 0892e11..b9f7b30 100644 --- a/train.py +++ b/train.py @@ -310,8 +310,8 @@ def train(config, model, discriminator, train_dataloader, accelerator): # Log gradient flow for generator and discriminator if accelerator.is_main_process and batch_idx % config.logging.log_every == 0: - log_grad_flow(model.named_parameters()) - log_grad_flow(discriminator.named_parameters()) + log_grad_flow(model.named_parameters()),global_step + log_grad_flow(discriminator.named_parameters(),global_step) progress_bar.close() From 4c15186b9b55c715da7326545aa1427c55bbb9a4 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 04:34:19 +1000 Subject: [PATCH 016/152] ok --- __pycache__/helper.cpython-311.pyc | Bin 14848 -> 14855 bytes resblock.py | 4 ---- 2 files changed, 4 deletions(-) diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index 8e5dffec0b859989b126fcc941e2178037b22629..20ec559aad586a1879fc60363dfec46eae6ba0e8 100644 GIT binary patch delta 166 zcmZoDX)obj&dbZi00e0pwx=mh}8onnJJmYMXZ};O>3FB7{P{VGP${JUSiI~$QU+xxyAX( Hg_ivQUT7?w delta 146 zcmZoKX(-`c&dbZi00ffjx2MTZ Date: Sun, 11 Aug 2024 04:35:38 +1000 Subject: [PATCH 017/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- helper.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/helper.py b/helper.py index f4f572d..7660b4c 100644 --- a/helper.py +++ b/helper.py @@ -17,10 +17,15 @@ def log_grad_flow(named_parameters, step=1): ave_grads = [] layers = [] for n, p in named_parameters: - if p.requires_grad and "bias" not in n: + if p.requires_grad and "bias" not in n and p.grad is not None: layers.append(n) ave_grads.append(p.grad.abs().mean().item()) + if not ave_grads: # If no valid gradients were found + print("No valid gradients found for logging.") + return + + # Create the matplotlib figure plt.figure(figsize=(12, 6)) plt.plot(ave_grads, alpha=0.3, color="b") From c85754c10f81344ef90073d620ed33f55d1a223a Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 04:38:51 +1000 Subject: [PATCH 018/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/helper.cpython-311.pyc | Bin 14855 -> 14986 bytes __pycache__/resblock.cpython-311.pyc | Bin 29886 -> 29620 bytes helper.py | 5 ++--- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index 20ec559aad586a1879fc60363dfec46eae6ba0e8..1fd54fd33418e2a8c58d781177daed18269ac1ad 100644 GIT binary patch delta 1188 zcmZ8fTWAwO6rI^uHf^`9X_G!;ZK?^{il((D^+C|qM=cc*3V!&YrcN5R#5C@1i`7k` zDn1c}`{9pDDH1;#L0!c!`yt|&A1Ns4f+z@rpdg}tiZcmYO)uOtvvZbfg_CtFt~p z{b#G3cAQmBqB%F)+yv>->{%GH;pU$Paec#TVG+#8Vr1LQ@3zO zdmH21qN=OuQLUdbH!#@8V1U7HnCf_jXRV&~Pk1~G-!>k>i!czL!v**nUZf@qokRG8 z<>@*rcsr@Umi38|OlsWFOZt5<7wM_p#;i6o*uh{Ie2Glq&(?|PH7pgo_QKDei4o02%?uLVU zvPJp50>!PZw^Rb%S#_D+bA@Ft3P!wkpL%zm2G<-7prW;jg4m#L8L+dK;?wqkS8uQd zSj&ges&X+S>ng5R`6>*VmS)Qe?G1Ucod3WORdzUrw}}N7q{=+QATta~zB@$kzYOz2L^dEB#M9M2S7XIVOB_waQs$C;7N7wTB;Vmnoel()Rl29MlCs(Ovb0F z#vLt3Y&2zwQ}NieB_{Ovl=;h92ks&EshDNz6o%aGIP5+y{FcqP?)M4|<9ulbPM}g& zr&v-_OHAw7QuY~oOxY)jFm8GRa|#UNLQPwMG6}7N+sN_jI`{13ohW{(*#pNALhCSP zehmF^F}<~>##+v|q(mw)J*M?@=3Wj190oZY!M{zf;j#Iu`Go^6V=_Dp>$o0Xf)6;; z@`;H=ThD-HK5V@y`uS}wXDU6KoSe~9x%mOCj~pmI$fepi^l~_ilaU+n&D@CG2YJ)A e53hBOc({Pp&0)y=-su5fGnD=$`EUW@(EL9qY}FC~ diff --git a/__pycache__/resblock.cpython-311.pyc b/__pycache__/resblock.cpython-311.pyc index a91bc991b41f9a90db8524492d624994e27d249c..9283553816254d62699993a4b425c861e082a079 100644 GIT binary patch delta 2458 zcmaKueN0nV6u`M{--EVLkPq7`Xv<(&p)Dd(L{LT%VVfU_iGsM=(Fat@Yq_uZ(QTob z%hb&&a@b;wnRA&Moo;h)$rh8je`fZPn8kglYqI?@muzNu@m zpL{F(eFB%}xLDEy)pk0V;ZUYG|tS^74ImR6$1gILZWX&G!b8`Yzxjv;1g4(}4ZVR7e1 zL23$xdj?#+VYZw$MsCps(pP-4G~wdtu7$TaODMwqrU_SGt<4m*b`5 zq-0hmW^}Y5tcY!hYA{sJhmJBU*)I>4{iU7T&9n*quqX`;3qnj#y}BS)?&Q_hfD7f$ zHHsG(5C+2mrB4#XC*djMwoN`WTE@IK92gGyqe1}trYex^n8UKY;s{Ay)n?2Nv5!^N zMK(jcv_+;aD%1fyHZ|#rL4!^h7ho4a24=ie0Qq z*I)t(gW4&+KyXma;{}{G2BU5HJfJuZ=h_x=Cz&`Q-)P&;Cl9M(Lj|v91mO(>Yi8B` zCQen3T)9?W#bJt65;|e0Ka@xp!@fT156}*J`VKCA3UT@&Nw9VpMU%Q8&*G=TC;A5k>Vex^ z7A99Wb-Fp3xjiN~by_sUCvV;QkI_86rXeT=MJehRdjv{B?5oJZFCBFvln+M>oa-B0 zG`&)$I6UZ^y6sP(pHlPW0j%nGljRWZ_i)#kX{G#be*-`A9TRWYqU#SRrLm6_d+!TS z@qJROn3CG-=vIiP8;E=y+88!D2*qEAhJMX{C@#-b%3+yQd8ukvRFksa$8_r)cG?Az za5d?NQ{jcre9Cl`L*W0I1h3qo0?k5=Qg}H^%H_WqCFM53m!INmN{=D1<0xwBsM!A5QJVjvX2p`l~X-FQ%>Cb7O!aG(Wz3&e^kP zzjOBNZZ6&@S00e`p|mt@68qi#_)<@J!G-j9)rznJ_Lp};!yK)o;v@o>4Kqk45HkZ^x3oEM5;zZE0}&(MXTUk>O!uNDdKGyN&Ny&rL~*u znHo!TxjX^4D0+lIz~wqiXf5c{D_YXk%gDo||3zM}Ha<+69$Q*r-=Xhg7sz94IcPVIb{G#%u@QcE#eDDl|` z6j?|jmg3py4hDsQNNX@oEnLkhC(la5Idc@P&!D9kp#;H((1tLdfrwkpNGDh`t>3hm zbxo%rHe=6J^n~b(Zt7hLYx56SrsrRZ5fR!XeZg)eHxKn{c&EspCzmO%qp=g=G0i-_ zG+AJ}=)k7Tmwfzb1?j?$vrWg&FqzqQSUvk>(Ue8)z@-qpSX8LAEFvAiXWOJf%L~LZ ziI*9ja}aV7)*;j$vggCob91?UOcT4Xm!v&&Zz;3aGCih=2E|QLLFgCa-fJLiE9c@} z0Uz1Q(_}BTSLh3QsT1zl3P?SSG6cz7dXP-bxf|C-*ei4Xxoi>XM$Ry3DUddn=aOn< z=6mlL{;oAywX^_d6egy{RGv_9Go8Wo7}w?x3T{e!;bO%s-GqGM*NR*&k)fTOlPc}I ziQA8JbpBvOpy5!!%|5*tHHxr>ft8(olpPFW96AIevddIml({^zH6VCf7@0&5T>xVh zojPj~<071uv;MSd5fPUq!>wB}_K#w;O z-a^=mFaS_rsy@cV>(bTw#<8$mQ$_8Vt2Kpk==@)%)Q7hH2nP_}Mi>BpQ?dFu6Jyem zCWq3zi;4aEGF~nmFK%>)1t-sV|JE!vO^Sz&H6bX*vvTo`I)p=85pbL0I|y*J#il;N z^!w76Eku>Gi;4Y)GQLgT!wGBs(%P_dJeky~l8LDTelZeXOyex)$L_moFNG9 zy{(f9^8rqkdqKWt@Yqg8Ev6DSy92T0V94d8ZZGWub4PWVTspaGGOw&LUZH3t9F4^3 zCD?85SgStI!Y50&J1iV|8U9}3(T$a^KM&5g=dYSvK-RLlV#FQv z2vh`bPnp4vLCX>3XQvegd-_VoS4s^=`&?r$${_m5Rf%sI)cWki34(9F`g=CiEPd!} z Date: Sun, 11 Aug 2024 04:46:48 +1000 Subject: [PATCH 019/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/helper.cpython-311.pyc | Bin 14986 -> 15083 bytes helper.py | 62 +++++++++++++++++++---------- train.py | 2 +- 3 files changed, 43 insertions(+), 21 deletions(-) diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index 1fd54fd33418e2a8c58d781177daed18269ac1ad..fc48ac018e58e323340d0db39c77ac7034ca3eb8 100644 GIT binary patch delta 1426 zcmY*YT}&KR6rQ`w%(4p$bYWrnAE7@iAiI>Gf+f&uOQrH)R1$;6abX4+7-s01Sy~85 zt3e+OF@Xaysi|6RYGX}&*asSaV&hXE8jFb!b`sGSdGQ4w^~rPZ1`=*&zI*R?&OP`1 z-4C*VXQN+6qY;V5gWUGi?s~ME?W~_#8-qG(1DU9%{Jkna;s7y#%FWMfFHI@+qi1#-VBuDKgvy3gKHs+b>qLG4pxUPgKSWn0GP>}Tcc&ZhHAheBp)!D&C5j$|#Z0hQ;M7OXIWU-!%wd zVv|nR_Oy9NH!N)-1DSaj7D@AbPN`@QFcsI+)w$8iON=SFS@lx;GPlw};0Ik#cZ0N| z=<;!eZ3pHvmbq-eb>io+DDfe=d^1s}ur8QTwL>?PoL`$U-b6uihac$6^0;WCGd1j!OG zG1!(?9L@mtw~bbFyKV`=VU@*#>hVcimAp-YPjlb4Ih4y>YsW=4@HR0?McfZx(Hia( z_z}v1<7$>UXTaB-kS8q)EXja7e^}!CDV6QumEY03RdHoy{g_w|*Dt3|;@udU%q-y0ft1qKQH;L9VU;+czgD&B_Z zZqaJM%PZ0*k|1ssMFEuyxQG804$$MuTE-yk(U!kecLv-0e{A_pitil#_6t%hwl5up N31rW`SD@0l{{cNzYYYGY delta 1353 zcmY*ZOKcle6!o(^;}3TJN$teJ&IeVKl*Iky(~?lyCTXIupsheC1so?cai$s1w0CTp zKnY#6EC5l$RX1S+RBeGS5E$_br~-l11q&)6pd+Lrgg}T*BqStIQMvb-M)0I}=K0+F z?*I9>(zB)5rC2QDqxHu-@6Ua<5=)5BSGGJqT&*jo#Z-e{qX+5RsE1DbrkeC>Jxsqg z>$UnF^o{FLViS6duvxDnY|+JGpH{bWvm^PdAe*y(w*2p6F9F)W-&YP6DI z%zAP?6)aZYKKpz}n=z-fiUE9Dxkn^%P3gstlt^vxWVkr|AFUJ~R1;zt|5WZG990=p zYA2pm&on8d=ZOW)?}aVXHcps&aUQC$E4V5W_(jkXeYic;BU&&QI#<{3#lSdTGQn`O zdC2LI#8CKR&{HM~xx9fNhW85%{|;~2)y7tpVs6GVTuv#yv7nVjG0VwI3)wl#IRS@> ztD)@zJ{oydgvz%f4I=d7?6pYq*O6xIh<+ea7V2kg6s;vFLJL($I>lbvqiT z*#J8`<-5A~YVOOCQ$9%Hk9DRxNhEaP!TPl5MW|mBt=Q8LS4u?hD^E6T^^^6*rfDG* zyuNvI>k_NqP2dIGqV94#DEvy`VcUU1&N81cV3~LeABev|w{OIgs#wOGcRofoj->L9 zi+nw3J2FXclgk$-@Y~dmcrQtypT(-NSTt-M_A;{%+tVE_W6VbnmF%KfFh&6!(5hgb z{}}JqHI^gOi4Fh6#od0``yN-5J$p<~S~Q-`~e}-712OuF6BDLbV0QHj56( z%uT~Cy6_&UzEZx`?Qo%;o>4KBB}P+aMDPZk;Y|WBLPc`itYt13aG42Z)}mB30i4c8 zv)6mm;u_Xw+)RqE%4zU}jBkzg_&we3S8(2O7P50Td`#q5_->|6r0|=}TZ5IEAd*Kq zZ@NVTvaBl=wLqolm7Lk;#t`5q+v3U;H9D#kw0t35}4gM=MRB-f?W!wut5yvmkeF&f2b){)t&4{{n-~U0! LncZhL@$G*A=ATf{ diff --git a/helper.py b/helper.py index 2d719fa..9256e9d 100644 --- a/helper.py +++ b/helper.py @@ -13,28 +13,32 @@ import io from PIL import Image -def log_grad_flow(named_parameters, step=1): - ave_grads = [] +def log_grad_flow(named_parameters,_global_step): + # global _global_step + # _global_step += 1 + + grads = [] layers = [] for n, p in named_parameters: if p.requires_grad and "bias" not in n and p.grad is not None: layers.append(n) - ave_grads.append(p.grad.abs().mean().item()) + grads.append(p.grad.abs().mean().item()) - if not ave_grads: # If no valid gradients were found + if not grads: print("No valid gradients found for logging.") return + # Normalize gradients + max_grad = max(grads) + normalized_grads = [g / max_grad for g in grads] + # Create the matplotlib figure plt.figure(figsize=(12, 6)) - plt.plot(ave_grads, alpha=0.3, color="b") - plt.hlines(0, 0, len(ave_grads), linewidth=1, color="k") - plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical") - plt.xlim(left=0, right=len(ave_grads)) + plt.bar(range(len(grads)), normalized_grads, alpha=0.5) + plt.xticks(range(len(grads)), layers, rotation="vertical") plt.xlabel("Layers") - plt.ylabel("Average Gradient") - plt.title("Gradient Flow") - plt.grid(True) + plt.ylabel("Normalized Gradient Magnitude") + plt.title(f"Normalized Gradient Flow (Step {_global_step})") plt.tight_layout() # Save the figure to a bytes buffer @@ -45,23 +49,41 @@ def log_grad_flow(named_parameters, step=1): # Create a wandb.Image from the buffer img = wandb.Image(Image.open(buf)) - # Close the matplotlib figure to free up memory plt.close() # Create a wandb.Table with the data - data = [[layer, grad] for layer, grad in zip(layers, ave_grads)] - table = wandb.Table(data=data, columns=["layer", "average_gradient"]) + data = [[layer, grad, norm_grad] for layer, grad, norm_grad in zip(layers, grads, normalized_grads)] + table = wandb.Table(data=data, columns=["layer", "average_gradient", "normalized_gradient"]) - # Log the image and table + # Calculate statistics + stats = { + "max_gradient": max_grad, + "min_gradient": min(grads), + "mean_gradient": np.mean(grads), + "median_gradient": np.median(grads), + "gradient_variance": np.var(grads), + } + + # Log everything wandb.log({ "gradient_flow_plot": img, "gradient_flow_data": table, - "max_gradient": np.max(ave_grads), - "min_gradient": np.min(ave_grads), - "mean_gradient": np.mean(ave_grads), - "median_gradient": np.median(ave_grads) - }, step=step) + **stats, + "gradient_issues": check_gradient_issues(grads, layers) + }, step=_global_step) +def check_gradient_issues(grads, layers): + issues = [] + mean_grad = np.mean(grads) + std_grad = np.std(grads) + + for layer, grad in zip(layers, grads): + if grad > mean_grad + 3 * std_grad: + issues.append(f"Potential exploding gradient in {layer}: {grad:.2e}") + elif grad < mean_grad - 3 * std_grad: + issues.append(f"Potential vanishing gradient in {layer}: {grad:.2e}") + + return "\n".join(issues) if issues else "No significant gradient issues detected" def count_model_params(model, trainable_only=False, verbose=False): """ Count the number of parameters in a PyTorch model. diff --git a/train.py b/train.py index b9f7b30..684b94a 100644 --- a/train.py +++ b/train.py @@ -310,7 +310,7 @@ def train(config, model, discriminator, train_dataloader, accelerator): # Log gradient flow for generator and discriminator if accelerator.is_main_process and batch_idx % config.logging.log_every == 0: - log_grad_flow(model.named_parameters()),global_step + log_grad_flow(model.named_parameters(),global_step), log_grad_flow(discriminator.named_parameters(),global_step) From c3c52784fe4584c6749e103317d6dc34ce1274d4 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 04:54:20 +1000 Subject: [PATCH 020/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/helper.cpython-311.pyc | Bin 15083 -> 16236 bytes helper.py | 13 +++++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index fc48ac018e58e323340d0db39c77ac7034ca3eb8..00619877994d7ed8a91251d60e6ef5da15133035 100644 GIT binary patch delta 3790 zcma)9drVu`89&$eVH;n+0Arh{lQc;Ult*|ZBw0w>kdWrFq*>@nnh}S4F&B)@&b^S7 z*~OJm3n^_uC(%etSGuWbs$?n~?OJK7)M=WONu6Y8?uxF)+C91|+Ww$QD|KC^?K{_) zz#iJH%kQ51o$s9Uo!|GJ?>qcu=&z4EUbWlv5d1E``Bu*l?wxk{n9TInnSJ~i>MW83 zK(FMMEMrWkPqM<7U&@zk@Kr1oNOpLaNDh#eN>0Es$pu(0F(*;5aQdt2(ry!-dF5w; zHw%Av4)6#)B`zFfb~5A;>k3>H$Hg8-V#h3T(|&XnRIaZKK`0?;Zt}9Jkd?R_?{vt~yram_b@jjEbnS{BG+Y)BYom&c_zQSBP zZk8N3oL7O1>#KY)C~ltfkl%A2!g2YV?#t=IxYccGwQHn4lHTm28+0Ws>z46@3I2Zf z7}A>Vs>j*5DK2EUl#~7&wN`EIvL!jaxHXKV{O>b23f7yP$eR$g7KF$N!O!d>ZNg^2 z4|xyi6kPUQxzVtJ3bK#&%Z0Oi!nSUOKu#xaiRVe4SySBh3L_P159H(sfz|B#RTbtH zCeGaOu9q@MLJ(v=@KA@osekkR>2WUx1+TZUzW8ivV*9y*nVY1FMF2@YVEqyRKlD6%~Y8&5C^Q-X~%)DK~HdY@@hJeg%GOBaXf z@zePTwN>XRU%E94t7NI85;<@0OGK0+OZuP>J|VX(bk)N(!thio_TIw89z0( ze9#Jo#^JfVm1C|noCqJ_6_ULZ9IOV2mIE0l@v|A3E$b=5r#{I2_8QL6?O>iUNmmcV3jQW&xv2Wf5F8NP%NAsJCfgPP z4lKF?%ZTNh44B}?neqBXn>g9Jj7-+1nc-_>GA}Ude(uP{BQx&VqL+%FFTPSbTl#^& z@|M3c?XS-Gt3UA9-}2X|{S6s^Lt1Rih>a<+@r#daE|`LME?Mj;M^oC;oUt^gEX|7+ z$GGshfLsMrO%o%NBgyu~+)0BIc1;FIdmA$#;|f8Iw~ReFb)E<|*TI9&v|ZjbB~Mpg zs!WSz8L=!?Q9G}uxLs*(SBBdKjU~=J_TV!^mru@kraLcnrk#O|GcdPherJl?ljinh zxINGy&)U3OYSLV7hO13+wRFs(l%sWC1js&VZf}O$o8tD8_iT@heoPgfI6QVZd3aes zqQXGr{-WXC;hRILj-K=b;miYJkfd$B8C!31|B~pQJUsJ_R8`l?Bc4r*XEWm2WE+s~ z->~E>nQ6O{H=DQQFPZhud9IYtmM@i*&9=;)N|#h+N~)I1DrP(8+OC|LJ+o{R?84|W z0!+4jo{#vvu|tdH8y3qqE%UrnSVj~C1U#E*%Yv-M@0Zlgd)^k(C9Ro~)@AFeRvxnZ zmu);>{3!zXi~`n)yT$(o$CQ{)oBEPC}xe~0lCJ*T%9gji;CisL97vikM#JLg3P9rd)pJo$;%i^Xvjuy(Vnc_UuUKR>G z_X77KcNLbqzPcoXCTrHs0Ew#KSQnS(X4uOAj~O10o3KX`NXlMBZziKrMlwU0u^hWi zi^53Ycua$Yql6*>`NCi%CMnTyAeU_dN;L4LX=h-Bt=T58RUHmR6}9(YRdptBWb3h5 zKvf{I_dqp(B!71*BXI>rB}f(Bnk)qc+)lZu1JGGjlW-gTtf8M=UrdPx&G_q7a-2$p zY}V1O19B*8oO#^>V~iH2#SE7NsMIklNg|uCkI*g(hiUiw2#rF*HZsDJEB}&thM#s^ zawHEeIo*@PQ}RS&GLdZm*jqf)JlB-=1~cB^GUE7$m@H1%#`$q=(dL{A&V=Ufy%I_1 zSKhK!rfijqV*YbSFCI-5AAkE406gzX**9QXJdqJkq>N|r5xaZp&>S<@{<4@ZtiENh zPT8yfztLZO6xfooof>a^uH|COR7cuYnz5Cp;Hf^9;!a`b!f(Y_dGBo+J=wy>=*$I` zenRac&*wLIhA0QL*d@`;eK$kz?N1`?B!9?PgrgwD_mL9?0cIP)1wUo%q^8gZ_hp10 zUN~G>ZDO1Y7mCg@f=m@V$?v=es-L1w>OFi0VC9ODEtA>$Y1B9TG;}^ej{BY!XbHsf z!iT<6fq9yIRsJnTCZ{XHo+m-Bo1?L;TRD2ezz4|(6%AmEEwHnQx=t4#8iFVtklQeh z;h+PbC#M5*K3W)bQ4I|YM&vFmcgLdmbr3zZP_W@iD>F}?udLwlX^@b2Dq9{hc1hy| z|CmBO1sZSIuz*+8C@p-U=wo;rNOce>(Vn|eiW{izK?*cf4I|pC>VyQ9B+27dN1Euo zD!n7*DpmHKEm!V=7~FSa*w`m6r7He8sJ^-IkE$3YI9Aif?2dv2js~>||AfL@0D2x! z5>vY(O23TXrUJe@0-Fvp*|5(!dAT;g{F%7x)H;J}<2Go*hJ80-lu9p6D>oV4+#8Ga zclAW^FG2V*d82Lv<0QYSdwoYX7U&{ywFoOv%5fLfH5B0ra~`^~b0(p6*YD(=dLRED zh(05FeJAXvY1?JiKToykzYRDaKsOI7SdkPJ%jA=$dI$DXA+IUnUJbuN?rGk5&-bYA zOBCiPGy=d;A36uQ2?#UTZyXJ9 z1eoDCgAh!2jOo(Gp)(4M;qEUjy9Z*@P(@LeL3pT12^H>Li0b?UKycTR4Vp_@6H6EB@nK10&JA(~$ zT}X?ft>i@5H&Jm!6jzN1DO9b5)KDdrBK1$CRO+s?lF770mikbwRt=R~sY?I!+*z+{ zu9$PreD|Ak&OP_sv;6&gf1C5X=XN_N7`r%#c=5r^9%aw5NN-2nC$gNN(XwT_v5Qfg-f)H zj&0`#7~A3qDd;L1(?dm#d{CfQiuQuN;!>IsUEA*e`@ma_gLy)>ZF@E#^eqdqi(U10 zi|93@`!D05eWrGt1+avu&eEbuvD91|X8k^z(ql!KqFye;KR#9d{7qi>J_D8)^)xRY z*?z9zv53)Bn!o4A=i17;Y#WfvEs{@jfaw!O9ZUy{y53yW0IewMX&(QB4STJJJcmd) zl{6)mtx$SI1T3()%hkG{#cQyS4~B#5e45^_IP3<(rl86~_;tQ43pvPktgA zY}X~E%M*eUqYWEUbwQWa6!2I@n$!&r$;m0*;Li&S64I6yedmBKB&BE{5=2=_>5J~7 z7VTBk8zTlYn@WNyeo{pUwElPHY; zV6S~~R)H))y62E&u(OH^Z&M_bn!zYi%HZcAhtnG5OGqVwF|Wuo04l<`q!|1{VdivM zS70%TWD%cp3@<5a(y*jv4Q57|2O)!<8A+6dRLo)cQUd5>vjT!yhvcI=622Y6H{;R_ z-pq?;&?$ok8{>161|!cTQ4Qg7x-hOt8tEjEttSJKA~oFL)--{KoP6sDKt6c1ML39M zYR^}?oY-@FmwWeo(F2O%+D-98HbZAl>^WVxx(+Cdy&b>iJi1i&z*TwQRk=32{;hXL z-x|GpdHwQ(aO3@OV=f%ehvN^zE%(DMx$voc_*Bl-mUp#fU2Xq)=zI<=aNd2JJL_r5 z*;@0q)~u~{&*sVS|K=&qwlc7KBg`wm-p6f z_U?{m+5Q~cpJ)3)F+HqSzummHaQE!`*(_U^W9#y49Vn)a%zTaw)$4Vdp)Z*`8X z$=9$oS+<5CiHw?KBY8HGWh17s%*ZQ4D?`ge_&2UgONCt}UKv~&Tpm2&DOV2-h556K zzbM}g<{DneH@pC^IcIm?*}dGo?<%=9xHhucmu>0G9`Db(5;<2Q?@BE95HpUI#9pL& zFLHE0P_fRhE#v~Rd?0rCB^X-w-wWn~$MeDC`=Ri9Tl>9BxlmI+)O6ssdiev22oL2h zIq+J!(tlDQ{zU{VJS$Lir?2Ea&TqDk!gZt#dR`+JI*KBg^Ea1`k-8xzwoH9;?i1m+ z-OCh;;tPJ6KMx=}jz988={Vl^e;6?D@1su{}jPi~{WsEdgU1e^~iCnD|DxlBbuja`Xz^ z?oW><%NEGz)|6_NJX~_}!2s?%)(nw-_gF`HA6a+31cAvjQV&8Zig}QXKRtFYUICcheJgNP;~>WMG|;AAZ& zqzpz#Euc7{1^{wu67>@X*&#|0F+fBQ2s4IXHN|}X$fr;fSXWY z&oLU2?L0kk;bfG~;_si-;!%JLC1F8kd>;`W^bd%v5RG+8Rj0=$Q|Q+K{tnl+RMS!1 z*YfkWLSX<%+H6vWmW##+*CfUmB*Rj{CIYS4?V;WpTLaw30DXZ6TSs9(8?A3L;dcQ@ z!UN-B7jnX6Sg*^_US$pOc%nl12y^#Efp-qfCYL49j62e5hK?M2V!mw#`%IuMuGG{@=4{g diff --git a/helper.py b/helper.py index 9256e9d..45ee750 100644 --- a/helper.py +++ b/helper.py @@ -64,14 +64,19 @@ def log_grad_flow(named_parameters,_global_step): "gradient_variance": np.var(grads), } + # Check for gradient issues + issues = check_gradient_issues(grads, layers) + # Log everything wandb.log({ "gradient_flow_plot": img, "gradient_flow_data": table, **stats, - "gradient_issues": check_gradient_issues(grads, layers) }, step=_global_step) + # Log gradient issues separately + wandb.log({"gradient_issues": wandb.Html(issues)}, step=_global_step) + def check_gradient_issues(grads, layers): issues = [] mean_grad = np.mean(grads) @@ -83,7 +88,11 @@ def check_gradient_issues(grads, layers): elif grad < mean_grad - 3 * std_grad: issues.append(f"Potential vanishing gradient in {layer}: {grad:.2e}") - return "\n".join(issues) if issues else "No significant gradient issues detected" + if issues: + return "
".join(issues) + else: + return "No significant gradient issues detected" + def count_model_params(model, trainable_only=False, verbose=False): """ Count the number of parameters in a PyTorch model. From af949d4b5e05c7d579c0164547815e5c39f3b187 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 05:18:02 +1000 Subject: [PATCH 021/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/helper.cpython-311.pyc | Bin 16236 -> 16358 bytes helper.py | 138 ++++++++++++++++++++++++++--- train.py | 56 ++++++------ 3 files changed, 159 insertions(+), 35 deletions(-) diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index 00619877994d7ed8a91251d60e6ef5da15133035..1823dfb37079e609fd787aa7c4994b194b16c757 100644 GIT binary patch delta 723 zcmaD;_pF|8IWI340}zC6-=5~Cxsgwsk)IvNWd`EU(^wf8+Ics7FnTjG^LFrVp2oz? z%*CeZ#Fap%Nr8wOo@Iu(SfxHP0EsUk0;q&l>MI{u=LLR+ zrK(G`R)k$tGP$B;a)HluhtCBT&xhelq(7Z5s^S*vaWz3zc+{nB0zG*Hzwx_?2%{Z{lLJ?YI%cSw7|E7*_IWI340}xEzx;<^N#zsDE#>w4$oNR8a3=Hi&n>`r48JT%Hcs5UCVrFLM zDY2ZKz$`WS4lD2EJ?5N~&#)$|>Vh;wK$ZfGRl`xkl_ENasfK$gV~SXb^JD`y8IDZG zrHnO9<6aGg6bY}PL|?M;VR+*T3O5sBqmSfR^o~PG6X;(3pelOc4vxZm9cd4X7O4^-^qp& z{}{tJ8%R!KW{jGASSElmcCv`90%PN5Q`ukajNX$CRL?NBZ~m*goRcwn@=|>pM&Hek z^w}91w{CuAAjrb_Z1NwIMT}1;&o{jYG_24pnDNWzqh@}rOurd6i`uX-v4KLZ$Y-;> doh+l0B_q(AaH5Qnn=y>>0|OE%F*(k@0{|F(q&xrs diff --git a/helper.py b/helper.py index 45ee750..ffd8a99 100644 --- a/helper.py +++ b/helper.py @@ -12,10 +12,125 @@ import os import io from PIL import Image +from mpl_toolkits.mplot3d import Axes3D + +def plot_loss_landscape(model, loss_fns, dataloader, num_points=20, alpha=1.0): + # Store original parameters + original_params = [p.clone() for p in model.parameters()] + + # Calculate two random directions + direction1 = [torch.randn_like(p) for p in model.parameters()] + direction2 = [torch.randn_like(p) for p in model.parameters()] + + # Normalize directions + norm1 = torch.sqrt(sum(torch.sum(d**2) for d in direction1)) + norm2 = torch.sqrt(sum(torch.sum(d**2) for d in direction2)) + direction1 = [d / norm1 for d in direction1] + direction2 = [d / norm2 for d in direction2] + + # Create grid + x = np.linspace(-alpha, alpha, num_points) + y = np.linspace(-alpha, alpha, num_points) + X, Y = np.meshgrid(x, y) + + # Calculate loss for each point and each loss function + Z = {f'loss_{i}': np.zeros_like(X) for i in range(len(loss_fns))} + Z['total_loss'] = np.zeros_like(X) + + for i in range(num_points): + for j in range(num_points): + # Update model parameters + for p, d1, d2 in zip(model.parameters(), direction1, direction2): + p.data = p.data + X[i,j] * d1 + Y[i,j] * d2 + + # Calculate loss for each loss function + total_loss = 0 + num_batches = 0 + for batch in dataloader: + inputs, targets = batch + outputs = model(inputs) + for k, loss_fn in enumerate(loss_fns): + loss = loss_fn(outputs, targets) + Z[f'loss_{k}'][i,j] += loss.item() + total_loss += loss.item() + num_batches += 1 + + # Average the losses + for k in range(len(loss_fns)): + Z[f'loss_{k}'][i,j] /= num_batches + Z['total_loss'][i,j] = total_loss / num_batches + + # Reset model parameters + for p, orig_p in zip(model.parameters(), original_params): + p.data = orig_p.clone() + + # Plot the loss landscapes + figs = [] + for loss_key in Z.keys(): + fig = plt.figure(figsize=(10, 8)) + ax = fig.add_subplot(111, projection='3d') + surf = ax.plot_surface(X, Y, Z[loss_key], cmap='viridis') + ax.set_xlabel('Direction 1') + ax.set_ylabel('Direction 2') + ax.set_zlabel('Loss') + ax.set_title(f'Loss Landscape - {loss_key}') + fig.colorbar(surf) + figs.append(fig) + + # Save the plots to buffers + bufs = [] + for fig in figs: + buf = io.BytesIO() + plt.savefig(buf, format='png') + buf.seek(0) + bufs.append(buf) + plt.close(fig) + + return bufs + +def log_loss_landscape(model, loss_fns, dataloader, step): + # Generate the loss landscape plots + bufs = plot_loss_landscape(model, loss_fns, dataloader) + + # Log the plots to wandb + log_dict = { + f"loss_landscape_{i}": wandb.Image(buf, caption=f"Loss Landscape - Loss {i}") + for i, buf in enumerate(bufs[:-1]) + } + log_dict["loss_landscape_total"] = wandb.Image(bufs[-1], caption="Loss Landscape - Total Loss") + log_dict["step"] = step + + wandb.log(log_dict) + + +# Usage example: +# model = YourModel() +# loss_fns = [nn.MSELoss(), nn.CrossEntropyLoss(), YourCustomLoss()] +# dataloader = YourDataLoader() +# step = current_training_step +# log_loss_landscape(model, loss_fns, dataloader, step) -def log_grad_flow(named_parameters,_global_step): - # global _global_step - # _global_step += 1 + +import torch +import torch.nn as nn +import torch.optim as optim +from collections import defaultdict +import numpy as np +import wandb +import os +from torchvision.utils import save_image +import torch.nn.functional as F +import torch +import matplotlib.pyplot as plt +import os +import io +from PIL import Image + +# Global variable to store the persistent table +gradient_flow_table = None + +def log_grad_flow(named_parameters, _global_step): + global gradient_flow_table grads = [] layers = [] @@ -51,9 +166,12 @@ def log_grad_flow(named_parameters,_global_step): plt.close() - # Create a wandb.Table with the data - data = [[layer, grad, norm_grad] for layer, grad, norm_grad in zip(layers, grads, normalized_grads)] - table = wandb.Table(data=data, columns=["layer", "average_gradient", "normalized_gradient"]) + # Update or create the wandb.Table + if gradient_flow_table is None: + gradient_flow_table = wandb.Table(columns=["step"] + layers) + + # Add new row to the table + gradient_flow_table.add_data(_global_step, *normalized_grads) # Calculate statistics stats = { @@ -70,7 +188,7 @@ def log_grad_flow(named_parameters,_global_step): # Log everything wandb.log({ "gradient_flow_plot": img, - "gradient_flow_data": table, + "gradient_flow_data": gradient_flow_table, **stats, }, step=_global_step) @@ -84,14 +202,14 @@ def check_gradient_issues(grads, layers): for layer, grad in zip(layers, grads): if grad > mean_grad + 3 * std_grad: - issues.append(f"Potential exploding gradient in {layer}: {grad:.2e}") + issues.append(f"🔥 Potential exploding gradient in {layer}: {grad:.2e}") elif grad < mean_grad - 3 * std_grad: - issues.append(f"Potential vanishing gradient in {layer}: {grad:.2e}") + issues.append(f"🥶 Potential vanishing gradient in {layer}: {grad:.2e}") if issues: return "
".join(issues) else: - return "No significant gradient issues detected" + return "✅ No significant gradient issues detected" def count_model_params(model, trainable_only=False, verbose=False): """ diff --git a/train.py b/train.py index 684b94a..bf30af9 100644 --- a/train.py +++ b/train.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.optim as optim -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader,random_split from torchvision import transforms from accelerate import Accelerator from tqdm.auto import tqdm @@ -13,7 +13,7 @@ from VideoDataset import VideoDataset from EMODataset import EMODataset,gpu_padded_collate from torchvision.utils import save_image -from helper import log_grad_flow,count_model_params,normalize,visualize_latent_token, add_gradient_hooks, sample_recon +from helper import log_loss_landscape,log_grad_flow,count_model_params,normalize,visualize_latent_token, add_gradient_hooks, sample_recon from torch.optim import AdamW from omegaconf import OmegaConf import lpips @@ -53,7 +53,7 @@ def get_noise_magnitude(epoch, max_epochs, initial_magnitude=0.1, final_magnitud -def train(config, model, discriminator, train_dataloader, accelerator): +def train(config, model, discriminator, train_dataloader, val_loader, accelerator): # optimizer_g = torch.optim.Adam(model.parameters(), lr=config.training.learning_rate_g, betas=(config.optimizer.beta1, config.optimizer.beta2)) optimizer_g = AdamW(model.parameters(), lr=config.training.learning_rate_g, weight_decay=0.01) optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=config.training.initial_learning_rate_d, betas=(config.optimizer.beta1, config.optimizer.beta2)) @@ -309,9 +309,12 @@ def train(config, model, discriminator, train_dataloader, accelerator): }) # Log gradient flow for generator and discriminator + criterion = [perceptual_loss_fn,pixel_loss_fn] if accelerator.is_main_process and batch_idx % config.logging.log_every == 0: log_grad_flow(model.named_parameters(),global_step), log_grad_flow(discriminator.named_parameters(),global_step) + log_loss_landscape(model, criterion, val_loader, global_step) + progress_bar.close() @@ -364,30 +367,33 @@ def main(): ]) - # dataset = EMODataset( - # use_gpu=True, - # remove_background=False, - # width=256, - # height=256, - # sample_rate=24, - # img_scale=(1.0, 1.0), - # video_dir=config.dataset.root_dir, - # json_file=config.dataset.json_file, - # transform=transform, - # apply_crop_warping=False - # ) - - - dataset = VideoDataset("/media/oem/12TB/Downloads/CelebV-HQ/celebvhq/35666/images", - transform=transform, - frame_skip=0, - num_frames=240) - dataloader = DataLoader( - dataset, + full_dataset = VideoDataset("/media/oem/12TB/Downloads/CelebV-HQ/celebvhq/35666/images", + transform=transform, + frame_skip=0, + num_frames=240) + + # Split the dataset into training and validation sets + train_size = int(0.8 * len(full_dataset)) + val_size = len(full_dataset) - train_size + train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) + + + + + train_dataloader = DataLoader( + train_dataset, batch_size=config.training.batch_size, num_workers=1, shuffle=True, - # persistent_workers=True, + pin_memory=True, + collate_fn=gpu_padded_collate + ) + + val_loader = DataLoader( + val_dataset, + batch_size=config.training.batch_size, + num_workers=1, + shuffle=False, pin_memory=True, collate_fn=gpu_padded_collate ) @@ -410,7 +416,7 @@ def main(): accelerator.print(f"{layer_type:<20} {count:,}") - train(config, model, discriminator, dataloader, accelerator) + train(config, model, discriminator, train_dataloader, val_loader, accelerator) if __name__ == "__main__": main() \ No newline at end of file From 2961d7739a75d14676d40cb59ce6cd489490b7b9 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 05:18:10 +1000 Subject: [PATCH 022/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/helper.cpython-311.pyc | Bin 16358 -> 23231 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index 1823dfb37079e609fd787aa7c4994b194b16c757..a58105d446b5f368f9f272215508f9f8f4956b82 100644 GIT binary patch delta 8173 zcmbt2TW}l4k-K;&};8JpN zx$f0HASptUA9Yy{nlsZqGu=HsJ>B!$)Zf3uTEAnlR5I{&{^$pX4{SPbt<#)g&#$=X zW_W$fHB@`3mSuPs&z)fp)$tW)m_zRKtL1(+RL@sJ!^4{ZHt@B)1qd2>E3~|P9dCnH z6Yu67@ZZN*L$9BA0&M1M0JiY#J|qYSp-Q+ojd9L}nmgib8GCX@m{%CzbeoSAb)AO2J4Fj)TY_LYz=5@)NHkggP?hGRt z`s8(*M)|*VR?S4qEtDQm#zjoKZwp z=f9(iB1k$DBk3gVWf0lX!@88`_ze~ z`AccSNsCxhniur|8@N0r86;!5X8M??);VRiW6+i=@4(p>an|7C9Fyz#%A_5B4oL$mwMofpsTy^bbS7&wjKoOJa_N#T(G6Ow z|2j|~poF!_Iv}i*TvAQB>XU9zmr1JqTjtVI9x2bGn$e|w?yN;Vy{cJWS8X;(c4@xZ zJZ1q#?2`v22XCfz$7FJ;T1y)k1y&l9wO@HL$5nmk306=knWTnWtn-em)t7Z(w(u)` z$tKC0b>dtu{#PP&MToP?W2La@_@0izsk+QUmVr9R!WpIyzQrx&!6Ox)HTj zYP=ma`Ag7|Px63rqRB>a{elDxbJQdI9j?x#U-U2OY*?sEe!lLiyIe;gfIr)aD?RUX z=pi$8s3KM2tEMgOwP5nRI(3ej$b_S{E71?FPvQLya_Z+8Jp;V@4}{t zqR8L2I^_NhRjCrU1pNc>sTUI@d_=7zz*0OIiyn&*IFZaeJXFC)}+xontZ4p5nK=@TekYBde-p)P28}v`)wic(L;I5o>0s3Yi zfCm^B_FZChPkE9oGLvF?W(=<9ayvW7Tx3~hvH>srE1Da)`ww2x5Ogo9Ru~zAx)tjz zzf{4WRrCx-;*rr|vTbbH7hA~!U400qxw#g=zr!D>x_ nd6t%Pv4V$c!qs-|BKtF zvFo;ScZJ^}=2w@J)B_I$&nUrDwNxzPgO0B%kOybLV_(Lh?1Y0uO_hRxma3050iCLk zC&*BTfcL7;H~#;yRAl;QBY;oQRNG4>%VzwU$oarx2hpKr$Yx`A_=_0n{1X^K9!MJk zhy0WE&IWpfcR>3E_97ESUsIn1bCXzb!ty1sy|cHuVp%ik0jCA-hq{*lxS#m2S~(UW z32-)Wrsox>vQd?Ntc=soDtq{7xa4`h?2IoErnVuo!YTl$ISy5&&=m9aKPb0|WCLaHyCotZF=_i5!DbH9*)F$`ISm<3a%z05@=KPLkS)x;;5}0jtgofG8#*S_y|!ef!pA4 zA{rM3RUe8C9|@^d2@)NQ#$nf}a*%#O(H%?bohm!3vL{vcag{wpR-=^cQG&6!s^L3S zO^WX%>#$|QB?dxb_(%lSpcT~+jSr8&l14E^1|wKWjEJ;{J*FDxIaPI7Q4PpX5cG={ zGA{9qLWh4U8NH6$2z|U|EuCMYtS*fadG_(r?F(9n`C+)pzm9Z##P>>(`1?4k~u!*S4<6gQ^V=L)OhdsanM1m21w+OFQq+6WdkO#^D7agSjM3t! zAfqodFF#+EF-#c?^-WW4S;LEKrq^VQ*P2(&q;@OK>+{X)i;UL3mV)zpGQF7-1yA$T z#_YzK`xH-i-qW4wD>!R1$EHG=W0_+Y9?W)RAI^5=oXer)oy%vs-@NC$_gvX>b&KNM zQ0i5j%L`6#w)H!KZv|dm_1Y@MxhAtGv!`J5OdVHjzKkhjf_-y&QkS}>Mqk`Iy;Z3R z?NdU|)&eDOYssEcY@MZ+Li-9WYi3PK%a(L+ zIyU*xRPR)1YNulM!5UtFnZmnfgGyt6zOf%ySy4@G4DdN&1W?wZv;c{E#_vj!AvAa4z%wFQgqTyP>d zwQpvBI;dDW^OnxErch-+mzqd{B!1A|=xa~Ro=m4wN>y*ZsyD4GaFypQ6PC=7!nNkP z)*RRR`8|gft})Lw=D0?rF;AFtHC;1x0H7#bcb@Ccaoq*Zgs9AZg{#kV^*OG-z?sjP zCQO-~3Rjos>T+08bznABwyNqPqu19`kk($a+9r>ljZMa&%HiqSnbqgGWeT?}&%y1v z3}@#ixQzWQAk(Hhr^ZngnU=HGNoyMZJ~ma2KTwLIz%wm=hWnw4!O~3?z(s~81@l;V z17Iah`J+Tq2QN;Hh$9xLbUedrp4GVJVxShERJ3|gFBiLNx^(=KOrO_Y)s-h{lAhQ1 zz$B6uY6h&qxVJVs$cQ?B5cYwzuG0PVn-19CEqA0U9zpMW@FfM7ilJ^TnlC`|{HrUUt=14S(zom!ZKZupJ=g>TGrhkW}n=JH`TIR$I0&KOj0 zWuZYDlfktqMur1=db5#TKu3}YPz4a$BnATp@)&kJj^Gf2aVq}^5Q^1*8DXft?1iIYegGUe7-9Gp`v~*b49lD~iCg5S zwe=ctD?J#$(AzCr7f9h@r0IEWomE`gTOSXEOYnqysNCgWpUohazPDII;lL~;?#h9_!q;#nmh6(l^ukrIW!s^ zBq6>m4Iy^cqLH{5j0(a?LGHbwiO* zoZN?P{V;(*Sk(f{1l=duL#ZIJCih}hd>)P^P6Wl!0L1b7{q!G)nE_2Ev_sK2!8}C1QQ2xCr*Wii zw>T6FnB{s?XQ~u#+b}Z@#UOd&OPZ!$OVvh)235Tvh9JB)l(a$8FoS@4UJ?z@N4*w2 zx5!lxIU+v*qBBVSLrB{`1@{HPXw`Us0l=<;%?FN5zkcpR)0^xzhy*xK&e(j-+cJG@ zcEjw7cMmF^J9FMX#oL$nLX77D3VmH(=gT(abiKLu9S9eS{12ALwAZ{Ha+CE5)-iX= zdWx+W!GOK^%7xeEKUS?=XubY@x!Jb7VgwVef1Q5TCqH0Yxf*ptrUAfni^!8cY|^dO z9fh1GU?j8H^A>{d%dgt%RzS)eIvFA4CFl~20Ol1zeHXckgb02ne`MQie3UIRI{l-v z&)%9k@Z+{z{Q)TNb(j7FRLlqS=7Tv}3Ko0j{w$l_{i<1UuKvKXI%ip3FgwnDZQ^S{ z`9*_2{n%(hb#5(G)toyuacXLVQq`QVYR+-ZA6NC{s(SEz%>C5RfCWqVgwRd z%uwM7pd$}hsJoX(>t6EWGob1SCr02L(d=xHaTQ0G8aLFm=VE#gglQWLSAa<6jx@;f1u%{bQrfpdlXl zmn|;#y8PRgKl9-HW46&Cr2C1uAd-0)~;%l zmY!YJ5E&IAGbTyd*tUW_Ew{Hh4fJ~Jm-n=t*3$%5e!s2OO=UL=p`qbeq_iS2fR+vH z-2b$_q1BL%xiv@E#7cGi#KuLWGwZhVqf<7w5Lve66@sosV zA3CD(!}MI26!bCN z5S8141HIj%H^fl_@#E60{Id;@gLuA;FK8P(P>{bwfOiAggn*jepF>Ul)#jDfb5Mcj z?daeUkyNs>b&J3CKauK>2tGxCE?wp3pIn5J!e5s6ZP{r1HloqHm)!QF@=II%j}w9& z0uHi?iDQv?aM5acau0SJ=(~#`VDvigM=(ePjO&@GFhX52gk4~F@KS>_kZurcY|h!! z%`QWI2Q&nI5Wxuonq9kKdr&C&hw=tmNsT# zuT%R#;XdPfz$e&7Vv%j+UZ94;2+wNdquX-Z3(WFdc`534v#hPi+yVD%%xqneSp;h# qUzEjdKVNNNYCSoox~SX1vU|V-m;LM0SlO*bW&y}gZvQ5Iu>C*E*Jx1y delta 2685 zcmaJ?YfK#16`ngYyUSxAylrelcGnQR*hU6oz={KQF!)7c>_8q)>}A*)mNC0C>pQbH z!AsdlmgCq-8|x+xjjSeZRi|m9mUi1XQPnn5TgjF310*V;t6Jhm<%l0`sQjZ=rKmk; z29vZ^dw2G`Gv}Uj?m727?r(lU?#=|?3j|687{j-2jrGT_1viR|!nJL0?e|<1PS&Vi zzzu4-T6B?|+^81ASFKv1mcrL2wNeehyjcwbb&DDTtW(PX>s9iU5UXNyWXG0-h`QH? z4`vDe_pW;c;pIT#@0*VjDR4PZ=&QRSZc)oG3ij>?yn>LEa$?SV69hhx*qm5PirE(r z)r51!NkOgn9{I5HCg}cPUH^2+D*=0tAh5IHTJk8QLaR-z9HsHe?W5N!y2|dQrEe@s&4o4oXD6LG!lT=YHJehDf&z4Jeux-&OiLj&5#+G}i zC?>A>VOsG5yid4qOix<+l=eBU3BqzujBt;V%8V=QdQ>UrQpQLE$vbA!v|>|{E&HoR zD}nP~7~m}-DddDxr7ysS(_~tlmU0A6T1raS#8>@Kxdysc$`n(Bds51a8s##Y$%_rp z>RK$rVi*PuzzYqshUmj*I^orN7$7gqmyWy7LagV9)pPQketG!y1GBMtf5Z9`4FPLR zqIEFwGUc2`m89lq*EPfDmAILb0$tlL_edoomq?IZ{WoL_7aqz<}%+m-KX!Vv=!>?Wpot zwqk=3uD2AE!wLdmG}gyHCoVhy8#!kXA4{3%=}BNKL;g>obnfSc)gmDf%J(f5o^Jdl z>0b2;ew{#ge&7D-x!c+Kr^gn0lZ(Acpe&S*FP4ty58UyWzj}0bXnuQq?eUK+_(vA~ zBl$jdtm#>4seUW_QPcAzRPb)QK*$uk5^DzhRjf&zYGQwhHP0M^?ddK6E?Krp`|z^` zKjkx~Zp3_a5GltIo&QsK;6Y$O}v0WdEJA ziZ^_-Y__@4cY>@6vgZW*U2{Wt@Z;h6O@r`!s=71ag8%rU|Mji5>5)D$Hzu&9h+zn}l0tgp)}N zKcSLXV{3c$H_&b@y}WRH|83>719^dlSh`jBJ_$s+ll`eROxoC&t($z@lQvZ?#kRD) zU5{nRD-vcF>cFAKGYVDG7QF}yBZbAb9U?@9>>eRr4fVXhk{wadV4E~C&c4}ExASEr zV`S)QfVGz7*0)>uPA$@_uy{paOC7Iyv3{vmnCO(f;)=-5K2}4nvdfR1j^GYcWs`BL zB}~JzX*OYNY8cv*FQH^?-82~KK3IJTxVXPiuXYnlu)q=aD(e?smV_&WX;6_TSa8=7rx<(wUUznniyCa<2=8 z8$Cak#Fdi5K<^(t(K(=f{`R+Ve9|*jen?M_haa?;ux^CuDb{;#0xMk@p}1=0Ezt77p+7P9r7Hd`AX4Q1AcrqdaXI zIt0!+IP82D{VnkSS}447#K17$ALt{e4OH!QOrjs5=34-~1P;wgPXqaPyyniZT% zjJWl*mWj9^kDKPiR5fCJ%xKx0G80P5>TW^87Xo|ZN!scpfODsb)G_`x2wHe8uU!n> zwi0K~rvtn*lygc#Xf%g!FV-sNUAZU!AYsY5%baTIX4;bw3`>z)cKc5g{6l%Bzj z-p_#arNB-NJ_jD(8hlfVVDxAeLInW#oztnV>K4UxT=^#qrVlj-YjMqELuhUKbAf$# zDAMp9lz9i?U4%yfpxb0eGbt0!)(KQQyF1ibwilPX5cVM)U^|B+5iY@n!9}<}%4~B& zGvZ%0axN#*N`{%kHBqNPYtHFb)~Ta-N`b_- zw_*>`ON>0zRrW4?xS8}gzRj}3&pbba2Gec;E)5SgIIw|j@igUme zPte)gX+=((a+>HriuOVa$)>b!`Zr*~oUpQhs}dn(N$}^L-;z*0?>_8_XFrOr%H8A$ XG}f=-J~2r4tO^f+?LK))M)&>&W}5Q4 From 32bf47130440c33a07b5c0f7876010091305d7cb Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 05:23:39 +1000 Subject: [PATCH 023/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- helper.py | 44 +++++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/helper.py b/helper.py index ffd8a99..961b0b8 100644 --- a/helper.py +++ b/helper.py @@ -103,12 +103,6 @@ def log_loss_landscape(model, loss_fns, dataloader, step): wandb.log(log_dict) -# Usage example: -# model = YourModel() -# loss_fns = [nn.MSELoss(), nn.CrossEntropyLoss(), YourCustomLoss()] -# dataloader = YourDataLoader() -# step = current_training_step -# log_loss_landscape(model, loss_fns, dataloader, step) import torch @@ -126,11 +120,11 @@ def log_loss_landscape(model, loss_fns, dataloader, step): import io from PIL import Image -# Global variable to store the persistent table -gradient_flow_table = None +# Global variable to store the current table structure +current_table_columns = None -def log_grad_flow(named_parameters, _global_step): - global gradient_flow_table +def log_grad_flow(named_parameters, global_step): + global current_table_columns grads = [] layers = [] @@ -153,7 +147,7 @@ def log_grad_flow(named_parameters, _global_step): plt.xticks(range(len(grads)), layers, rotation="vertical") plt.xlabel("Layers") plt.ylabel("Normalized Gradient Magnitude") - plt.title(f"Normalized Gradient Flow (Step {_global_step})") + plt.title(f"Normalized Gradient Flow (Step {global_step})") plt.tight_layout() # Save the figure to a bytes buffer @@ -166,12 +160,21 @@ def log_grad_flow(named_parameters, _global_step): plt.close() - # Update or create the wandb.Table - if gradient_flow_table is None: - gradient_flow_table = wandb.Table(columns=["step"] + layers) + # Check if the table structure has changed + new_columns = ["step"] + layers + if current_table_columns != new_columns: + # If the structure has changed, delete the existing table and create a new one + if wandb.run is not None: + wandb.run.config.update({"gradient_flow_columns": new_columns}) + + # Delete the existing table + wandb.run.delete_artifact('gradient_flow_table:latest') + + current_table_columns = new_columns - # Add new row to the table - gradient_flow_table.add_data(_global_step, *normalized_grads) + # Create or update the wandb.Table + data = [[global_step] + normalized_grads] + table = wandb.Table(data=data, columns=current_table_columns) # Calculate statistics stats = { @@ -188,12 +191,11 @@ def log_grad_flow(named_parameters, _global_step): # Log everything wandb.log({ "gradient_flow_plot": img, - "gradient_flow_data": gradient_flow_table, + "gradient_flow_data": table, **stats, - }, step=_global_step) - - # Log gradient issues separately - wandb.log({"gradient_issues": wandb.Html(issues)}, step=_global_step) + "gradient_issues": wandb.Html(issues), + "step": global_step + }) def check_gradient_issues(grads, layers): issues = [] From 1baec27cccb3aaf09203aaa06eba3fe3aae6ef6c Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 05:40:03 +1000 Subject: [PATCH 024/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/helper.cpython-311.pyc | Bin 23231 -> 23517 bytes helper.py | 59 ++++++++++++++++++++++------- 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index a58105d446b5f368f9f272215508f9f8f4956b82..a09ae67cc8bfccb0d014624e2a6bfbfefce1b237 100644 GIT binary patch delta 1542 zcmZ8hZ%i9y7{9meU0d#2%AYz)TPU=Jm4XSlu-G6KI}%EDW05H{nx>=I+EH43@2Zis z4*x|o#LT^#5MhiD4n9m0=1JxU=jI2EAB?$@E|7~DO`Lw%moAa{#TnnXr38F)_sj2| z=lA^S^W62;2H3d)466o16@}07Z*IhXx^DR0%G5Fx^`24s+ZqCD{S$zVm0E)VX@A3gu zS*%)aB9d_3|J_-Aq%jL{X@-=JgItx^RSBl3^UW2WtvcA$=k+TpTn69elVjgEiUSmd zUU67J0PzmDAyDz*EUmVi4^(?57Qf3cD-eSzVTmKC-CFa#z72 zMccl@aQH`MLG;L?>MB4r_Ij140e#Z40X{^b*4sYC9EXu;f=`LzSTg-ucs!lVOs0hC z#=~YYaxTgHk`a*?M2`v%;zG#aeAl83lA7pMmk(5<>#jP$qVHY)V{D4w z@3Pc~LYJ~Dv>--AL1FgVn8j@f)fgNO2eP|cs%~Qb1~KkRL+%k^?q6qRV}JR|Yy1U` zU)K21T3hr&k)i4vFCLwLE??`M3GHZY%j~AsD{H-XS<~F(GV9pkYUY9k&bGxlH#w*5 z?)z#&b`9sbP=O1{Txge}47GE?EzZ5kx#hN}@?4<61!OMptHUen>^QyAGW(2d>&$bV zGTSwCg5(R$kKXYS@bQ( z^46|`wQJjCEwcM6BE3&BwDvCoIPxG!nxhRWa2eSuQ4e4Nj{PalYJ^W zL5!9eos7KBL%2XDr^SHBNe(oTu=NW3AidpjhXD`J+3pYnSq1h)2U=6+0 z{}@=3KI*3d&EZd4CoT9__24!N1pC-hUD7Btq7Q=2wk1pzIw2%dG1x)OrHWSwTSMOk z`H!Kq|M8=%gx(}j+AP^yX$g?b8j`V;I)kJS s=tJz7y=n+kd{M$%QqLF%Kr8Z}{n28v%apmO83F*me+O_BwZAa`AM_EN82|tP delta 1371 zcmZuwUu;uV7(b`I_x68R9D{Y;O6gWzS4MOzE>T!JS01{N08Yx#n0oPcv(UB6xtEwM zt?)47eDTi?4@$a)C=44GQ@qYGF+pF9F+SXyYD~>yOvLAsFyY0-@3eGm2`A_N?(hEY z@B6-UzH{&TD!#LVZP#r!GlJ1~dv)T(P21ltVUb7ZU5om_HH!6?Ig5I@Bd>EPL1wCt z3r?k%UNP;ANINERRzCyk@>5F~p&M}3huEeZXmEtiv^RJbLxk`tei^-wFnZ6Tyh!hP zT=)>R?e*CXHCzy}hwgB<+Lf)6he_w9<~*M>DZ`JwSzT7Yq=V{k)kn^(^fox;@-}$Z zxOE%S&b%&bx`wlOxn+AiMo*5Hb}2(!6L)SVHgDg4@`*_$G=}`UsmZ&7FVQ>R?{t1l z*E%mcF}VPVDov9_l#+?G5>LwVxrFTRq9?nSe9fEkL@M)kToF&F5*7ZH+RZ3RQd|-h zk%oOQ?6OuIX>lqc#ivCAcO^mOiZxC@@P&qXSrHYvBGfaJIjARF0X&aKm^j!fY^ls- zoRy5Ta%2^>S>|3-d!IgrTS7&vW)Ib8i3^vxu*QYy{4>(vhJcJN&EQ>kcjnI)pD%Xb zik13?HFu=sj+EUI&FBIL-&*E8SK2jhNP89{hc+k8c8|Zl?tZSyb4@}OF%Xc&Q8h4T zL^fyD#PMzS5Wsx~VBy;~?1(qdMzto-HQ&{q&toNTu0rQRDZ-A(O$J|6Dd6#3W;-+hJZeX+6=azi82q4K4g4zL`*fx+Cw#%wPw0Pr z9=t%^{lD;kg7BNV(Lb-pYjkcPvhOeE$TPSNAnO4jZ@s;ybcRfcspL$8+yxz}4+n|@ zzDo;(Wn7@|g!=J<`elg6ya4}_pne;^Zo=PCY2=`#R+lZ580n3XJ?#rXD|}K;rYA@b zlWP?h87t6#Mh*(UGx`VZ9zDVtAsPKJ?4%!!HuFw+2&#qA6@AZ7VDSs2g*mMt6B~m7 zd$a?q@NcHXscAnZm+9Qm08Y@aj-KzT$()T(q=_5eNc7~fQlQ=` z@KkIhF`1O%gN!p*&EnUSaaQNiD#wod_`NKrnmazsVKe>ygtzNoW_ZBhA%iGD#VDUU zJ)O!ZGAS@^ks4wH_%j-c&GJErpHY8|&0|N+U+YSaxvKQc_@(_zjIs~9ezNS`YKK5r zuL)~QZDoY5g9PcmQyu1)fWk!bMRn|y732N%jaPqe-|f;+bCnBYJc6re+bu#5zxMI} DilA7H diff --git a/helper.py b/helper.py index 961b0b8..6cf84f6 100644 --- a/helper.py +++ b/helper.py @@ -167,8 +167,14 @@ def log_grad_flow(named_parameters, global_step): if wandb.run is not None: wandb.run.config.update({"gradient_flow_columns": new_columns}) - # Delete the existing table - wandb.run.delete_artifact('gradient_flow_table:latest') + # Delete the existing table artifact + api = wandb.Api() + try: + artifact = api.artifact(f"{wandb.run.entity}/{wandb.run.project}/gradient_flow_table:latest") + artifact.delete(delete_aliases=True) + print("Deleted existing gradient_flow_table artifact.") + except wandb.errors.CommError: + print("No existing gradient_flow_table artifact found.") current_table_columns = new_columns @@ -189,14 +195,22 @@ def log_grad_flow(named_parameters, global_step): issues = check_gradient_issues(grads, layers) # Log everything - wandb.log({ + log_dict = { "gradient_flow_plot": img, - "gradient_flow_data": table, **stats, "gradient_issues": wandb.Html(issues), "step": global_step - }) + } + # Log the table as an artifact + artifact = wandb.Artifact('gradient_flow_table', type='table') + artifact.add(table, 'gradient_flow_data') + wandb.log_artifact(artifact) + + # Log other metrics + wandb.log(log_dict) + + def check_gradient_issues(grads, layers): issues = [] mean_grad = np.mean(grads) @@ -215,7 +229,7 @@ def check_gradient_issues(grads, layers): def count_model_params(model, trainable_only=False, verbose=False): """ - Count the number of parameters in a PyTorch model. + Count the number of parameters in a PyTorch model, distinguishing between system native and custom modules. Args: model (nn.Module): The PyTorch model to analyze. @@ -230,6 +244,9 @@ def count_model_params(model, trainable_only=False, verbose=False): trainable_params = 0 param_counts = defaultdict(int) + # List of PyTorch native modules + native_modules = set([name for name, obj in nn.__dict__.items() if isinstance(obj, type)]) + for name, module in model.named_modules(): for param_name, param in module.named_parameters(): if param.requires_grad: @@ -238,23 +255,39 @@ def count_model_params(model, trainable_only=False, verbose=False): # Count parameters for each layer type layer_type = module.__class__.__name__ + if layer_type in native_modules: + layer_type = f"Native_{layer_type}" + else: + layer_type = f"Custom_{layer_type}" param_counts[layer_type] += param.numel() if verbose: - print(f"{'Layer Type':<20} {'Parameter Count':<15} {'% of Total':<10}") - print("-" * 45) + print(f"{'Layer Type':<30} {'Parameter Count':<15} {'% of Total':<10}") + print("-" * 55) + + native_total = 0 + custom_total = 0 + for layer_type, count in sorted(param_counts.items(), key=lambda x: x[1], reverse=True): percentage = count / total_params * 100 - print(f"{layer_type:<20} {count:<15,d} {percentage:.2f}%") - print("-" * 45) - print(f"{'Total':<20} {total_params:<15,d} 100.00%") - print(f"{'Trainable':<20} {trainable_params:<15,d} {trainable_params/total_params*100:.2f}%") + print(f"{layer_type:<30} {count:<15,d} {percentage:.2f}%") + + if layer_type.startswith("Native_"): + native_total += count + else: + custom_total += count + + print("-" * 55) + print(f"{'Native Modules Total':<30} {native_total:<15,d} {native_total/total_params*100:.2f}%") + print(f"{'Custom Modules Total':<30} {custom_total:<15,d} {custom_total/total_params*100:.2f}%") + print("-" * 55) + print(f"{'Total':<30} {total_params:<15,d} 100.00%") + print(f"{'Trainable':<30} {trainable_params:<15,d} {trainable_params/total_params*100:.2f}%") if trainable_only: return trainable_params / 1e6, dict(param_counts) else: return total_params / 1e6, dict(param_counts) - def normalize(tensor): mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(tensor.device) std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(tensor.device) From 264bece03e5e1bda781aca7138f85a043d30b463 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 05:45:55 +1000 Subject: [PATCH 025/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/helper.cpython-311.pyc | Bin 23517 -> 25438 bytes helper.py | 48 +++++++++++++---------------- 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index a09ae67cc8bfccb0d014624e2a6bfbfefce1b237..219d84555d8b9f187a6a7ac45804e66f99e9a3c4 100644 GIT binary patch delta 3806 zcmbUjTW}lI_3lg7TkD5xIr3wZ*p3xRwi7#~NlYMdz64zjdPH(d zKYHix*>lh9o^$TqbN1Ss)bD;yIbU=*tO$nV-8VzzvCs zl_h+lql9%Jxo(La&r&-FQH(vd7fmw<(J@N1gqLa57^RsKU}KC*+SOpm7(x@BC0w(s z10@`xi_I_45End!k#sy>tKNY>6TH+O>{#CE*i+(wjc#1q=5eZ(Bp)Hoh|A7$NvFF1 zhO9M-6={@NtHt3Xe6viN}^|*u@(hsI-;nN*qHEN^o6jjMn=OU zYEPo~&iI#!x{bZcvchQmXiV(EudNI>=ylarr=oEc-`wHZsMl9B3==wMyue#lJw7EZ zRTjf)*w=>-bX?iUY?+Ga=8|{13Br!3$Mu>iIX)qcsJf{TEihy`E-SiaC_XuP8^Q60 z&ULGGcDp(m)y-Rq;pt3R6m@$vJ{Bt3lpld)`|9wC&d@qwPetTcfgs;Tz*zt~r>J37 z(M=IWIVvggql7UFK&b;D2m=11v%~fQ;13hxD@?kl%SO-Cd!F>2_CM*L9m><3V+Xb1 zZF^?$yMwQ+ySO3e-<0W&!hBnLE0Wv1r?3E+^-kuzlk?^U-t&~}3D*U6-j(GCa{NGs zAGpT5GmW=r`QaQtoH6G0Yp};)uHKhN6l_olukp)4dN|8(&+*$crI}}l-1Y0%^Cq-B zfbZ#^p2?e$rS<8NJj3y(MXM{%!TWO)YFU1EV`?<(4d%ST1>T#t06C9J;NLG#eIkvRz450S|D|gi~AJkF+)2n^;H1#ar?DyK<1rhnD z1iX($zaXRlRiN~egCIlZ3GXK*@Ed&I{}1N{7NMt*NF^C~Se)9ns-06gkSPw4K7~|M zX$9ql#qYEP@U1kD|J>p|<2(b-_#|)w+Y1OO~ z=8_?~Nnqs_ZvFXcl1?C>~LZ`lK+aVe5Z{86wyDmqRf(RtDJ0t2NSp2C0xZ8s8ROFA{D z#v~wMkxzhD(p9;iW-IupImKGyqvjIp;9Y;twA2r%tj?V1{|Ky#+fa1FUUb7=>2j4zGvzkCBy;Yf2}jX{<4a9^nT7xDVq?MnjK&xn zG%furSmKVS=nh+S2kzpHL#T?V!X1s08#GpICX#A<=`ax-csb|z!(NLOd??D`ANRI1 z%fg@H-}Y|U5D>&d!#NsJ#!25fBB{qDDJCe#6;+xPVqrD%bx8=vL}BEpqQ)nM$+&nl zDk*7q-wcI+?`rJa3Y$KK#*l^%w-^_QO44fCL-2%FX@tKRZu;sE;E&@pJ=eMy$>eMD z(YiIFL}F07Vg? zdd-j#W+-9Z9X>9}!aity31+Y#pI;N8*5mA&AYNQ!#s6A!f_72(#M(`PL_oerW?+pyv7Mnb?_5}hy z#-?>E%r%y-JOY5%uDg9^(bas_)p5nukvf=l_2yi?dBixn7WkGFm*qR>?eq4<3jAZU zV|M&;*PLtCb>_`&SZgJm)xS@U?#Jf1O+mz4+p zAMNysb#|KiC7xZsEkMC1BtHwump>#YQ280?mRMY#gn!3`By%9PGJSph4^1>j;qPz$ zHFXI;d`my|T>7DR!Z!oL>2qHd{>?V={Aj?4l z{AAe!n9hwv!;@1!N>1U~J$`ByzrJUt!?1fZuoY<4aN~@h0=;HD9)Bn_8k08z=P`WX zuB`tuvGWW7-62b3P^~06M5GK6C>Zwz>)jae*iz1RPL)^%)STj~CGr5gfP#!3}&V{8#WFi61YC~teOqqA$5``*Hs zuTB&LCL_@t;|I6FKFBbUC14?Zkw1PhCMISB5|&2Dj2ff(p(dgU0`cD0b!?pE{(9~? zzjN-n_uX@Q?^|^D3UZvY+bs-{mTOmHKWx3|__;!G3k-9@k@>gcRb+6SbYyl`U1be> z)h3uWd$?me^9pS8wIJQZPJ2PD^c<9mXAr}DNk(1}qnX9E8tO-yFiK4QC~Z-j3tom9 zSXa2P$fc>}>@%NmK(x-ZJXkQ|*Cr zugfHrO!8@)+67}ZRmct(t4mu=npIje$!fN{SN+;Q~2#TsaY@jUj zLG3A2Il8tVi&9)ps?k^?btI~ay$Ly-5LH=GVWwVe53pJHYRXU|r9NwuL{%hA=Q1;{ zD8sUX-zCKa{4m&X)8&-pgsjR@5v%c-*r&o#{|k0}j1aP1r#hGjIAOsbt;aj5h1cuI z@`@^|N>*6Pc9OUhCjm$_)>*F-GT4R)5o0KiBZeyshoFt!>WM zw&)8&M{wlO1A*Z!$GaB1)eGL*94|P997921tgf7qXb((GW%XouIyPU?e7mA~!Bvs7 zF0&~0FGk>ve^VflD*?Mn8JXUuV%9%Jw5&l-OB|^#0zjT)S7u18bhv8%WmuWqmBVn<%v6Syi4U<%sAsP3P<`?^A#ZH3ZrLf=nD^V86 zP-}+Oa`vrl$&%pYQTWA)ZbUcqvXEyKR5pN^!v?&wmJZDBX;FmoSt9~xD z$FKu}-Rmm#!{OsfJQ>65saaq04Ao{K)*TVei}`hAvLzGP zd)cslj!*&uE}={}G2>ziK|1viEz9>M#KEBe!V_THA4L5S+CNsO=dC4fh5iK;c4j=B zO1%}0C2<|`#Nf{U`C$G_5>D2R<^H&$$~a0{IwfGh2k1IWRo*&K&-+Q_aOUxWHV(O9 zM^8=NBT9Ho;R%H<0$GzX+)JNq18WFSKzfasR)oVad#DO^X1+i4JKEBT5P2?F Y;6B!YUd%De;A=$W3ay5H@JFxuKbgM%TL1t6 diff --git a/helper.py b/helper.py index 6cf84f6..98e2234 100644 --- a/helper.py +++ b/helper.py @@ -105,21 +105,6 @@ def log_loss_landscape(model, loss_fns, dataloader, step): -import torch -import torch.nn as nn -import torch.optim as optim -from collections import defaultdict -import numpy as np -import wandb -import os -from torchvision.utils import save_image -import torch.nn.functional as F -import torch -import matplotlib.pyplot as plt -import os -import io -from PIL import Image - # Global variable to store the current table structure current_table_columns = None @@ -165,7 +150,7 @@ def log_grad_flow(named_parameters, global_step): if current_table_columns != new_columns: # If the structure has changed, delete the existing table and create a new one if wandb.run is not None: - wandb.run.config.update({"gradient_flow_columns": new_columns}) + wandb.run.config.update({"gradient_flow_columns": new_columns}, allow_val_change=True) # Delete the existing table artifact api = wandb.Api() @@ -210,7 +195,7 @@ def log_grad_flow(named_parameters, global_step): # Log other metrics wandb.log(log_dict) - + def check_gradient_issues(grads, layers): issues = [] mean_grad = np.mean(grads) @@ -227,6 +212,9 @@ def check_gradient_issues(grads, layers): else: return "✅ No significant gradient issues detected" + + + def count_model_params(model, trainable_only=False, verbose=False): """ Count the number of parameters in a PyTorch model, distinguishing between system native and custom modules. @@ -265,20 +253,27 @@ def count_model_params(model, trainable_only=False, verbose=False): print(f"{'Layer Type':<30} {'Parameter Count':<15} {'% of Total':<10}") print("-" * 55) - native_total = 0 - custom_total = 0 + native_counts = {k: v for k, v in param_counts.items() if k.startswith("Native_")} + custom_counts = {k: v for k, v in param_counts.items() if k.startswith("Custom_")} + + native_total = sum(native_counts.values()) + custom_total = sum(custom_counts.values()) - for layer_type, count in sorted(param_counts.items(), key=lambda x: x[1], reverse=True): + # Print native modules + for i, (layer_type, count) in enumerate(sorted(native_counts.items(), key=lambda x: x[1], reverse=True), 1): percentage = count / total_params * 100 - print(f"{layer_type:<30} {count:<15,d} {percentage:.2f}%") - - if layer_type.startswith("Native_"): - native_total += count - else: - custom_total += count + print(f"Native {i}. {layer_type[7:]:<23} {count:<15,d} {percentage:.2f}%") print("-" * 55) print(f"{'Native Modules Total':<30} {native_total:<15,d} {native_total/total_params*100:.2f}%") + print("-" * 55) + + # Print custom modules + for i, (layer_type, count) in enumerate(sorted(custom_counts.items(), key=lambda x: x[1], reverse=True), 1): + percentage = count / total_params * 100 + print(f"Custom {i}. {layer_type[7:]:<23} {count:<15,d} {percentage:.2f}%") + + print("-" * 55) print(f"{'Custom Modules Total':<30} {custom_total:<15,d} {custom_total/total_params*100:.2f}%") print("-" * 55) print(f"{'Total':<30} {total_params:<15,d} 100.00%") @@ -288,6 +283,7 @@ def count_model_params(model, trainable_only=False, verbose=False): return trainable_params / 1e6, dict(param_counts) else: return total_params / 1e6, dict(param_counts) + def normalize(tensor): mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(tensor.device) std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(tensor.device) From a7fc004a2e75c47df37d7b4ef3c4ca0592718a97 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 05:49:23 +1000 Subject: [PATCH 026/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/helper.cpython-311.pyc | Bin 25438 -> 26558 bytes helper.py | 4 ++-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index 219d84555d8b9f187a6a7ac45804e66f99e9a3c4..49334aa7f7aaf7591feba885ef0dca25b3f7e40b 100644 GIT binary patch delta 2708 zcmcgudu&tJ8NcWHMdC;7*pN7m9ou<;6F2YH2{_R3YML-glkzAcHIDBM7>u2rdr3#+ zD6L96{^%dTH)(7wtD{hwtU-lYXhNGdx>hS%*Vaa^g1mvIY2DU!)w*W;r_#{AbM0j5 zO4on39i996d!OHT&UerCou8ARyhlv07>%VIju+qf;izuY8B-G zPLGpR!UoSn#17ARhDa4Gc&3R1PI&7{HN50)BTmSAof?-=Gw+5)ZxiugYLR&1+1jgQ z3O3X&-2a%5x2gN<+6Rr7j?Fz8Czj>QZK9# zDxT&4WfCf1Pz$=GMX)Qs<$}Em8OQmOiv*X$r!WZjeAn5IrrK&jtN59U{-%{{3Q8!P z?r3v-S&_!2BAx$(B3=KbB5R68mKBREyJwL-7I)E4rTBTyt6^7%v({VGs*BO;dqlfB z?YdfaSaL0F?DVYQUA(U@{aj~+C*OlVJ@5$mGprB1AI9B_2_laTmGN=V1pI03fPvmt}92co1DzP65=scm@8J45t zqI-gGc7Mn&m7No+V0#D5bk(z!YUl^3QM^0idP;o#1fGj-H)lgw(0RPDR7{CeL zmgAC~IOc{MT|rU~n(j{R>&%#fP2KK>3ygIl$!qbxJuXUgyCO&lDvLs1&C2#!VXAwW z+=AQPO?!-S1z{+XI3Aa!yfLC=UlI*PtDKO-vAm&ZR3z$Wi8V@GzI05a5xmR8heU}T z9z1Qe!od0}y$QjyMv}Iv(_`!11V0>SHuOv>9Wm=)Xbj$L3}zcQ<{CF*2V^Vv=PLJ~ zGT!>n<_{{W&h;!fT5dX8vW~W#qb*y}o~vj-Ww`xihQDjMwXIjRmnRnj{Wk;s*{Q&e zTwq6L@F1H<;+fV&wl$G!O=P}Kvkg+NLHdN_OZJii%^JL?db4HT)1~LO-Lkt<)mi&} zmsMGN$Ep6?21}-VH+FCNZk(yA+1}aBvzsrt3Nq5m;xc=^f0@mg%6DfCdvb<78D+wF zZ~*QGS55sp&i^w^7pr0=&M(4D-wl@V`X{%pp%o#jfnqeor@;1nq#1r&bI2N4{ zsSR78CjE^~m$k$OzU?`34eoC5AQ#f@4{3dge3UjawX)*NjJpP-gF)Rb#%c-74GyRSm<`N*cCZXtHGL=yrF}zJRgHf| zAD@=4X097hO@oXyF|h|ZY~xt?*ce`V)39o96FCW;d#CD@#5>R|vCnAnp3CbFClW_P zqjCB*v^)a~d$a!6S)LS2Vd&ZPKlW`=lWyo8 z_SV&-f;KR+mXU{9u2s_Uk+E1pmgqGW|1x}YI6%G+KOH`)iJ;SW(tS@HCgg8$CDdqE zqLkg4WWF`@HixG&It9zGh8{t3H6<^BKw delta 1642 zcmZuxZERCj7`~^sx82rz`^CPxueNJ9+I4F|K1ArI%gTTim6*u@2esS1V}-4q=N1f0 z2gDHF4@Bl`j0OlrP%v%=azn(Jm<1CgAkdP9%_bUtD2j$;2|xUdc+Opim~e8>bD#6R z@AIB>-*e9W@&@|;5;C3^gd&DC;o@g~_LebYx#j>G^PC%Ndi&hBuUDW$%FC5zEy}Zs;JG-6IR&FLxbXq@L_wOBK52dUPF957ZZAt(G+dmbHv z3D-2d?{=V0<#TrtYdA*)cr}HiaDCz0qN{}NC0(94>F{U@vwV1p4`=wW5@?t}7NNDK zqmjAN=n7les}>ZtD))TFnuQ+_L5`J^A`a0^FVLZOMhVQt`HI37@c12$G@-Jd zfMc#4 z60PLD0HU1>kO79CD1}=Qa|PC=_>^&2Aaed-TS~}t;ov&BzSQR`&L`pXN%%)6%qOf` zW_Ok3Wmn$k?SE51rK^4mO0R&;_VlzXY9BHm~Op zS5LfNrT5v>h3GmMx`OSKRGFsWp^4T|M(-moA&}!@ay%N(vC;URxuO$8EZ&D3XoNcH zDatOun_WTuG^Kwqkm(8|4^(Xma8HqxUCM?{=h=oIh|{Z6v*JcNQxgRrmG}wEab%YV zhr9@{VZJeYU^@s>q+8&dW5{~pI5<4*s@v?Xm z_0O;6udQybmDt}^+}ql;r~&$WYTc#8fXgV9Q&>j?xdM4t&(J_3DdP)Ne+oYBc@ljL zcX|$Roh0a(lIZP6=qH#F8_cRoEx`zlHG@@pXX7}fS1G9fCVI>;-3v@hj%g*-1!B4y zOjD{?M!gy8FYV8Q;9x@BH6XR%DDjeqkzdm&e@h0Wt!0n^y1q-ASsfH?_m(<>X!$Jj Le;7s(-Y)+OD Date: Sun, 11 Aug 2024 05:50:41 +1000 Subject: [PATCH 027/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/helper.cpython-311.pyc | Bin 26558 -> 26573 bytes train.py | 13 ++----------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index 49334aa7f7aaf7591feba885ef0dca25b3f7e40b..6a61062b4e172217a4443781062b532dd88ed6ba 100644 GIT binary patch delta 209 zcmdmYp7HE?M&9MTyj%=G;BaJnn!rY0MjOsj2_W~;jHwEf1#BjBmGS_2ALjSAC~RJ8 z6UySOeSt3hna?w2nE(~N`_Wj5bVFB9p~!CNovBZ{A`P$};( Date: Sun, 11 Aug 2024 05:51:27 +1000 Subject: [PATCH 028/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- helper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/helper.py b/helper.py index 98eebc7..65f5ebe 100644 --- a/helper.py +++ b/helper.py @@ -271,10 +271,10 @@ def count_model_params(model, trainable_only=False, verbose=False): # Print custom modules for i, (layer_type, count) in enumerate(sorted(custom_counts.items(), key=lambda x: x[1], reverse=True), 1): percentage = count / total_params * 100 - print(f"🍄 Custom {i}. {layer_type[7:]:<23} {count:<15,d} {percentage:.2f}%") + print(f"Custom {i}. {layer_type[7:]:<23} {count:<15,d} {percentage:.2f}%") print("-" * 55) - print(f"{'Custom Modules Total':<30} {custom_total:<15,d} {custom_total/total_params*100:.2f}%") + print(f"{'🍄 Custom Modules Total':<30} {custom_total:<15,d} {custom_total/total_params*100:.2f}%") print("-" * 55) print(f"{'Total':<30} {total_params:<15,d} 100.00%") print(f"{'Trainable':<30} {trainable_params:<15,d} {trainable_params/total_params*100:.2f}%") From ef56439c2e6dd9c082065a08facde183e3d84202 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 05:53:01 +1000 Subject: [PATCH 029/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/helper.cpython-311.pyc | Bin 26573 -> 26573 bytes helper.py | 4 ++-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index 6a61062b4e172217a4443781062b532dd88ed6ba..f410e4963786990844c8bd2cb6fe60ba48a70ec8 100644 GIT binary patch delta 111 zcmX?mp7HE?M$YBDyj%=GuyZ5lLL1>KcIVRKlKfnSQc0lLhxxrN3Y+)Z9A=q3!J*PZ u=K^021Z}W^Gf(hb2u-}`k$52~^`b{ATo}l>$dP-6BliMF?q)-$IjR5NbWF&GfzmtnLx%xp6Dw)(HFqzDF<&y zMkjv{e+U0v4&Dob>KFLZA!tPooO!_JLO{$#=a>tz2^XCc;KCR9xY93jWL)9MxWJLI J+0bc@DgYZ0O_~4z diff --git a/helper.py b/helper.py index 65f5ebe..912bf85 100644 --- a/helper.py +++ b/helper.py @@ -262,7 +262,7 @@ def count_model_params(model, trainable_only=False, verbose=False): # Print native modules for i, (layer_type, count) in enumerate(sorted(native_counts.items(), key=lambda x: x[1], reverse=True), 1): percentage = count / total_params * 100 - print(f"Native {i}. {layer_type[7:]:<23} {count:<15,d} {percentage:.2f}%") + print(f" {i}. {layer_type[7:]:<23} {count:<15,d} {percentage:.2f}%") print("-" * 55) print(f"{'☕ Native Modules Total':<30} {native_total:<15,d} {native_total/total_params*100:.2f}%") @@ -271,7 +271,7 @@ def count_model_params(model, trainable_only=False, verbose=False): # Print custom modules for i, (layer_type, count) in enumerate(sorted(custom_counts.items(), key=lambda x: x[1], reverse=True), 1): percentage = count / total_params * 100 - print(f"Custom {i}. {layer_type[7:]:<23} {count:<15,d} {percentage:.2f}%") + print(f" {i}. {layer_type[7:]:<23} {count:<15,d} {percentage:.2f}%") print("-" * 55) print(f"{'🍄 Custom Modules Total':<30} {custom_total:<15,d} {custom_total/total_params*100:.2f}%") From ae1c736bd326cb73db37bbb5c29c38325810c9af Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 05:54:08 +1000 Subject: [PATCH 030/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/helper.cpython-311.pyc | Bin 26573 -> 26566 bytes helper.py | 11 ++++++----- train.py | 4 +--- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index f410e4963786990844c8bd2cb6fe60ba48a70ec8..84ea5b815074b341a0fb12c45faa5cdb0bc4604b 100644 GIT binary patch delta 207 zcmX?mp7GduM&9MTyj%=GAbxav8s|pd+tw^B3JMC7-&^ltW>!$xyvgPsGkXZ*2L>k2 zkjc{>6di+R2wo5h0-_mND+E^vtq_`%2jK(B4W@AB0Vy~W$hgQ8e1#|Y0vJvH>Cny| z3^p+sY+~RH#tTA$Kr|z21>*{)6-;yDAbcRXK^4wCzy)Ul85eniuJ8n10He(>9iONI E0Q4wLX8-^I delta 214 zcmX?hp7HE?M&9MTyj%=Gu=B|FG=Yu0x2@UP{Sr$u%Tg!5vEIYZ?p#`2lApVIwaq Date: Sun, 11 Aug 2024 06:00:32 +1000 Subject: [PATCH 031/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/helper.cpython-311.pyc | Bin 26566 -> 26566 bytes train.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index 84ea5b815074b341a0fb12c45faa5cdb0bc4604b..12101932762edfd4730c79e3858972ea7be82221 100644 GIT binary patch delta 623 zcmX?hp7GduM&9MTyj%=G5P5Wa8rMeN8Ag-On`kNXr?AatNnx+yOA&z3HOy7q3=AoP zRjfc#2q-2zS;o|G@=s%l$$O2|#6?;dQbgB?En{Y2SPjGwkRm?W&qUK1tfy9}L=q+r zG(n<9U^YXFWDR$UREl&f6GAFQ25g@o)LtQ|UAs)|m1UuB;)lxd<8+gqxEzV@ktf+b zih>GI_b`Lqqd1pkvYe@`vJzC31?ox`OjmARW!l3!ncGHIq>9zX&{QYoWePB8s+jeR z(ta_jPS&u|nVe+p&&WM_leK|p6^jB8R59r(ykrNe@~dLDF*2UK#zqdPPFMOlP^btP zF$_g7fJ8APkWeTR0P$Z=F0j!Q76ytGy#n!tfJ7BD(5TI)Z0<2J%1&Nszh?3$hxW;x z4waKXIXW{xRnwpB%@7nGS1puo_v^D!iAC1>;nS^=`%Sfy#xT5xS1mW delta 604 zcmX?hp7GduM&9MTyj%=GAbxav8s|pd8Ag+pjqSwwTNqLV)(9?RW?)zi#1N1oH2I8) zrn7JrD+5EVP>Cc=oPi-lq()#iLyBk(cZyhwcqxSDE&(PENP3W@MlI*xG>UWy)j$8wt58 zRvSZ8os=qOJ)^W=jH;y)!0>%EW2%B*Vo7FM>SPNWYw4#zp(0?|Fcdul62*)_LZL_i z#D6}y&qh;N04P%Q0>l>r5>?C!3JQ}oY~(hdvbo2^C^LDb{hG-u9M(=wbts>F$}w}Y zt79 Date: Sun, 11 Aug 2024 06:01:46 +1000 Subject: [PATCH 032/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model.py b/model.py index c09ab3e..e8d31b1 100644 --- a/model.py +++ b/model.py @@ -11,7 +11,7 @@ from stylegan import EqualConv2d,EqualLinear from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights from resblock import UpConvResBlock,DownConvResBlock,FeatResBlock,StyledConv,ResBlock - +import math # from common import DownConvResBlock,UpConvResBlock import colored_traceback.auto # makes terminal show color coded output when crash From 1e7e91b75abe0d8cb9e3149bca2b0de2d1d97dd7 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 06:21:06 +1000 Subject: [PATCH 033/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model.py b/model.py index e8d31b1..1703b23 100644 --- a/model.py +++ b/model.py @@ -157,7 +157,7 @@ def forward(self, x): class LatentTokenEncoder(nn.Module): - def __init__(self, initial_channels=64, output_channels=[128, 256, 512, 512], dm=32): + def __init__(self, initial_channels=64, output_channels=[128, 256, 512, 512,512, 512], dm=32): super(LatentTokenEncoder, self).__init__() self.conv1 = nn.Conv2d(3, initial_channels, kernel_size=3, stride=1, padding=1) @@ -402,7 +402,7 @@ def __init__(self,use_resnet_feature=False,use_mlgffn=False, latent_dim=32, base self.encoder_dims = [64, 128, 256, 512] self.latent_token_encoder = LatentTokenEncoder( initial_channels=64, - output_channels=self.encoder_dims, + output_channels=[128, 256, 512, 512,512, 512], dm=32 ) self.motion_dims = [128, 256, 512, 512] From ecc0b548d02b7b040acd9c42d77d111446fc2c9e Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 06:33:18 +1000 Subject: [PATCH 034/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- helper.py | 42 +++++++++++++++++----------------- model.py | 67 +++++++++++++++++++++++++++++++++++++++++++++++-------- train.py | 5 +++-- 3 files changed, 82 insertions(+), 32 deletions(-) diff --git a/helper.py b/helper.py index 95e050c..70ae027 100644 --- a/helper.py +++ b/helper.py @@ -146,26 +146,26 @@ def log_grad_flow(named_parameters, global_step): plt.close() # Check if the table structure has changed - new_columns = ["step"] + layers - if current_table_columns != new_columns: - # If the structure has changed, delete the existing table and create a new one - if wandb.run is not None: - wandb.run.config.update({"gradient_flow_columns": new_columns}, allow_val_change=True) + # new_columns = ["step"] + layers + # if current_table_columns != new_columns: + # # If the structure has changed, delete the existing table and create a new one + # if wandb.run is not None: + # wandb.run.config.update({"gradient_flow_columns": new_columns}, allow_val_change=True) - # Delete the existing table artifact - api = wandb.Api() - try: - artifact = api.artifact(f"{wandb.run.entity}/{wandb.run.project}/gradient_flow_table:latest") - artifact.delete(delete_aliases=True) - print("Deleted existing gradient_flow_table artifact.") - except wandb.errors.CommError: - print("No existing gradient_flow_table artifact found.") + # # Delete the existing table artifact + # api = wandb.Api() + # try: + # artifact = api.artifact(f"{wandb.run.entity}/{wandb.run.project}/gradient_flow_table:latest") + # artifact.delete(delete_aliases=True) + # print("Deleted existing gradient_flow_table artifact.") + # except wandb.errors.CommError: + # print("No existing gradient_flow_table artifact found.") - current_table_columns = new_columns + # current_table_columns = new_columns - # Create or update the wandb.Table - data = [[global_step] + normalized_grads] - table = wandb.Table(data=data, columns=current_table_columns) + # # Create or update the wandb.Table + # data = [[global_step] + normalized_grads] + # table = wandb.Table(data=data, columns=current_table_columns) # Calculate statistics stats = { @@ -188,9 +188,9 @@ def log_grad_flow(named_parameters, global_step): } # Log the table as an artifact - artifact = wandb.Artifact('gradient_flow_table', type='table') - artifact.add(table, 'gradient_flow_data') - wandb.log_artifact(artifact) + # artifact = wandb.Artifact('gradient_flow_table', type='table') + # artifact.add(table, 'gradient_flow_data') + # wandb.log_artifact(artifact) # Log other metrics wandb.log(log_dict) @@ -266,7 +266,7 @@ def count_model_params(model, trainable_only=False, verbose=False): # Print native modules for i, (layer_type, count) in enumerate(sorted(native_counts.items(), key=lambda x: x[1], reverse=True), 1): percentage = count / total_params * 100 - print(f" {i}. {layer_type[7:]:<23} {count:<15,d} {percentage:.2f}%") + print(f" {i}. ⅀ {layer_type[7:]:<23} {count:<15,d} {percentage:.2f}%") print("-" * 55) print(f"{'🍄 Custom Modules Total':<30} {custom_total:<15,d} {custom_total/total_params*100:.2f}%") diff --git a/model.py b/model.py index 1703b23..a17d94b 100644 --- a/model.py +++ b/model.py @@ -12,7 +12,7 @@ from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights from resblock import UpConvResBlock,DownConvResBlock,FeatResBlock,StyledConv,ResBlock import math - +import random # from common import DownConvResBlock,UpConvResBlock import colored_traceback.auto # makes terminal show color coded output when crash @@ -617,15 +617,64 @@ def forward(self, x): debug_print(f"PatchDiscriminator final output shapes: {output1.shape}, {output2.shape}") return [output1, output2] -# Helper function to initialize weights -def init_weights(m): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find('BatchNorm') != -1: - nn.init.normal_(m.weight.data, 1.0, 0.02) - nn.init.constant_(m.bias.data, 0) +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, use_instance_norm=True): + super(ConvBlock, self).__init__() + layers = [ + spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)), + nn.LeakyReLU(0.2, inplace=True) + ] + if use_instance_norm: + layers.insert(1, nn.InstanceNorm2d(out_channels)) + self.block = nn.Sequential(*layers) + + def forward(self, x): + return self.block(x) + +class PatchDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3): + super(PatchDiscriminator, self).__init__() + sequence = [ConvBlock(input_nc, ndf, use_instance_norm=False)] + + nf_mult = 1 + for n in range(1, n_layers): + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ConvBlock(ndf * nf_mult_prev, ndf * nf_mult)] + + sequence += [ + ConvBlock(ndf * nf_mult, ndf * nf_mult, stride=1), + spectral_norm(nn.Conv2d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1)) + ] + + self.model = nn.Sequential(*sequence) + + def forward(self, x): + return self.model(x) + +class MultiScalePatchDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, num_D=3): + super(MultiScalePatchDiscriminator, self).__init__() + self.num_D = num_D + self.n_layers = n_layers + + for i in range(num_D): + subnetD = PatchDiscriminator(input_nc, ndf, n_layers) + setattr(self, f'scale_{i}', subnetD) + + def singleD_forward(self, model, input): + return model(input) + + def forward(self, input): + result = [] + input_downsampled = input + for i in range(self.num_D): + model = getattr(self, f'scale_{i}') + result.append(self.singleD_forward(model, input_downsampled)) + if i != self.num_D - 1: + input_downsampled = F.avg_pool2d(input_downsampled, kernel_size=3, stride=2, padding=1, count_include_pad=False) + return result \ No newline at end of file diff --git a/train.py b/train.py index ce6238e..13057b8 100644 --- a/train.py +++ b/train.py @@ -9,7 +9,7 @@ import yaml import os import torch.nn.functional as F -from model import IMFModel, debug_print,PatchDiscriminator +from model import IMFModel, debug_print,PatchDiscriminator,MultiScalePatchDiscriminator from VideoDataset import VideoDataset from EMODataset import EMODataset,gpu_padded_collate from torchvision.utils import save_image @@ -357,7 +357,8 @@ def main(): ) add_gradient_hooks(model) - discriminator = PatchDiscriminator(ndf=config.discriminator.ndf) + # discriminator = PatchDiscriminator(ndf=config.discriminator.ndf) Original + discriminator = MultiScalePatchDiscriminator() add_gradient_hooks(discriminator) transform = transforms.Compose([ From b54f33e1ffd4ce9b92cbdf8ec79477c8ca3a77c6 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 06:34:06 +1000 Subject: [PATCH 035/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/helper.cpython-311.pyc | Bin 26566 -> 25146 bytes train.py | 3 ++- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index 12101932762edfd4730c79e3858972ea7be82221..085ea9b6d76709e3bfbfcbb79554d9ab9adc9a56 100644 GIT binary patch delta 1108 zcmY*YZ)jUp6o2O>d9P{mMw>dZZMr2{LYmZNNZUnSs#YS>wfkTvOB`EEllMAP)1=<_ z@}VWgKTND&{BjVBjo={LIx=+I)ro%D98)RC+WH4Z^n>_Gn2JNk#DaKFLWdsi`@P>i zzjN+A_q_YQScgxqgLYQaS|wpiH$I-;e^DFnR$3gA^tziH-J@W0FS$AQR%GQ3w>2n9 zxaogDw&FPN4{34AYTAsKgdNs@T|L6pskT=fi%x3@Z~J>jT6IskB?o2aX|VQgr%KYy z?rmOfJp+F;%G}?(1~7pCL{CB&UhTUdlieP^)u+hecSJ

gNPs;;l!Ygf)yExYb%xgeT_v>9v>9n@DxvAl*fq0Sh)! z@q!$~wZt*y9C_cyz9Ap1;MmZsa0SiFt-HP)O9yimGFuO7We1=)0`zICTI8iPMId|QSOF`i6D?$W;0Z+2OLX%ir- zYSdeI7K&_kDxX_0*p|erNv_CS5N~|#H#pKDxqE>`m2ODQ=|4>qFMO(p-y}v-2wafv*}N4 z(VIe#C~PN%1%%G_5cJVBC|KQ*%}>qFMSv}1{8-J`Sgv0PrYu6XPoh3wEWR)~U0?$! zAOG&bX0wG}*O+l4XIciEqzE=35+gR2rq#C#vrPmc&U+`Gwu_rImk9*jRDedXgJ6W# zX)#NgxqQ(wS(S1tcsUb?Q}|ovBpl>sHV^PU{-{UW8X1i~EK;0;T?P)jkBKKFg! z@8|pDec#L7rx(FDSAqFGvpIvpW6Rf{^k#f$?r<<~=qc(Yt1LR&fY!R#D$CBpnx_P7 zABbv)F)?P8s)@twh#Gfji3G1n6?atl)$d5j&t1H?zTcT*pl!wn%4|_%00{QeStm3U zbqYVT15i?BO2$zOP?0mo5X&%9(TwpzblZ{j^nIf9DX&A6kbNIc(OsOJBlHP?GvjA= z#7t4scri&ag3Ed!-H6LFr-&3K-VoKZsnL^F?rv=Th2Kv7HjPNG1_Kddl^I;L=L~lt68u+{N+${|p#` zqnM{=_d&l95JQr$HyGaO>j?*k211ceH&8i$kZ7Lu2Yo%;{h>bbUr;sds7mKc{_dc- zJm{Cih=hGI1u-Z}qR$@;_#mRM%T@bamj69fT{0Igl9qN4imEN0Blso1dlBj? zK7qJxIh#~1915vSPdL;Y=u?@YL2MUQZNp$dHKeLl8Qivjv|H5=!f?OXBdN^n_OSuN zFpQ|i$HM~y%|uppoupGp{6u7}thz3N5r)dj8!JEXWZdQ&ZJW|N z$4%Gt9!2l@)x_;vrkMCC*1B(1f_47Dx~{Pcz!Nk2qadlon2szI|fD)j)5`5 zB$xZ9QH5g2$M4Jqd)pOfMVzfrOqHWeQ#p8hZ3jB9dmdBRhB(`hU>io8 zrw!;>=?e`U#bxi?qZBMYSEg7Rz-6?joCg!=Xn8hsfhaD>pO&9ukkz}|$_vaPTF^(e0*Bv4 zM`)yY@@v3 z2$ob`PtZRw=mu~C9enZ>pA5GK8zSV_&d;`2)o%}ncldfkupJ4Tu2+0XyvH%9W+?XI zZ>9)+gaO;Jqgx9%6ML#KadQb+C%>|}g$B#ewXMY^rI-XCBv3|R6^Ubr40R6%!%_sE zC%(r~ZdWaMAFb7_kL{T5Af@g?KJ}NMIFEdHlECMwI XS_zQK;5q_5Xx+AaFerz%T><|AWzT=q diff --git a/train.py b/train.py index 13057b8..4abf146 100644 --- a/train.py +++ b/train.py @@ -358,7 +358,8 @@ def main(): add_gradient_hooks(model) # discriminator = PatchDiscriminator(ndf=config.discriminator.ndf) Original - discriminator = MultiScalePatchDiscriminator() + discriminator = MultiScalePatchDiscriminator(input_nc=3, ndf=64, n_layers=3, num_D=3) + add_gradient_hooks(discriminator) transform = transforms.Compose([ From c30545adfe2ea0c793cfc8c7a56252c5c988c50a Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 06:38:01 +1000 Subject: [PATCH 036/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.yaml b/config.yaml index 2be640e..746e4e0 100644 --- a/config.yaml +++ b/config.yaml @@ -15,7 +15,7 @@ training: num_epochs: 1000 save_steps: 250 learning_rate_g: 1.0e-4 # Reduced learning rate for generator - initial_learning_rate_d: 3e-6 # Set a lower initial learning rate for discriminator + initial_learning_rate_d: 1.0e-4 # Set a lower initial learning rate for discriminator # learning_rate_g: 5.0e-4 # Increased learning rate for generator # learning_rate_d: 5.0e-4 # Increased learning rate for discriminator ema_decay: 0.999 From 146c1e4b06c389a12b2e6146175dcd72e084e68d Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 06:47:59 +1000 Subject: [PATCH 037/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- helper.py | 6 +-- model.py | 109 +++++++++++++++++++++++++++--------------------------- train.py | 6 +-- 3 files changed, 61 insertions(+), 60 deletions(-) diff --git a/helper.py b/helper.py index 70ae027..df5bfc0 100644 --- a/helper.py +++ b/helper.py @@ -250,8 +250,7 @@ def count_model_params(model, trainable_only=False, verbose=False): param_counts[layer_type] += param.numel() if verbose: - print(f"{'Layer Type':<30} {'Parameter Count':<15} {'% of Total':<10}") - print("-" * 55) + native_counts = {k: v for k, v in param_counts.items() if k.startswith("Native_")} custom_counts = {k: v for k, v in param_counts.items() if k.startswith("Custom_")} @@ -276,7 +275,8 @@ def count_model_params(model, trainable_only=False, verbose=False): percentage = count / total_params * 100 print(f" {i}. {layer_type[7:]:<23} {count:<15,d} {percentage:.2f}%") - + print(f"{'Layer Type':<30} {'Parameter Count':<15} {'% of Total':<10}") + print("-" * 55) print(f"{'Total':<30} {total_params:<15,d} 100.00%") print(f"{'Trainable':<30} {trainable_params:<15,d} {trainable_params/total_params*100:.2f}%") diff --git a/model.py b/model.py index a17d94b..b0cfbdb 100644 --- a/model.py +++ b/model.py @@ -300,6 +300,7 @@ def __init__(self): self.final_conv = nn.Sequential( nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(3), nn.Sigmoid() ) @@ -560,62 +561,62 @@ def forward(self, input): Output: The forward method returns a list containing the outputs from both scales. Weight Initialization: A helper function init_weights is provided to initialize the weights of the network, which can be applied using the apply method. ''' -class PatchDiscriminator(nn.Module): - def __init__(self, input_nc=3, ndf=64): - super(PatchDiscriminator, self).__init__() +# class PatchDiscriminator(nn.Module): +# def __init__(self, input_nc=3, ndf=64): +# super(PatchDiscriminator, self).__init__() - self.scale1 = nn.Sequential( - spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1)), - nn.LeakyReLU(0.2, inplace=True), - spectral_norm(nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1)), - nn.InstanceNorm2d(ndf * 2), - nn.LeakyReLU(0.2, inplace=True), - spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)), - nn.InstanceNorm2d(ndf * 4), - nn.LeakyReLU(0.2, inplace=True), - spectral_norm(nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1)), - nn.InstanceNorm2d(ndf * 8), - nn.LeakyReLU(0.2, inplace=True), - spectral_norm(nn.Conv2d(ndf * 8, 1, kernel_size=1, stride=1, padding=0)) - ) - - self.scale2 = nn.Sequential( - spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1)), - nn.LeakyReLU(0.2, inplace=True), - spectral_norm(nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1)), - nn.InstanceNorm2d(ndf * 2), - nn.LeakyReLU(0.2, inplace=True), - spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)), - nn.InstanceNorm2d(ndf * 4), - nn.LeakyReLU(0.2, inplace=True), - spectral_norm(nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1)), - nn.InstanceNorm2d(ndf * 8), - nn.LeakyReLU(0.2, inplace=True), - spectral_norm(nn.Conv2d(ndf * 8, 1, kernel_size=1, stride=1, padding=0)) - ) - - - def forward(self, x): - debug_print(f"PatchDiscriminator input shape: {x.shape}") - - # Scale 1 - output1 = x - for i, layer in enumerate(self.scale1): - output1 = layer(output1) - debug_print(f"Scale 1 - Layer {i} output shape: {output1.shape}") - - # Scale 2 - x_downsampled = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False) - debug_print(f"Scale 2 - Downsampled input shape: {x_downsampled.shape}") - - output2 = x_downsampled - for i, layer in enumerate(self.scale2): - output2 = layer(output2) - debug_print(f"Scale 2 - Layer {i} output shape: {output2.shape}") - - debug_print(f"PatchDiscriminator final output shapes: {output1.shape}, {output2.shape}") - return [output1, output2] +# self.scale1 = nn.Sequential( +# spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1)), +# nn.LeakyReLU(0.2, inplace=True), +# spectral_norm(nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1)), +# nn.InstanceNorm2d(ndf * 2), +# nn.LeakyReLU(0.2, inplace=True), +# spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)), +# nn.InstanceNorm2d(ndf * 4), +# nn.LeakyReLU(0.2, inplace=True), +# spectral_norm(nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1)), +# nn.InstanceNorm2d(ndf * 8), +# nn.LeakyReLU(0.2, inplace=True), +# spectral_norm(nn.Conv2d(ndf * 8, 1, kernel_size=1, stride=1, padding=0)) +# ) + +# self.scale2 = nn.Sequential( +# spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1)), +# nn.LeakyReLU(0.2, inplace=True), +# spectral_norm(nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1)), +# nn.InstanceNorm2d(ndf * 2), +# nn.LeakyReLU(0.2, inplace=True), +# spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)), +# nn.InstanceNorm2d(ndf * 4), +# nn.LeakyReLU(0.2, inplace=True), +# spectral_norm(nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1)), +# nn.InstanceNorm2d(ndf * 8), +# nn.LeakyReLU(0.2, inplace=True), +# spectral_norm(nn.Conv2d(ndf * 8, 1, kernel_size=1, stride=1, padding=0)) +# ) + + +# def forward(self, x): +# debug_print(f"PatchDiscriminator input shape: {x.shape}") + +# # Scale 1 +# output1 = x +# for i, layer in enumerate(self.scale1): +# output1 = layer(output1) +# debug_print(f"Scale 1 - Layer {i} output shape: {output1.shape}") + +# # Scale 2 +# x_downsampled = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False) +# debug_print(f"Scale 2 - Downsampled input shape: {x_downsampled.shape}") + +# output2 = x_downsampled +# for i, layer in enumerate(self.scale2): +# output2 = layer(output2) +# debug_print(f"Scale 2 - Layer {i} output shape: {output2.shape}") + +# debug_print(f"PatchDiscriminator final output shapes: {output1.shape}, {output2.shape}") +# return [output1, output2] class ConvBlock(nn.Module): diff --git a/train.py b/train.py index 4abf146..b248477 100644 --- a/train.py +++ b/train.py @@ -9,7 +9,7 @@ import yaml import os import torch.nn.functional as F -from model import IMFModel, debug_print,PatchDiscriminator,MultiScalePatchDiscriminator +from model import IMFModel, debug_print,MultiScalePatchDiscriminator from VideoDataset import VideoDataset from EMODataset import EMODataset,gpu_padded_collate from torchvision.utils import save_image @@ -59,8 +59,8 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=config.training.initial_learning_rate_d, betas=(config.optimizer.beta1, config.optimizer.beta2)) # dynamic learning rate - scheduler_g = ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5, verbose=True) - scheduler_d = ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5, verbose=True) + scheduler_g = ReduceLROnPlateau(optimizer_g, mode='min', factor=0.25, patience=10, verbose=True) + scheduler_d = ReduceLROnPlateau(optimizer_d, mode='min', factor=0.25, patience=10, verbose=True) # Make EMA conditional based on config From 922f4ca0de18a5f4008927b4e48a576af820170d Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 06:58:26 +1000 Subject: [PATCH 038/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/helper.cpython-311.pyc | Bin 25146 -> 25147 bytes model.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index 085ea9b6d76709e3bfbfcbb79554d9ab9adc9a56..f5a507a6fc4e5adfcca9109dfa5ddcb28f9ca98d 100644 GIT binary patch delta 926 zcmdmWgmL!~M&9MTyj%=GAar4Sn)OEB`Pxh=tdn%C8QCW3Sn~6hNP^@bpoVAi16GyE z1x$PbwR|WNyfBISTK*DQn0f|=6!v9|3=FG*7y=j>P-Hn$IA^n@aMkdoa6{-C<|=NW zE}kk@Aju091x)4jjt6# zaf)z_z-)#Tks9t4(G;;(CWNn3#KBGzggQ|O>MTxlXZ2X1n1}8#B#*lR{U!nRIX~1` ze$stDd5@mIxFm_*lmdHGnrhyZf%qduHbo8;Oeyjy3bR?}vVsDq9%r=U2^(P0C_;mV z85}f9b6LoMq{1H2aLHV_nOMFGus#&++tLiyvnqnk!!M**==PZ zp!bR{!j&^CC@54h=_yok`6O1R7Ab^O7NkzDHCCUjW8NqT$_=;J$`W%*Q;Umc138=b rnQJjJx@~@Gp~}eifsH{jnsKw2l{XXP)5-i+7R+}!csI|oS)d94nAxn+ delta 833 zcmdmegmKpqM&9MTyj%=G&~a{in%PF)`PvF4k{~_=q_C#2&1OkqPvJ=61k+q8+>;CR zMJLa2l|18i*l)ks*b5@;Ys6Mm`{E$tt4POm&6*mJzieMEhkQ4%n3D0GNxC*46 zk!Nz2p5|nGT{Ur$7KRkjHDb${fi4DO2uKm1d_hk$8?3ul2*p;38iCmiDUvnZDN-rY ztxSj@Ns$5DEC_Xh5Y$#q^swl$Krs*57qU=a@I%euC&Cy0;&QlsP>;xh#6CH(>m7(%1P+zfN`f8IwCbLKtmrr74YLP-nWkG5cvyHLAEk=dO ztmf)WRpyfwO${d-n6gd2z$G@h-gp@k1LI_MlZ9+YfzefTeDXOHTWv;Qaw|Fs;tK(Z zQVAgOXvS0pzr>QvvQ!1%{FKt1)MAB@{F20+$>pZ~jNFsi&2EbegLGYl>8fH@P*A91 z(o>kMW8Nq@8zd|XByO>lCFYc-7H{5XuEogcy7{GrDkJ0F$q6q>bUX!%K diff --git a/model.py b/model.py index b0cfbdb..6fc1625 100644 --- a/model.py +++ b/model.py @@ -300,7 +300,7 @@ def __init__(self): self.final_conv = nn.Sequential( nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1), - nn.BatchNorm2d(3), + # nn.BatchNorm2d(3), nn.Sigmoid() ) From d74984c9ee1139c574c51eaea5af8487d50937c8 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 07:23:08 +1000 Subject: [PATCH 039/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- framedecoder.py | 99 +++++++++++++++++++++++++++++++++++++++++++++++++ model.py | 4 +- train.py | 19 ++++++++-- 3 files changed, 117 insertions(+), 5 deletions(-) create mode 100644 framedecoder.py diff --git a/framedecoder.py b/framedecoder.py new file mode 100644 index 0000000..679ded6 --- /dev/null +++ b/framedecoder.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class FourierFeatures(nn.Module): + def __init__(self, input_dim, mapping_size=256, scale=10): + super().__init__() + self.input_dim = input_dim + self.mapping_size = mapping_size + self.B = nn.Parameter(torch.randn((input_dim, mapping_size)) * scale, requires_grad=False) + + def forward(self, x): + x = x.matmul(self.B) + return torch.sin(x) + +class ModulatedFC(nn.Module): + def __init__(self, in_features, out_features, style_dim): + super().__init__() + self.fc = nn.Linear(in_features, out_features) + self.modulation = nn.Linear(style_dim, in_features) + + def forward(self, x, style): + style = self.modulation(style).unsqueeze(1) + x = self.fc(x * style) + return x + +class CIPSFrameDecoder(nn.Module): + def __init__(self, feature_dims, ngf=64, max_resolution=256, style_dim=512, num_layers=8): + super(CIPSFrameDecoder, self).__init__() + + self.feature_dims = feature_dims + self.ngf = ngf + self.max_resolution = max_resolution + self.style_dim = style_dim + self.num_layers = num_layers + + self.feature_projection = nn.ModuleList([ + nn.Linear(dim, style_dim) for dim in feature_dims + ]) + + self.fourier_features = FourierFeatures(2, 256) + self.coord_embeddings = nn.Parameter(torch.randn(1, 512, max_resolution, max_resolution)) + + self.layers = nn.ModuleList() + self.to_rgb = nn.ModuleList() + + current_dim = 512 + 256 # Fourier features + coordinate embeddings + + for i in range(num_layers): + self.layers.append(ModulatedFC(current_dim, ngf * 8, style_dim)) + if i % 2 == 0 or i == num_layers - 1: + self.to_rgb.append(ModulatedFC(ngf * 8, 3, style_dim)) + current_dim = ngf * 8 + + def get_coord_grid(self, batch_size, resolution): + x = torch.linspace(-1, 1, resolution) + y = torch.linspace(-1, 1, resolution) + x, y = torch.meshgrid(x, y, indexing='ij') + coords = torch.stack((x, y), dim=-1).unsqueeze(0).repeat(batch_size, 1, 1, 1) + return coords.to(next(self.parameters()).device) + + def forward(self, features): + batch_size = features[0].shape[0] + target_resolution = self.max_resolution + + # Project input features to style vectors + styles = [proj(feat.view(batch_size, -1)) for proj, feat in zip(self.feature_projection, features)] + w = torch.cat(styles, dim=-1) + + # Generate coordinate grid + coords = self.get_coord_grid(batch_size, target_resolution) + coords_flat = coords.view(batch_size, -1, 2) + + # Get Fourier features and coordinate embeddings + fourier_features = self.fourier_features(coords_flat) + coord_embeddings = F.grid_sample( + self.coord_embeddings.expand(batch_size, -1, -1, -1), + coords, + mode='bilinear', + align_corners=True + ).permute(0, 2, 3, 1).reshape(batch_size, -1, 512) + + # Concatenate Fourier features and coordinate embeddings + features = torch.cat([fourier_features, coord_embeddings], dim=-1) + + rgb = 0 + for i, (layer, to_rgb) in enumerate(zip(self.layers, self.to_rgb)): + features = layer(features, w) + features = F.leaky_relu(features, 0.2) + + if i % 2 == 0 or i == self.num_layers - 1: + rgb = rgb + to_rgb(features, w) + + output = torch.sigmoid(rgb).view(batch_size, target_resolution, target_resolution, 3).permute(0, 3, 1, 2) + + # Ensure output is in [-1, 1] range + output = (output * 2) - 1 + + return output diff --git a/model.py b/model.py index 6fc1625..56fc609 100644 --- a/model.py +++ b/model.py @@ -15,7 +15,7 @@ import random # from common import DownConvResBlock,UpConvResBlock import colored_traceback.auto # makes terminal show color coded output when crash - +from framedecoder import CIPSFrameDecoder DEBUG = False def debug_print(*args, **kwargs): if DEBUG: @@ -427,7 +427,7 @@ def __init__(self,use_resnet_feature=False,use_mlgffn=False, latent_dim=32, base self.implicit_motion_alignment.append(model) - self.frame_decoder = FrameDecoder() + self.frame_decoder = CIPSFrameDecoder() self.noise_level = noise_level self.style_mix_prob = style_mix_prob diff --git a/train.py b/train.py index b248477..14087ba 100644 --- a/train.py +++ b/train.py @@ -54,9 +54,22 @@ def get_noise_magnitude(epoch, max_epochs, initial_magnitude=0.1, final_magnitud def train(config, model, discriminator, train_dataloader, val_loader, accelerator): - # optimizer_g = torch.optim.Adam(model.parameters(), lr=config.training.learning_rate_g, betas=(config.optimizer.beta1, config.optimizer.beta2)) - optimizer_g = AdamW(model.parameters(), lr=config.training.learning_rate_g, weight_decay=0.01) - optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=config.training.initial_learning_rate_d, betas=(config.optimizer.beta1, config.optimizer.beta2)) + + # Generator optimizer + optimizer_g = AdamW( + model.parameters(), + lr=config.training.learning_rate_g, + betas=(config.optimizer.beta1, config.optimizer.beta2), + weight_decay=config.training.weight_decay + ) + + # Discriminator optimizer + optimizer_d = AdamW( + discriminator.parameters(), + lr=config.training.learning_rate_d, + betas=(config.optimizer.beta1, config.optimizer.beta2), + weight_decay=config.training.weight_decay + ) # dynamic learning rate scheduler_g = ReduceLROnPlateau(optimizer_g, mode='min', factor=0.25, patience=10, verbose=True) From 66f63c52229406cb49400248a85648c2f05166d3 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 07:23:13 +1000 Subject: [PATCH 040/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 0 -> 7561 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 __pycache__/framedecoder.cpython-311.pyc diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..766903d4762d9e62cef620724a675ff61daf6a37 GIT binary patch literal 7561 zcmb_hU2GdycAgn>$RRlrIg+hkN|t5WF}<>WoY=8rJ1bj?;=lFAiL>3xNg0YWk}2^= znjsU5OSVQ4G?bwtN}z?gLDFS`tgFCTKm`g^Mf>2F0_od~I)#V<1Q;mtpf?51qG?{* zb7nY04kfw!&`a^~&b?>ux#ymH&iU@u-!(LN2`K-(`Hxc_euDTW%~zg?hp{L6ZbUyiV}chXn5ss2P_z)FTH$Jch4PCfqug6S(N- za4|ZU3%K0R;qvh0TSSQ2#-VRUD-!e{Q{J;<@c0pEAP^HIPfSoeT_C1N-tnPl!okBV zCY%N09yvk7I4;dUp2~5M2y+XP7v+qgw&dm< zS$K;lzos66gvV7WM3^z(;|WmTadSLGY7L?=pNWC);b}3#hddfBW#$A?W5VHBGL{a9 zHBT%#mq~~DSVChHk-528ayl%lbYv^h!{x-X?TN9r^M)8jTR$GKBONJ=13G3uN31`>SBOb&@_f$&!AR_YF*7_Bzbf8hBBS7vGWwznt zD-W(L-(LBJ!Uk0~SRl9j2bLC!++uDiCo`R53-Ey;oUVu~{gn$hNfHbMo5ULqCsKSS zj&(L1p3g+$MvrbXF$nxxL&fY-As!Ehzaqs0&?@1=jOg*)zF1RVSKBYMt#-?O||huj~?WQ3`7rzbT!e#q4|D>8b{IsKR#-U z7pRB6E7qX${}C_t+atN%K8~{Wvd=pX*jb>gEiPC)5aWW_2Cs))B8DNuco=##*Ns?G zh=>|B71g}DsUzuFDjA}Uo#%@s!&8+gss)O zX}L4*acLCDzV5u$bYLY^yuKP#n$DN$AgfK0op!lp3Grpqi$}aJA@qK zarn_ZnWQwI5rnKDzF;(?my4a)N9(I9p+>Nn#!{k}Wt(f4vf}6Hy3%>=s%-oK^3O#lJXv$g(ePwpfi${H zrOa3uqTmGZ1K}%2j@6+M`v!!Oba03Rq+88#07p=~!>$u>V%pG)O{RrV(zr<~Wd?JU zUU(-@gEo#wpxa|hzg@GyEuTK|#!=s|knU3pg;-KtV@w0$mP#Jlr( zH)L*0LJ9NLu9E&qwJGPyQM?=7d(ml!!S=Di=5C(Kf+cxoR^LC8C+E)5yfX_{Fj~ZZ zkZsNOkC0uF%zOS%*%i*bYf03Rp!a4Jz1b|d!J6Uw>=huGzNR5MzS{~=rG~2tXmaFY zT^jm?T51Epf@lFZg?J1|$e~dKLz)BA0Ipaia^2S>*N@%{@tBm3rV?|nW>408sFe@~ zF2qv+Af;EKVzM8_u}%SjkOP^2*jpv#6DR-r^3v7iXz`rF9#z?+I}TgF9$mI^`3+EW z5^FC4iJ`xXFX1EhZliC8T8xZ1jIHP@8e7S5nAs%Fk({2=8WWLwVZhR8CEQNKgWe~$!e zM|@B_;=^$c{edA453P}FFDgARt35AcTh~@k@0aZF-~Rf}I=4Rhq+7lfR!+XBo_tT~ ziKsmho#oq>ww1&usg)E+_ifcaJZ(8t@O{g*EN2z2SLJ#Eka`SKxCF^_>G~$yy}@>W z@vGJI>yGuvKe?Z@D~GPDhpsE^4VAqivo|1{E)OqeisXZAAzQpvI^}rUb^-z_=V3FL!j6$lttBxL&xv<1wZ0 z1~}-(dfxg=u!B8kRh^e7c?uv7@KulOdB@(4z9+BK#%tw5i1be0@pWAkd{%yI=jaly&#scT@)d!Q&^j^K zGCqQbYW}v*qMCW_?SMVpH@E)>iilxGqhd2~_-0~BUbqL&3x}a|(5#1>kk>faG4WVZ znu|mQjY$a7-Dxq#BkGGpXT<>=!X*lGfVed(ozk31;a*zv&Y6dC2nM|Hek>{gkjEUN zc_$<3=w0L7?yZSXI+i3CG?yMtBncMrQZPo zw`aMfo5fSb`O*o6J*l!M;Y7Jd6C5)Cycj8QtG8wUDaC(E^`BaF7RbWz7T2=Ui9N+% ztTxG9zryvaT>mC_W`jGkHm`8!RPNj+H?qNvJn8u6mi)#o`K{Xu_qNKtU2v8C2R~kT zu<*r!BLCaeXDOw7Q0*R6+J{v9ds*?HRsCmW=B(i~)~1DNTVU^V*p?qP;*p*w+o8iK zGiC)YSgNfL8fy>ud8*I|po{vDsXc-LM|Gpu9l`3M891N?cxo?!at__4V7}GkTOJ_c z-dlOGvHp0dZ+ivvvdsN;L1`y>+~pEO!-Mny*evW}mfo=a^{#zYxOg>ka!M z?@1r34qj(v3&4W*7Id_aVGj)eG=HuKMz2nNd6pt_KDgT3CE2#h7Bqw%%G$?=)1;TD zb$Cfu;ALaZ3-ht8#CvOira$KaiR^BPSx_?XS$e~^LKXFY}7DbNy5-zDhV zT^1dD-Qmry9bnp{Ae9tvz`GE-o$tql4?<4yEIevXL~oiCv7GozXpk^tK@6>cCB)I$ z{#}u?ayZ`$)sEsIr-1-`lDS#q@9$^U`1P^nKKP^`E_}LB>R5eKX&q8qhm@vawQ1O5 zG>*+UQSs+U@Z_&ClQBpraD{U_L}^aA!4b5ENIW*3gzzpV5u9E?UJC?IfA$+J|Mt_+ zK}b7yBXB8A}mD`b8$f&!$y~IZw_vT zH8%{D$fO0$4QU96)jR?qWI=?BH1RbYV;tGNaUn9h0AMT)7Ho;W{(<&nF5J#w`_8K8(d#$LLR=X zaHA?W3dye}_;B{qSpb%kYloGVm(-S*3O-;C_H1?zZFCN;bt;_~)y|7D)3((a{QM%o zR+T-x$sXTekFR#Dy}91B{-(l?sO(6YZ7Oq}<&N$h4;=((4diOZ zNEU|`Z!g!*VNu?I|S_+(h1x{6)r9Q#WKD8 zez`TcoGJdYG_SOV)Yeep%2sRJO0INeHLA3pQd>_It~{-S*L`0Ou3b`kPj7Ucmb*@u zTiT1`#qrYF)mut%5UilI3@$r<06@ReRg@I&h{_$=0@AJ3g z^pNn6fZb$yawYa1wtrvkZ1IqA1-Y9n_pXGHmqXqPN6QtGgr(b?E{iJ*-|2nd@57-h F{l9rnJI??B literal 0 HcmV?d00001 From 8658951285f34fa61de2e7528922b1e828bf721e Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 07:24:18 +1000 Subject: [PATCH 041/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model.py b/model.py index 56fc609..826fe8f 100644 --- a/model.py +++ b/model.py @@ -427,7 +427,7 @@ def __init__(self,use_resnet_feature=False,use_mlgffn=False, latent_dim=32, base self.implicit_motion_alignment.append(model) - self.frame_decoder = CIPSFrameDecoder() + self.frame_decoder = CIPSFrameDecoder(feature_dims=self.motion_dims) self.noise_level = noise_level self.style_mix_prob = style_mix_prob From 0d7bec529e46b36709b5a0b1f1dce13e85c2efba Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 07:25:53 +1000 Subject: [PATCH 042/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/config.yaml b/config.yaml index 746e4e0..78963e9 100644 --- a/config.yaml +++ b/config.yaml @@ -44,7 +44,8 @@ training: scales: [1, 0.5, 0.25, 0.125] enable_xformers_memory_efficient_attention: True - + learning_rate_d: 1.0e-4 + weight_decay: 0.01 # Dataset parameters dataset: # celeb-hq torrent https://github.com/johndpope/MegaPortrait-hack/tree/main/junk @@ -86,4 +87,6 @@ loss: weights: perceptual: [10, 10, 10, 10, 10] equivariance_shift: 10 - equivariance_affine: 10 \ No newline at end of file + equivariance_affine: 10 + + From 467727e3757a5ec3772645a2396281d7937d7f27 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 07:28:57 +1000 Subject: [PATCH 043/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- framedecoder.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/framedecoder.py b/framedecoder.py index 679ded6..12af508 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -60,40 +60,64 @@ def get_coord_grid(self, batch_size, resolution): return coords.to(next(self.parameters()).device) def forward(self, features): + print(f"Input features shapes: {[f.shape for f in features]}") + batch_size = features[0].shape[0] target_resolution = self.max_resolution + print(f"Batch size: {batch_size}, Target resolution: {target_resolution}") # Project input features to style vectors - styles = [proj(feat.view(batch_size, -1)) for proj, feat in zip(self.feature_projection, features)] + styles = [] + for proj, feat in zip(self.feature_projection, features): + print(f"Feature shape before view: {feat.shape}") + feat_flat = feat.reshape(batch_size, -1) + print(f"Feature shape after reshape: {feat_flat.shape}") + style = proj(feat_flat) + print(f"Style shape after projection: {style.shape}") + styles.append(style) + w = torch.cat(styles, dim=-1) + print(f"Combined style vector shape: {w.shape}") # Generate coordinate grid coords = self.get_coord_grid(batch_size, target_resolution) + print(f"Coordinate grid shape: {coords.shape}") coords_flat = coords.view(batch_size, -1, 2) + print(f"Flattened coordinate grid shape: {coords_flat.shape}") # Get Fourier features and coordinate embeddings fourier_features = self.fourier_features(coords_flat) + print(f"Fourier features shape: {fourier_features.shape}") + coord_embeddings = F.grid_sample( self.coord_embeddings.expand(batch_size, -1, -1, -1), coords, mode='bilinear', align_corners=True - ).permute(0, 2, 3, 1).reshape(batch_size, -1, 512) + ) + print(f"Coordinate embeddings shape after grid_sample: {coord_embeddings.shape}") + coord_embeddings = coord_embeddings.permute(0, 2, 3, 1).reshape(batch_size, -1, 512) + print(f"Coordinate embeddings shape after reshape: {coord_embeddings.shape}") # Concatenate Fourier features and coordinate embeddings features = torch.cat([fourier_features, coord_embeddings], dim=-1) + print(f"Combined features shape: {features.shape}") rgb = 0 for i, (layer, to_rgb) in enumerate(zip(self.layers, self.to_rgb)): features = layer(features, w) + print(f"Features shape after layer {i}: {features.shape}") features = F.leaky_relu(features, 0.2) if i % 2 == 0 or i == self.num_layers - 1: - rgb = rgb + to_rgb(features, w) + rgb_out = to_rgb(features, w) + print(f"RGB output shape at layer {i}: {rgb_out.shape}") + rgb = rgb + rgb_out output = torch.sigmoid(rgb).view(batch_size, target_resolution, target_resolution, 3).permute(0, 3, 1, 2) + print(f"Final output shape: {output.shape}") # Ensure output is in [-1, 1] range output = (output * 2) - 1 - return output + return output \ No newline at end of file From 23d07d871b6020c7f9aa949ae7a1befedb6e379e Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 07:30:45 +1000 Subject: [PATCH 044/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 7561 -> 9283 bytes framedecoder.py | 6 ++---- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 766903d4762d9e62cef620724a675ff61daf6a37..20212d830e4a4bc3197d51d8cec0de377873a22f 100644 GIT binary patch delta 3291 zcmah~TTC0-8J@9?eQd^qvGLtx2X+(1C0TZhH?Tm;-CVTWG=xMXVM`5T920`s@vv>I zv7Ab6<*g!>bfSjcRkg0vQlwDHts)gw`%tMbZJ*ApMh$g^x_#K^vfA$Js{fg>3FCz1 z@cYcU{O9|>|8frR_x)vp{jJmKprHNYUw>Ts!t;BU^W6Bmm7*R{B2AX{#?!p%N0da5 zQd_X_u(nKnME%mV1$qw;b!pJz&A~&B<-?t$t9g?|Et}*vm#@HEB-2L}-yoR*%u)}Q zqVg8ec8Es7F-vVG(OV`k4L`k&Qt#ZEgIT^IZ#9s&O4cVd>L?Rz;B1xu?eZ5)yv+!i zw;6ztWIGl}!!c1(gJd&0kw*p$yM&Iz(DDFeEG4V_AD6d)j6gJX_lIIr)miu~quN-q z@U|UC9d*0GY@=j9&MZLzt5nHe!_&X=e_ggQlVq;Lbr^6N$?*bQN6A6F8HeEV&Mh-! z41RjcHyW^5sqqC^RH>0*EjAh>lV=P#O%n5b9PvqgQp?U78>lU|HZC<0JM)Ay{B1$a zJ$zDk4&KE#@3@Vtd8dJcM{+)&gXAo^jyw zk0}&Ja;5MU6d2^l>U#>-qqU!FbyS4){^;E}lo4>UHcl5{F9c|Kj zSH2mpj7y>=dtH?o2=PIw3EqfM9jm8a3qxL84mr?#xQbq=uH|W|iFfAUo#V|nwFe~2 zN=-o4TRX&ib8sE?J)r|gHa~-;`)JhoNR_195IO_N9_^;Fk~;@>^6cH87}Dgc8JAdN zed{;K4|h>A^DV}@)2pZeRRlEz7tBzqd5?i%tK>0od;0D6KvWtaDw_9skL0;TMNfP? z=bv0%TNk;dWI|jQl36Z$Be9mu_Hv80*ZnO*o!o1Z?JpzHJ)lzXro3Z^P}9pW*V{eQ zUB@L#=!RL8t~6*kQIi01p(^xJmZXB?WYOPK7=U3m1l=79pB~Yrd~v&Z$A35#NWcVg zCdUjX#Kjw2Hl0iMayjoAZYCk5k|GCWGb`(2I$0s$d2Ksa$pyuFgFp4T_mgv@esQMVvm&_G5qw-(Z?NuJAU+V8=3 zM$iudhJPfJ5%~0KLQHZgAq`qOxX?!qx5ZZyqL^Hb!STgsE9L_6%({>UV#a**a`M;y z_Cl`nh=Js-g(MHgQq`PQJtB7Jvx!@4E8vyT4~B)ar0XISI$nJKNkoE&&Q{43s+Nw# z+h}_waT}=TOuby{!w*RN+dwWfJwD83)WMeXKTtx0YH09)Dp>7p2NZPgPZdWu zUfE}wvAe%)dE|f(h{Bv#ne&)A4~~D#e8`lpmEB4prUqigNlka89DUS@-6J3hGpaJ9 zm>Jz`32cmR2JcxEsDoF?fSpy)RTW*u=;~g$<8ip_({NXLsQm73+wMgr zJfMaL9*4(24UhlTfoHDZ+3QMpP7TiiltyL-_rIm~Ph)llL_wETbQz<|Ust$XwiiS}eJbk1sP8kjsc7DB@A%ya7%g5Pom%_0Re*>%?ovKk3@BL9Pn1MH)pdrL?w+kr1V!hYE_Kkd58gpr#wA zDA8K!0SUVblIbgosVcEtOr-SSa^dfYD#6WLt(3@8xp71a0tDQcx2{cXnz7$#=YQs# z`5{+3>HgB?;tWPHhGidmrdY?M=6WinwR0FjaRR zOU#G1GTYN31Rf^x^Jws;Cg^ti`iiM+?+d8d4^T{!>7KdCfG@l*Rm=(^?J*%xb1Te{ ziB(wDN}|>=q2`uM65E9Kci}ybw2(9(m{BZE=yLbJPOat9o(n=n0|F{Tua2Hxb<(WQ z8(E>JrWq2z>ypAs7Rh?gHmFryG$bHliIwaUH)M2Er~pUu zPTYVFR~VQBD-02>f~jVbh$j%4`Gna@Kr$@91`$czF<)O^`XFhRXF=?kSLCb3jyb=y zP?j%&VwGN$D2g}8`@z|{4AmS5_V*3c=1blG{bEb{CR>lgue`l}Yh4ZQ%xV6V?oVmn zwC+v!9c2L`jjHl_`gIZ**$c(8@{3e_j8I04Ux!`2vaqrwp9Xcuw6VM* zpQTPuVS*`LC^(U`H(jy+3UPQ-6M_B_s`8X0L4`epuAQQ*GP$fdzTr!~NMFWPUX3 zx%u|OvX(9C*&-M;zNGV|s^yWB-+Hy4S7U8wLUSf`XQGN}w6Mp=AM$bat#8WqMH=NjkK*_{hocnSvBgeme}^b8q`eRU+k q3J2(LMt+H9@fC7(avCdSe)5Y<4r2&-I Date: Sun, 11 Aug 2024 07:35:48 +1000 Subject: [PATCH 045/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 9283 -> 9112 bytes model.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 20212d830e4a4bc3197d51d8cec0de377873a22f..2bb2b63b927cff0a8ea8f25b33deea81db10a210 100644 GIT binary patch delta 988 zcmZ`%J#5oJ6uyhoIu5pD$9CKh<&U(L3IkA3g;ZEb)S{?Nr4j=I3YSs>B2gM78lCBe zK^tM@B6;@zIPwbtW7G< z6(vY;(C=^O?z~pg1}(#3;}e~RZDVZMM|dCaZ%{&B;Oq=XtnTayVImiyX{KRIix{Fe zl9J_f@Q&IUKMSl&4Ii@L?6L$rEJ1jM9HY7oH!TB(tRsR{m*A&n4f$Ar$qfOqIJ?-O z2dnTPW{{)=nwBjFJ2}7?>%?Hy?m-AC%Dp`&MWwf|tFFi?R|KzyA7yZM2A5cfY3sW8 zt0Q~F!(4X71-TxAv}pD-fyvAO9vvR$E~~XGbuT)C=*xBz-qVP*kIUB(u`a^E~@CAVfi%&+D*Mv{+(BqmtFB#LpHn~-7A zJ~J{LgpPRf1he4rtGbZ@S7We%Ase9u#d&Lr4hEHp^(2TD8Pa?>XmEyTgbl4>1fS*9 z0Si3R@ZE;q1rn`;BwP1Lm%zkp2!bj58BI}YMdUU(GYAvW3nA$eR??Kskc%*xQRYV_ zosTo=88H$d$I|2f#vZmkVAN>d)d!SGhFJ1q4{$f?-7y+u66Q#=Z49l1{nd~a4kKU4 z;@gFyIMUO?B(LQNQ{YW3)fr?;3DL~p)Py6}h)uzQj(Wn97W23m$2fQwId}+@i>bv% zaMa0E9h5p#t%_(W2IIkC>Q0*~Q_qw3B)p98&}NuTXyMyK`8(aMg|kRq();oQ#ao3u zY&J*XeWLaCkHj4HZ}5fu==Fh-JUnRWgHJ6P8V0X^%)j9uMD;6!H=##Ahj)+k2)xic z=ok2?k6rfq(^~|$hi4{p<^7vNsHPlWbT0?+hjUfPt_j(ikX?_p%nDD1%Ed)xHP*Ej z>zX`W*X>KR?f}DANNo9zaC3I=TccuWQ***L3dk@4#9X^ e7ry=Fwz=M`l%m1^_i|zS?si@FYY&ZrRi58AHRw|S diff --git a/model.py b/model.py index 826fe8f..a662fca 100644 --- a/model.py +++ b/model.py @@ -427,7 +427,7 @@ def __init__(self,use_resnet_feature=False,use_mlgffn=False, latent_dim=32, base self.implicit_motion_alignment.append(model) - self.frame_decoder = CIPSFrameDecoder(feature_dims=self.motion_dims) + self.frame_decoder = FrameDecoder() #CIPSFrameDecoder(feature_dims=self.motion_dims) self.noise_level = noise_level self.style_mix_prob = style_mix_prob From 2a4e77b3a2cfd2314083268e3b5810cfc07e2135 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 07:47:24 +1000 Subject: [PATCH 046/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- framedecoder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/framedecoder.py b/framedecoder.py index db800af..e57fa01 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -69,7 +69,8 @@ def forward(self, features): # Project input features to style vectors styles = [] for proj, feat in zip(self.feature_projection, features): - print(f"Feature shape before projection: {feat.shape}") + # Ensure features are flattened correctly before projection + feat = feat.view(feat.size(0), -1) # Flatten to (batch_size, channels * height * width) style = proj(feat) print(f"Style shape after projection: {style.shape}") styles.append(style) From 9e79c19c743f0082d2ba207382f716a698c7db42 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 07:49:27 +1000 Subject: [PATCH 047/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 9112 -> 9123 bytes config.yaml | 1 + framedecoder.py | 2 +- model.py | 6 +++--- train.py | 2 +- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 2bb2b63b927cff0a8ea8f25b33deea81db10a210..112d514f51c31e8e84a5e43f9cd31c8b16cbdf7e 100644 GIT binary patch delta 845 zcmZ`%O=uHQ5Pq*O+3j|-={EbHHrpm?4{EJHc&T_Z5K#|;UW!2xUn{L5x($lqMZF01 zAQY!y{Q*;ySZEMR3toEgss};v)JxlgUIeM&#hbHfx7uQ7nPDDpzHh#nx6{Q}MeU8I zwGp(hEAK|u=C!Orr{I9`Q5`O|5K@g-B_Y9z?l7q)K~^?%oYg5IbLbn!r8*YQp>J4K zf`ev}%CM$e@W@Qj5m-@dNyNC+XoV#+!xYg9H@t3`w_Gr+9@Zw5@JyAkPrijqT&**t zMBk)rD{r+UgePFq%2T!3;DhBeO{j?X(`7FL5HhAkxgcXn*z5Ve2$}J#$B-L+!-T?< zyuG2r1v|q!L?Tqe5;Z-X)#Z$(9g%!axC^?QJ@xOn)>b9u-i`;MBFWW? zgqvemo6j_!6iH;XBMdh{xQFEQCikMdU+akOaFq`Ha6MSnZMj7-!FURm?A}(~4Q__3 zaK_8fPPp#n4Yd--CPoGSr)o;1&XPS{@Y?$w4^AT+yiaErzNH^f>cX|I&YtDitJ0~3XXaetO|AB)@TWcL6^a|JpI_i{(*eOS!R9KT2@0{$N- RgQ+_|n|HralemxvGxji5yEB;G_kRZx&(p)S1ZC;M29m8FHGoZA;iU&o)yHl2;al{5f2+;Qb#}-T#i`5fjea$ zZ0Y0hQ_lzr58?Htrs1X%QkPttle+BMZN&yJ+Ti$fyaHMw-SjCuesjhPZSY{n%kRo~0Bk@-xAbn?>jviIm#E zht#94-9a)pJ+L2&!Ea++jC17)VF^xQJD!;yXh(C8aLSU-G1p^Jt}toUPY4sPARr+t z3Rkdg-Oa$N+woI@lmic6<0nS@lps|xH;>DMgL7tKc(}j?79FIkeL%}mqyxn8T>Hh52gJoU<+OjiuVn-p0srN?*X2 z^xCLb%+c<%^W!f&c4u^U>}7gV?V_GEAHMfA6-s#{d8T diff --git a/config.yaml b/config.yaml index 78963e9..c8550b5 100644 --- a/config.yaml +++ b/config.yaml @@ -5,6 +5,7 @@ model: num_layers: 4 use_resnet_feature: False use_mlgffn: False + use_cips_generator: False # Training parameters training: initial_video_repeat: 5 diff --git a/framedecoder.py b/framedecoder.py index e57fa01..8762d97 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -25,7 +25,7 @@ def forward(self, x, style): return x class CIPSFrameDecoder(nn.Module): - def __init__(self, feature_dims, ngf=64, max_resolution=256, style_dim=512, num_layers=8): + def __init__(self, feature_dims=[128, 256, 512, 512] , ngf=64, max_resolution=256, style_dim=512, num_layers=8): super(CIPSFrameDecoder, self).__init__() self.feature_dims = feature_dims diff --git a/model.py b/model.py index a662fca..682aba2 100644 --- a/model.py +++ b/model.py @@ -397,7 +397,7 @@ def forward(self, token, condition): For each scale, aligns the reference features to the current frame using the ImplicitMotionAlignment module. ''' class IMFModel(nn.Module): - def __init__(self,use_resnet_feature=False,use_mlgffn=False, latent_dim=32, base_channels=64, num_layers=4, noise_level=0.1, style_mix_prob=0.5): + def __init__(self,use_resnet_feature=False,use_mlgffn=False,use_cips_generator=False, latent_dim=32, base_channels=64, num_layers=4, noise_level=0.1, style_mix_prob=0.5): super().__init__() self.encoder_dims = [64, 128, 256, 512] @@ -426,8 +426,8 @@ def __init__(self,use_resnet_feature=False,use_mlgffn=False, latent_dim=32, base ) self.implicit_motion_alignment.append(model) - - self.frame_decoder = FrameDecoder() #CIPSFrameDecoder(feature_dims=self.motion_dims) + FrameDecode = CIPSFrameDecoder if use_cips_generator else FrameDecoder + self.frame_decoder = FrameDecode() #CIPSFrameDecoder(feature_dims=self.motion_dims) self.noise_level = noise_level self.style_mix_prob = style_mix_prob diff --git a/train.py b/train.py index 14087ba..11577bf 100644 --- a/train.py +++ b/train.py @@ -366,7 +366,7 @@ def main(): num_layers=config.model.num_layers, use_resnet_feature=config.model.use_resnet_feature, use_mlgffn=config.model.use_mlgffn - + use_cips_generator=config.model.use_cips_generator ) add_gradient_hooks(model) From 871cfcc40e6fa84c7228bd2e21f1b2f6cadc6920 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 07:49:34 +1000 Subject: [PATCH 048/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 11577bf..41f0385 100644 --- a/train.py +++ b/train.py @@ -365,7 +365,7 @@ def main(): base_channels=config.model.base_channels, num_layers=config.model.num_layers, use_resnet_feature=config.model.use_resnet_feature, - use_mlgffn=config.model.use_mlgffn + use_mlgffn=config.model.use_mlgffn, use_cips_generator=config.model.use_cips_generator ) add_gradient_hooks(model) From dd7b4e7e9e25d26a8c1378f3c4ae47147453c05c Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 07:51:04 +1000 Subject: [PATCH 049/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 9123 -> 9157 bytes config.yaml | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 112d514f51c31e8e84a5e43f9cd31c8b16cbdf7e..01e950f134cf2167a74bdfa6ba8c739f727305ce 100644 GIT binary patch delta 483 zcmZ4Ne$<_JIWI340}xm~*q-)%Bkx|0iL#tVRzQJvhUpBc3{i|J3{gz!3@MC@7*m*1 zm{V9%SkqWL7}6M1*jhNFm{ZtW7@}BGI3|DRP*B!ndD#HeS;Pt?UNSHNS;ioi156Ca z-K@gN!@_7f*^u|Ps2xb$9z=juXoFaen-%%$m>I1$FA%KcP_+OGX)@koEzU?RNY!L2 zvIO#PF{Tv*jb>0#C=vzfu%2uu)xzjI`LL7lU1dsFxpPuCB54#0i?tM zNZews$}A{y2C-PmGE>WMu@q-kr51UDMokJy5ztmv z@CR7f2qGyg#7-N3!ikEgh@GWac_X+B-_G~g8JN4-(^>6I(`tx&-k-0GmrLzNJ$m$Z z^}50VIU$@RDI-aVsfWa%OeuSl9#&N*Ntsm(#T(Q<<8BfZ7e0wpne*1kG%A+0hn6wa z#}OtFdJy83nfiSX9j@Fq71ZuQR|-={u{ZS`4@)rQa)%nWVrI11Iana*TQ7H$R~a@FAab{@D5D`K)E z(zl9Y?M5~W%uX4WI}GQ*G|C=GufqaXilrAgK%#Qm(J1?5=*JB14k8Fqanb3gYvR7M tw+S15m&pJ9l0-1nCL*CdYKrqvz>mv=Rs`JBj~dHlS>EE~P!cbpw_m_Se*6Fc diff --git a/config.yaml b/config.yaml index c8550b5..27babb5 100644 --- a/config.yaml +++ b/config.yaml @@ -5,7 +5,7 @@ model: num_layers: 4 use_resnet_feature: False use_mlgffn: False - use_cips_generator: False + use_cips_generator: True # Training parameters training: initial_video_repeat: 5 From 3ed7160bf6e33ea52abc1e17bb80ec83ab08c899 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 07:52:12 +1000 Subject: [PATCH 050/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- framedecoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/framedecoder.py b/framedecoder.py index 8762d97..3b2579e 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -70,7 +70,7 @@ def forward(self, features): styles = [] for proj, feat in zip(self.feature_projection, features): # Ensure features are flattened correctly before projection - feat = feat.view(feat.size(0), -1) # Flatten to (batch_size, channels * height * width) + feat = feat.reshape(feat.size(0), -1) # Flatten to (batch_size, channels * height * width) style = proj(feat) print(f"Style shape after projection: {style.shape}") styles.append(style) From 56ea059073fff4254ddc27e0b14c96042531fe4e Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 07:55:26 +1000 Subject: [PATCH 051/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 9157 -> 9157 bytes framedecoder.py | 20 ++++---------------- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 01e950f134cf2167a74bdfa6ba8c739f727305ce..982f1be43a2f0a16e4fc9af96c394912a7721716 100644 GIT binary patch delta 216 zcmX@=e$<_JIWI340}yaM+@AJpBkv?pM()Xq>>`u5h`KZKfLQ`!j>^&?B@mFs2V>PR zWbwn;3k1MwJH+A{1;MO8V&aoo1jROMiU%+8qmxC%8mMN)jSu(;^v36iqR8tg@> z#TkhOskc~)GpkaIJb}g)fy}(coSaxv6boXplx3!tPd+1~vzbx0gHhaOlFb5^<(!K+ gS12yzS;Dg->7uaV6=B283*@+27_&F8QIcc?0P=)3*#H0l delta 213 zcmX@=e$<_JIWI340}xm~*q-)%Bkv?pM%KxS>>`u5h`KX!gINM%j>?iCB@mFs3uDzV zWbwh+3;4lmJH+A{1;DI7V&aplB}F%DhzBw+x;;s+A9n3EGrieiDR$ya4`*x8Fxi!%}nHvgCDVidQWWH}>b je&npk1%)$X=ESTpx+tu7MObh10y%CL#;nb2lq4Ae4Z%3? diff --git a/framedecoder.py b/framedecoder.py index 3b2579e..7f8a373 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -34,8 +34,9 @@ def __init__(self, feature_dims=[128, 256, 512, 512] , ngf=64, max_resolution=25 self.style_dim = style_dim self.num_layers = num_layers + # Adjust the input dimension for each feature self.feature_projection = nn.ModuleList([ - nn.Linear(dim, style_dim) for dim in feature_dims + nn.Linear(dim[0] * dim[1] * dim[2], style_dim) for dim in feature_dims ]) self.fourier_features = FourierFeatures(2, 256) @@ -61,7 +62,6 @@ def get_coord_grid(self, batch_size, resolution): def forward(self, features): print(f"Input features shapes: {[f.shape for f in features]}") - batch_size = features[0].shape[0] target_resolution = self.max_resolution print(f"Batch size: {batch_size}, Target resolution: {target_resolution}") @@ -69,52 +69,40 @@ def forward(self, features): # Project input features to style vectors styles = [] for proj, feat in zip(self.feature_projection, features): - # Ensure features are flattened correctly before projection - feat = feat.reshape(feat.size(0), -1) # Flatten to (batch_size, channels * height * width) + # Flatten the feature map + feat = feat.reshape(batch_size, -1) style = proj(feat) - print(f"Style shape after projection: {style.shape}") styles.append(style) w = torch.cat(styles, dim=-1) - print(f"Combined style vector shape: {w.shape}") # Generate coordinate grid coords = self.get_coord_grid(batch_size, target_resolution) - print(f"Coordinate grid shape: {coords.shape}") coords_flat = coords.view(batch_size, -1, 2) - print(f"Flattened coordinate grid shape: {coords_flat.shape}") # Get Fourier features and coordinate embeddings fourier_features = self.fourier_features(coords_flat) - print(f"Fourier features shape: {fourier_features.shape}") - coord_embeddings = F.grid_sample( self.coord_embeddings.expand(batch_size, -1, -1, -1), coords, mode='bilinear', align_corners=True ) - print(f"Coordinate embeddings shape after grid_sample: {coord_embeddings.shape}") coord_embeddings = coord_embeddings.permute(0, 2, 3, 1).reshape(batch_size, -1, 512) - print(f"Coordinate embeddings shape after reshape: {coord_embeddings.shape}") # Concatenate Fourier features and coordinate embeddings features = torch.cat([fourier_features, coord_embeddings], dim=-1) - print(f"Combined features shape: {features.shape}") rgb = 0 for i, (layer, to_rgb) in enumerate(zip(self.layers, self.to_rgb)): features = layer(features, w) - print(f"Features shape after layer {i}: {features.shape}") features = F.leaky_relu(features, 0.2) if i % 2 == 0 or i == self.num_layers - 1: rgb_out = to_rgb(features, w) - print(f"RGB output shape at layer {i}: {rgb_out.shape}") rgb = rgb + rgb_out output = torch.sigmoid(rgb).view(batch_size, target_resolution, target_resolution, 3).permute(0, 3, 1, 2) - print(f"Final output shape: {output.shape}") # Ensure output is in [-1, 1] range output = (output * 2) - 1 From 22bcc9ba5b6716c18409a4ee292ab2ef0520619c Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 07:56:45 +1000 Subject: [PATCH 052/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 9157 -> 8017 bytes framedecoder.py | 21 ++++++++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 982f1be43a2f0a16e4fc9af96c394912a7721716..417a84291f3cfc2f3459db7a27548dfe425a3028 100644 GIT binary patch delta 1490 zcmZ`(%}*Og6rb6(*I&Dfjm?5F7#m1al7fpXT^t(|=or$1 z)@qNYq7^Wt&?*O4i6~V$Q1#Hipi<}wN0vaADmBNdJ-BMpL()Uv*l{RndD8poef;LV znYT~-2*uFaZ9B_Ikb*g*r*g)P(JcPGs{N?WMk_u^8!LDKT%kkgsjU@tDMdK#NiwzN~XPuk)(-d>%2y+YBRu0l^nJJi~Im9Igc$UTmKPAGX*| ze|HbzZ)k7z2=1k;)x+{7s8234=m4M#p8AhSfVphb$!UPdmtQd1>1nFW8K+6~)(m-% z<=Cd?`1$BG2{Ow>w8YKuv_>wlBwMo_KNg-L*O_Iw9o0S{eayrIIa`E8)zL(FayqV& z5DN<0htuI(YR&}b>EAW>*ZUz#23X1(*TNHXqeP3(l3o@tC8CkZsi@lIB<$Pe%;Q>k zW)`}Wt29`9y7?Uzw2e(o5j8P79)~NizQsN~iQFKAj9`q1=QKhF=)Kw@6MIS`$cV8+>!b+)8J`>v{`(7J1RaCI=%Pak-j zLU{p|D|?Q{hmOYdwcWO0#?hB`^reJ-rDknnWg>lI^TxJ6qqJw0_LL3CHNHJh>qAfL zwkPB1%6htX%~e0syLD%bc@fpTMxWH3q@BK6`o7O62$se?66l(*j_&z-=o|hEblu;E zuhXCWS1wf>>+|R+=r9`3c8q6!5Aw2KQ58gxjzcI$44*m$SRG1MK~V4i>6Fv?3Ici`*Bqnjolw1WOlY*e)zAIXl>}nJwu< z8=6>y9$U5bVQmc3$^)jRZQ4HcrAhT^#!M<9wXqLAv`;kJ7oU30>@G64q<{YT=bv-F z|9)n6l_Nj(vbXGZD}~AW`!^HMop;zarq7HYo%nj%gvH176)H+e)LCzu7HPbYqsc8@ zqCYe&8z|~FzOr=DaNajaeSuU)FjwHEhSPb==;n1%2Y$=h$8*Unmt$ z(&)ak)L4)yB&O&Wm$Pj`EWs(|Xws6`qypZ&3iQoL1D!vM1akhUj8*e@!0pn9v`Us7 z#FDWQ$Ue!yn|a%^5vPi;ETyeFzl>zv^2=y`7v0X1t->gZw@Wq+5ekUGyc3F=Wb?X2 zO%|d1-d*T{+rO9Pi&l%53vfVuWhu?*0+vWjE+CUBAprqz{7 zS7JY-cpov(qtQW^Ye?4_cYaHvdH+j8r6rcPrSO}=8>h5$pJXK#2P9`$Bi@<9?^SjW zR-+NGkJGg;kjtP%pFu#fOC>4nA#FeZfvz`RbiwOKNw>#qkc^U%cj%{+ok10z0iFOp zEBevt4_HugWJl+A;37=82rb#PLvjpKl_BKu^%uZKib7G}Al-((_YEwZz$vA64T|$| zflIy}nH9LmgeX967A7WzaWR&dY2si%0hM+orpID40?#Ei^f|npfLWs^F6He?Bp@G~ ziHHIhg)xjG-mq~CF)j+$Wl@)y9AFPxZzF&ukBh=_AgNakIjpnuOo>a}eF zgy}JX$HAgGOfG2>cQhH9o{f{htvE2;Ped(Aq5A*ram;MeDgRcnxnvf$vPlKQy8Lpn zQUU1UCwjv6KLh5=@bX8IOqYEYS+i#e{cYH&N;EhHYC%U=LdLbM;|j z)JyiQaUn7_KMIBT9PA(qg~?cSIuYY5ozOv^L4FW$7#RY!2S&)EO`BhGbOH}livxSe zvJpnd;2T6UX$8VwG#=a1DpUU&JHy05w}9=m*OvF8y}B0JjGUccLnZGd-1#P!}#k|VMY&~WLG^n6K1`X?;(3Po6 zQ<S|B7Ht~3XRTU2)Frb$m=zFKL0AUiO*0kZ9tICS2p(JuD#2DY*eU~` zhVPIo537}JGTVV!0i7!7ltJf{VCZ`5I(itaH(Q;!mNO16*SAB84ZOv zGq99zQUJXh+Rc$>^{mlYko(ZxH?x|;gjFUiGhuW+)L!Bwa&aH}HPn}lS(`a?b3|rqF)N@>1$8p0L*=D4^jUPK^yG2BY0oCLlc+bD4&#yM z1>0isY*IT5^t80Qeu|tV^N9pM7Z)J;rGn)|)DVF`V3hWs0+J8;bI>xa!H&T9(Dbew YU86LO{r|UIf|uw2()OP1fY2QO1H58pwg3PC diff --git a/framedecoder.py b/framedecoder.py index 7f8a373..db800af 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -25,7 +25,7 @@ def forward(self, x, style): return x class CIPSFrameDecoder(nn.Module): - def __init__(self, feature_dims=[128, 256, 512, 512] , ngf=64, max_resolution=256, style_dim=512, num_layers=8): + def __init__(self, feature_dims, ngf=64, max_resolution=256, style_dim=512, num_layers=8): super(CIPSFrameDecoder, self).__init__() self.feature_dims = feature_dims @@ -34,9 +34,8 @@ def __init__(self, feature_dims=[128, 256, 512, 512] , ngf=64, max_resolution=25 self.style_dim = style_dim self.num_layers = num_layers - # Adjust the input dimension for each feature self.feature_projection = nn.ModuleList([ - nn.Linear(dim[0] * dim[1] * dim[2], style_dim) for dim in feature_dims + nn.Linear(dim, style_dim) for dim in feature_dims ]) self.fourier_features = FourierFeatures(2, 256) @@ -62,6 +61,7 @@ def get_coord_grid(self, batch_size, resolution): def forward(self, features): print(f"Input features shapes: {[f.shape for f in features]}") + batch_size = features[0].shape[0] target_resolution = self.max_resolution print(f"Batch size: {batch_size}, Target resolution: {target_resolution}") @@ -69,40 +69,51 @@ def forward(self, features): # Project input features to style vectors styles = [] for proj, feat in zip(self.feature_projection, features): - # Flatten the feature map - feat = feat.reshape(batch_size, -1) + print(f"Feature shape before projection: {feat.shape}") style = proj(feat) + print(f"Style shape after projection: {style.shape}") styles.append(style) w = torch.cat(styles, dim=-1) + print(f"Combined style vector shape: {w.shape}") # Generate coordinate grid coords = self.get_coord_grid(batch_size, target_resolution) + print(f"Coordinate grid shape: {coords.shape}") coords_flat = coords.view(batch_size, -1, 2) + print(f"Flattened coordinate grid shape: {coords_flat.shape}") # Get Fourier features and coordinate embeddings fourier_features = self.fourier_features(coords_flat) + print(f"Fourier features shape: {fourier_features.shape}") + coord_embeddings = F.grid_sample( self.coord_embeddings.expand(batch_size, -1, -1, -1), coords, mode='bilinear', align_corners=True ) + print(f"Coordinate embeddings shape after grid_sample: {coord_embeddings.shape}") coord_embeddings = coord_embeddings.permute(0, 2, 3, 1).reshape(batch_size, -1, 512) + print(f"Coordinate embeddings shape after reshape: {coord_embeddings.shape}") # Concatenate Fourier features and coordinate embeddings features = torch.cat([fourier_features, coord_embeddings], dim=-1) + print(f"Combined features shape: {features.shape}") rgb = 0 for i, (layer, to_rgb) in enumerate(zip(self.layers, self.to_rgb)): features = layer(features, w) + print(f"Features shape after layer {i}: {features.shape}") features = F.leaky_relu(features, 0.2) if i % 2 == 0 or i == self.num_layers - 1: rgb_out = to_rgb(features, w) + print(f"RGB output shape at layer {i}: {rgb_out.shape}") rgb = rgb + rgb_out output = torch.sigmoid(rgb).view(batch_size, target_resolution, target_resolution, 3).permute(0, 3, 1, 2) + print(f"Final output shape: {output.shape}") # Ensure output is in [-1, 1] range output = (output * 2) - 1 From dec23d18510609c4b10a80b76c67c5355ee4b3bc Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 07:57:12 +1000 Subject: [PATCH 053/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 8017 -> 9112 bytes framedecoder.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 417a84291f3cfc2f3459db7a27548dfe425a3028..6236155ad08e423d8b3c54f9aab1d7667999708f 100644 GIT binary patch delta 2338 zcmah~Z)_7~7=N!_+qc{GwzXaV>^iuS4d^Dwpzcq*v2J5sAsB%H8IwPD+--Dh>vHXk z)T=WFb(m>3Ukp?EP!lqT%%lbqjbi+w3Ha>x40CAtrplPc>j-{N#Ig?ch-2d(S=h z{`o!6^Lw86y*p|5LLc?1%~lFj%qk8HhGiZU2F2 zL3i!943lQ`nd3Iye;m`vm@?+aNwa3bJW8M+9loswAeaR6BF^I4dL|i*4r~!DD!9g2 zk~N`Pixd6qbP$Kpw}g932@994`m8fhXI(15f;GrO)p~oC)2@S*3DyF{nzbsBtpeUC zmRvMqR`}Lhrc`HS6H3?pma2YdsvMQ;IRL;?LYa#2)ur4;Cjg9+#l8SWZN+|;Qe1@9 zxR8a5B{(v^wI*2Ea<*c*Qa_s^b$sjsS>Quv%N1vkHL@h@V4diDcb%3|)e{-%J-7-U zd`e{-%LH-+SK(jP)tx@pEf89c=p9!-@h1A!wPoYuHSh@3$AI-JsA4u~htoNrtBA9> zBoJ)%6Prm06icS@tH2u1ss}qs2^6keRndlZrty2Ht?`ZMxZA_%>OvP7RO%T}f=#GQ zV<#Eg>{}@tdc$2$sC(}b41!TGvJQQk+ z!o^ZkJjbXO%n&!6z=KIXG0L4z#S;@jCcU9Qb#AN(hz+MWUboKgSDe0pRl24pF+LQZ z;8-T9n!JpCBzVm=cI56!BzQJH5leB*2p`9&s6sQy@FB&zHZm4Vr8vd>^pijhm-a*w zQ+yl~7}i)Pi@p)<-87wk_OSrm_z=h9U?T+#qcKL2jv+=GcTOc#N&L&M4s;ZAY7CVQjrwu`76wC8q=)?8NYkgG|9&&Qv4}}cPCb>#&mp=?^eDk=v8mTV$n#V@4OE!yYXc5b0hk? zx`DWh?p43kxkgrs_U;^eV-!EINMw&p_K0N9iqkWFBS<#WEc@)ByR>j<_^8%CwfPYmgo+d z?hv6z#dnGRHrd}PQe9Xi*e%0u5q96BeV02{&>cF!{7C{UB~k-o_<$Tfs&uFU2_BQ- zF%ceHp&LM#}w9+BY@5gtJyUyYV#aB1uEc99BVkzh!MArXeq zfR8alYh4~d^qOxYqm;BaM_{Eq2X36ydJ-9t$%sfs&}Cn^cCDgHboJ$$u)^|{<)%&R zNflAKA}Ury(J#JeMCG(Ad-Te2k=lhtf&m!@L>R!;aCjG*<{PqvR9QP?zFS^3V^e>t z{t|wtau$^Tyn2xN01DOgop=$C|9Tx1ue`x)+LQ-RQ_Ut*TOO<{Dr9oy!P6CbP_w7` z434mBem=oYjd8s4XW{D<)ubrpQ>4;UzdC=z_Hh-qo1aGG8$RhdNf6lof9pB^;<-Q6 LwTDkZsE+>u%l%xy delta 1512 zcmaJ>TTC2P7(QoaXLfdXcGw%Uz_Mj`yR9yzg>9<>rGP*$kS1DCAMn=dVYykM1DZ(A z24b3YO=wbjXrPS`t1*ZvCK#I72jZKiJh|I!3L$B1eIvfCw`lt0e`ci!ivP^{a{l}O zpYzYmp2{|ApF5p)1hn$#>%>p4ue4=lDrh@g50cAh2KAtBOi;JbiwOP{6M=}tLZpP0 zLJ3UdOSW#2STCV&3sKk~N@#D)vM~Xkh6J#I1lS$nyXH+9uuD;!B>Jr&@xuCb0N^^F zLA{8g*Q;5xTJ{P-r64R2#RPA0V-Oc4!gk3=!W?9}1*&a<894>;=LE_lA#UG-wWgLUEtBn6-<}OmCuT>|@sHo18k!g9pI?x(`k_TzR;|ye z^?7ldlcDnzwH$cMW-_zoek20>443gLHh@DheW}?Q$9Rw#TD$bSjP~wh-)Mf}WarYV zc2~glEbaal*Rq7?uFz*;HQp7+vDd&mDJ5kpvmac3uR!GOy(VQ=hOey3c${ZAZix+NE(hjB$({!cS zQ@V8pXu!E?L(zc7DoghqraS;Va~22zt3cF7Y{b6c*e*?~R86~B*y__@x)h=kXTgB2 zx++faK-!K<)Lj~MP_-BA|J5XBANsrSS$5z5)WLt&?R!!fCPh0A22%}Y`i+v{5$70u zAA|fBk>M5)bx}-}9oBB@r0y}WNIEAwM~m2xCHvS?v2GQpNJXNuWj%TbnuFHhSht-? z(nCE&;e$=gM4X4=b+OM%Uwrxt6vOM%NZ5NR|TILkiqn9T`#xBS*VyCWxzIzef2QGDL)t#ylJ!>Syw$0; zHDAfqwprg)$8<;P1iP^()|Qu%M_+YE@42Jt?o7NT>u$}tTT}9yUOqE;Z7^Lq|JGtS zs~^qjM^g?kE)T5+YVHMU76aM9@m%0|Mk#r~E>zS>c@>pcvF|DlFf|n1|4=JRi4P(N zzHD8z+;-e~Y5CP`ReP?gJsauBMLP0GwnYJ0U#No3hnkoYKF#{W0d_OoIz9xAz)xWF zS&P3lhJo@6c>c~mAvc86L?s>`D%pjAtH9{nE_eim^^9J902k!SG$8;=Ca u-ULBG#c%Yn(r{sjy!TE|qcOp5M?Puo!WiQIH(oWSKls_y{kDsR8Tc;`)Lp3n diff --git a/framedecoder.py b/framedecoder.py index db800af..9a549cc 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -25,7 +25,7 @@ def forward(self, x, style): return x class CIPSFrameDecoder(nn.Module): - def __init__(self, feature_dims, ngf=64, max_resolution=256, style_dim=512, num_layers=8): + def __init__(self, feature_dims=[128, 256, 512, 512] , ngf=64, max_resolution=256, style_dim=512, num_layers=8): super(CIPSFrameDecoder, self).__init__() self.feature_dims = feature_dims From a52c33ae43a4f4b9d8dec0acef9541838f0f19be Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 08:02:20 +1000 Subject: [PATCH 054/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 9112 -> 9146 bytes framedecoder.py | 5 +- reference/CIPS.txt | 925 +++++++++++++++++++++++ 3 files changed, 928 insertions(+), 2 deletions(-) create mode 100644 reference/CIPS.txt diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 6236155ad08e423d8b3c54f9aab1d7667999708f..52f9ed477c8244a893597e86d3837ce7e2817cc6 100644 GIT binary patch delta 503 zcmXw$xl02<6vk(=yK$qgVxf`5XcE&zMexEaUU&o%K?DnfA+FAD46Y=;co&*N&`OXI zv=ke`N{Qfw|3gZNod{xSrC2#H;LPyjd*3^XH)r9;uy(0w#YCL;w>$IbMtfF|J&x8M zVH}WCVvq=BBw~vbMlUF{F^S1cF_lANQfA+GL?l+UPa-l?`M2t=56Yhz(FuycC$S0D zhHCKN6Ja*sD3xgkZ_!>W`_Ugj6e=1~4P||0yAtimo>$DM4PC+t(!7#h(o;sz2Hhgg zQ!_0_6R#I4v7=XDr-@&Cs(GjD%07wS6#sCAXov^gC$x`0xSvMm(H<1YOFM}a3}ZNs zC@*-6`Wi%0$r}k{6Gl+dv3yA!5-gV1S5lS%qx`-m=p4hQFy@ObLtoyGfnlw~5dW#E zrL)}WX=|RtWXWPO2`n8mEioC#!6@1i#1~)!m161J20%EwQ<1a9b84Xq^Ex@Gi3 JhTr<%{{SPEf@lB$ delta 439 zcmdnxKEs`NIWI340}x0(+MZUik#{f0#Jij-IzWMThUpBc3{i|J3{gxe932d4j47-w zoKegvY%L5?EGg`hSveJWUpfF)7O?_}mkdmsJvez-7)>UJ^4=D;1&P~%2zwBr4I&&i zyYkgBGg@vwA{fe{Y7P|AWW2>%oRL_Ns>xJj0p#CeOe +1 + dilated patch with increased receptive field is obtained. Applying this patch sampling to real images before putting them into the discriminator may be thought of as an example of a differentiable augmentation, the usefulness of which was recently proved by [9, 34]. + +Tab. 5 reports the quality (FID) for CIPS generators trained on patches of sizes +64 +× +64 + and +128 +× +128 +, while the resolution of full images equals +256 +× +256 +. Fig. 15 shows the outputs of models, trained with the patch-based pipeline. In our experiments, training with smaller size of patches degrades the overall quality of resulting samples. + +Refer to caption +LSUN-Churches + +Refer to caption +FFHQ + +Figure 15:Samples from CIPS generators learned with memory-constrained patch-based training. Within every grid, the top row contains images from models trained with patches of size +128 +× +128 + and the bottom row represents outputs from training on +64 +× +64 + patches. While the samples obtained with such memory-constrained training are meaningful, their quality and diversity are worse compared to standard training. +Appendix DAdditional samples +In Fig. 16, we provide additional samples from CIPS generators trained on different datasets. We also demonstrate more samples of cylindrical panoramas in Fig. 17. + +Although we do not apply mixing regularization [10] at train time, our model is still capable of layer-wise combination of latent variables at various depth (see Fig. 19). The examples suggest that similarly to StyleGAN, different layers of CIPS control different aspects of images. + +Appendix ENearest neighbors +To assess the generalization ability of CIPS architecture, we also show the samples from the model trained on the FFHQ face dataset alongside the most similar faces from the train dataset. To mine the most similar faces, we extract faces using the MTCNN model [33], and then compute their embeddings using FaceNet [24] (the public implementation of these models4 was used). Fig. 18 shows five nearest neighbors (w.r.t. FaceNet descriptors) for each samples. The samples generated by the model are clearly not duplicates of the training images. + +Refer to caption +LSUN-Churches + +Refer to caption +Landscapes + +Refer to caption +Satellite-Landscapes + +Refer to caption +Satellite-Buildings + +Figure 16:Samples from CIPS generators trained on various datasets. The top row of every grid shows real samples, and the remaining rows contain samples from the models. The samples from CIPS generators are plausible and diverse. +Refer to caption +Refer to caption +Refer to caption +Figure 17:Additional samples of cylindrical panoramas, generated by the CIPS model trained on the Landscapes dataset. The training data contains standard landscape photographs from the Flickr website. No panoramas are provided to the model during training. +Refer to caption +Refer to caption +Refer to caption +Refer to caption +Figure 18:Nearest neighbors for generated faces. Within each row, we show a sample from the model on the left. The remaining columns contain real images that are closest to the respective sample in terms of the FaceNet [24] descriptor. The visualization suggests that the CIPS model generalizes well beyond memorization of the training dataset. +Refer to caption +Figure 19:Layer-wise style mixing. The two leftmost columns contain source images A and B. In the rightmost three columns, we replace the latent code +𝐰 + of A with the latent code +𝐰 + of B at layers (left to right): 6-8, 3-5, 1-2. The visualization suggests that layers 1-2 control the pose and the shape of the head, the middle layers (3-5) control finer geometry such as the shape of eyes, eyebrows and nose, and the final layers (6-8) controls the skin color and the textures. Interestingly, this CIPS model was trained without layerwise mixing, and therefore such decomposition likely arises from the architectural prior. +◄ ar5iv homepage Feeling +lucky? Conversion +report Report +an issue View original +on arXiv► +Copyright Privacy Policy Generated on Sat Mar 2 08:53:15 2024 by LaTeXMLMascot Sammy \ No newline at end of file From 653fbb92c53138d4717cf672af32a3aed0c5ef06 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 08:03:46 +1000 Subject: [PATCH 055/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 9146 -> 9322 bytes framedecoder.py | 12 +++++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 52f9ed477c8244a893597e86d3837ce7e2817cc6..041b88c02df87ff0f9b739f00ba88683ba7023f2 100644 GIT binary patch delta 1142 zcmZ`&O-vI}5Z<@DWlR4=+5Rl0P)fuQqb5qC0#<4T5s8MN(Ga1w>{1IA>MLR(s~7cP zf(97{4<5in)yP4_s4+2KJP=RC;7=O$qKS#|fWH?L2NGC$wU`{wQJgB>55 zt=Ftp6Nl3C_Fm8XS!;tuk8&l8e{_Tnex z6qL!WqApDtphA+pI>O4QgB8Wa&~?H<9R7h+4HFvq7r=T%8H!h}6Q zELq$~d!W*t@5yn!w!DqJm*wrLB8D|)izU1)?$}S3vye*_RHIs~$|(`K;>pR>3YHWZ z9*oXF`M+tKbxRZxID3)H181C7WEIRf6}OykgB0v}JH_J#Jj|LL@ZA{-{7Ko7$LK^Y z$>H*m6GXSE#u5Ags?diWI4)_^ZP?cd6RrX0pS)c88r=%tidI5EapYz+;*2;UzS^J? zAV#>WL@IBv%`zkphZAh4YOyp!dXnq{ua`2;v>yB=sbzFGv$rtSh6>)5c){l01iRc9 zh#y|NZy!)Hh6IiE57Gz&u@%b*Be4YC&Ma{#rXA0&Tt+;ih5Km>1F_GN5xT>Jw2Qfo z%q2}3bCkx^P%?ZZ5!YxFyeh4B?qK<5mN&(yWPvzh|Nz1_nxn)tFn=bPRqzfKZG@2smw!P^?-RZ(o delta 1027 zcmZ`%OH30{6rI@C2nery1<&4U?3!Jh$~E7m@bVg;sPTJSH}0IfSSln-kJM3_uY5zosDx#ox(Fg zFfv3=eR&(5TNGM&XA!Y1{AGmC8A$J_cd7!!@MA;AoH48M-TIy z323lDJlIbFmibWC38MU*VKx~}jqwkBK6LcK+1akl;FJ1~eB~#;GS?v(k*%4go}th} z;V^|D$mTbinx>q)cuj*Ea$on&A(Vo&l0~v+IWj78w)rJA$WcDrvj<@Xi3XFx z$B6yX02zk|v<^S=!4Zeky?-Z(6r}4=r7(F|G>S0p2nYXupmm3%Aj$RbhNh5WmiS4s zScx50_m(y*W^%6$L}#*Uw;hHV5_fMAvWPsLoPreQ1glPap+i0Ckn~4pw&e^W9%@|! zeq2Ym0hHIKZ$?rw`5)?uok!6$tyxxvZRCR+;hifBDR-F>cT@W!0m!)PwD>gD*4*=` z39cP@-EGfnQ#gKW6yKn<76N&8GM>V%)a;^AtojBcqc}u0ZagmE!EIDU9eGm>$E9#O zGLlNjxE&-;UIPB((%|xGA;!ym<~R#kIIqJF)XL?Re0|? zS$&;4O@qlKmeS#9f|R6na65IgI5vRKQ%#$QOvo5_!oku$^a>_Rr;9j4`6k1`uhQyV zuh+$*5F{(wjd=ZX-+Dw^zl~v~;+VHlD{eA-1Mt2!6o*SbH>!gXp99UnUEkBzeuPNc S|1pCfPJCBmTYE4tUG;C>ckP-0 diff --git a/framedecoder.py b/framedecoder.py index b2a8d57..18cef58 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -18,10 +18,12 @@ def __init__(self, in_features, out_features, style_dim): super().__init__() self.fc = nn.Linear(in_features, out_features) self.modulation = nn.Linear(style_dim, in_features) + self.scale = nn.Linear(style_dim, out_features) def forward(self, x, style): style = self.modulation(style).unsqueeze(1) - x = self.fc(x * style) + scale = self.scale(style).unsqueeze(1) + x = self.fc(x * style) * scale return x class CIPSFrameDecoder(nn.Module): @@ -31,7 +33,7 @@ def __init__(self, feature_dims=[128, 256, 512, 512], ngf=64, max_resolution=256 self.feature_dims = feature_dims self.ngf = ngf self.max_resolution = max_resolution - self.style_dim = style_dim + self.style_dim = style_dim * len(feature_dims) # Adjusted style_dim self.num_layers = num_layers self.feature_projection = nn.ModuleList([ @@ -47,9 +49,9 @@ def __init__(self, feature_dims=[128, 256, 512, 512], ngf=64, max_resolution=256 current_dim = 512 + 256 # Fourier features + coordinate embeddings for i in range(num_layers): - self.layers.append(ModulatedFC(current_dim, ngf * 8, style_dim)) + self.layers.append(ModulatedFC(current_dim, ngf * 8, self.style_dim)) if i % 2 == 0 or i == num_layers - 1: - self.to_rgb.append(ModulatedFC(ngf * 8, 3, style_dim)) + self.to_rgb.append(ModulatedFC(ngf * 8, 3, self.style_dim)) current_dim = ngf * 8 def get_coord_grid(self, batch_size, resolution): @@ -71,7 +73,7 @@ def forward(self, features): for proj, feat in zip(self.feature_projection, features): print(f"Feature shape before projection: {feat.shape}") style = proj(feat) - style = style.view(batch_size, self.style_dim, -1).mean(dim=2) + style = style.view(batch_size, self.style_dim // len(self.feature_dims), -1).mean(dim=2) print(f"Style shape after projection: {style.shape}") styles.append(style) From 0915e9f76449fa6b6dba4a5d9878efd000a1378b Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 08:05:09 +1000 Subject: [PATCH 056/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 9322 -> 9675 bytes framedecoder.py | 20 ++++++++++++++------ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 041b88c02df87ff0f9b739f00ba88683ba7023f2..574565ea2c1113a1f34db14929d583d47971fc2e 100644 GIT binary patch delta 1496 zcmZ`(S!^3s6rE>#W}NYoaWWp;aTa1{!EzcX0VnxLsEM-GqFJm!00ra3lVD>f>9buF zw0M9B{;#$NAmg{%P`9*kHU;kcjOwPEe!Zp z@+wqh=hH{Cijq@`-EeEkS4CdfO#|3;W+_>6rVWI+1-|AzoP${D2q0r~K$KR6)56+a zs@SNBjaNkfCDFebRK<2pY+qr&GKdD80<2zUB5Fo zQDy^cqmKVmuw{H(XVk-7#gjuB3;C2?$)gQ2+c`R4QRvuYf%!7dvjiE-C_|TbR}uo) zgbtIbj59765gHEc)!r9t-HU1{TMUyB;d$XUxFHCZf7La4tozah_*QUL?ghoTWiI1` zZd0A@iDf>|BUi|iTXgXhPZmtJDr1ft-Y34!<)5gb3khj zsICB9w_dE0KCM?9Mzw}f)jXz|$9BwP@I~dpsyjdGRlSEb?_pI8X<}$c3~gJ)7j&o_ zUaz`{Za}*FGqeW15{9_Vs!JGQOOWLpxe_scLr?aPJSADLD4yZVQa=)k*6sG{%I7ll z@a;;%4CD3@WGQ+txD4=t{Y7LgqNE)3?%SVo4w>Ipw%4nQv7*evHTgkQT||=(2h5u6 zK-s4phb;U1U3N?$KeRbrZn<2ARNRXxU2i%P9CJo2|NH_ceCO;yE|6Ud`)by}X;*)# z_shSy5DLPtHCfXLsg8T78sAR2UqaLF0=tOg^BFut$skb$eJ;L;`>AC(7FQm}e(Jf0 zs=JG1N{Pt~`F9HDT4LNrh?&Ea?~l8&Yp$sHerSe-90ME%3mFC>-dV@naN zB(u1W5_~#7ms*I&eASp9(t>eTiDt4`N#lp$le&)fhbd{ASy;fabYwP3n$bRZkkWcQ zH;wP57MqMNDHsn!Q~fA+h%o6NAy7fK6yCADyI!x?UOVgc+D$T#&Y`73Afc%31LXoDQV|@G5a0v!)Lbg1AfdQ&g=FY=wXW3p{6+IM*K zz3+YV-p;<6dt<+i1g=Vwk0Cbm^X-|39XA6e>$^e1^Jty15iB7M&a-Qx!5n3XoI?7~S6Tc$5+l_n3 z(iAES)>6eZS4^CFMu*9Q6Ew$K_HxOz+yuiZ?;cJizn0K>pJZApG z`P^)XCjQ@HS`Bv-H~nQO0+)raTn!hz>ievx%1m;EZdsKj5${BH$UzJsSyJ4ISGltoa?(#I z4T-=!QsGRX4>kkW(GZk_V_YX`WEEjEm;+JXl4JGcSR*-B_m4OHIw!yb*_m5geP$g{!%X@KhZ?sR&*dmug#9l1sp_99dz)941>`a?AMN?P^-s9 zEeeVjhx^)q)ZQZzdkEC@+$i(|m6h4TT+bU@AAfA8Lq_%jY&(NUo4rp0JUA&m*1lK5t{dvO#P$fNo=YiSh%M>M zCdMORB@c5s!dO4T_sMT#clzo|uA$@}D8oM}!(`gZa9tT~D5JI7?oKy40o$nzO2D6~ zZ@Nmf4)=E1sF2V0kf8M_aW6$W-0#-NC3)Cg*()IOn4foF>^YZxLis=bNvd!ot-@xS IgUjj6Q>f4%VE_OC diff --git a/framedecoder.py b/framedecoder.py index 18cef58..39b3957 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -17,12 +17,20 @@ class ModulatedFC(nn.Module): def __init__(self, in_features, out_features, style_dim): super().__init__() self.fc = nn.Linear(in_features, out_features) - self.modulation = nn.Linear(style_dim, in_features) - self.scale = nn.Linear(style_dim, out_features) + self.style_mlp = nn.Sequential( + nn.Linear(style_dim, 256), + nn.ReLU(), + nn.Linear(256, in_features) + ) + self.scale_mlp = nn.Sequential( + nn.Linear(style_dim, 256), + nn.ReLU(), + nn.Linear(256, out_features) + ) def forward(self, x, style): - style = self.modulation(style).unsqueeze(1) - scale = self.scale(style).unsqueeze(1) + style = self.style_mlp(style).unsqueeze(1) + scale = self.scale_mlp(style).unsqueeze(1) x = self.fc(x * style) * scale return x @@ -73,11 +81,11 @@ def forward(self, features): for proj, feat in zip(self.feature_projection, features): print(f"Feature shape before projection: {feat.shape}") style = proj(feat) - style = style.view(batch_size, self.style_dim // len(self.feature_dims), -1).mean(dim=2) + style = style.view(batch_size, -1, style.size(2) * style.size(3)).mean(dim=2) print(f"Style shape after projection: {style.shape}") styles.append(style) - w = torch.cat(styles, dim=-1) + w = torch.cat(styles, dim=1) print(f"Combined style vector shape: {w.shape}") # Generate coordinate grid From 51f2a2fa0e7627ad6f688a832014908f38154b34 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 08:05:54 +1000 Subject: [PATCH 057/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 9675 -> 10032 bytes framedecoder.py | 13 ++++++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 574565ea2c1113a1f34db14929d583d47971fc2e..8be95056d3b88dc3b608952ed86c48748b07b5e9 100644 GIT binary patch delta 1875 zcmb7EUu;uV7{6z|{eRna>;FIdQ`S-njIqsu2~0N;W*uWXG_nUWuD!R*on5<~wjfO# zG(I?!MF*b*9(<^g43H%Pdk~cbV$?`XkdhkAO?)#kJP;v3;v1f?rMg9;@g(~mhQUeJ&zT-97Aekk&H>;~r2uoks8^L;u@)ifvVfBq-2q*4T-pF+AGiIlPW(Z z?|xM^X0nHq^2BsN@Vg>Xd(qdN3nY(7@A z$4_o63z3Qtxu@rSrUx8@Iey;y(8T!y|2=nmSy)#Q)~%j9XPvWB7b`yg$*7*JiQ+Ph zxm1SLF6U*s9@n;(L6oP>DEna6(8>;Q!QdHcV}{HGgSWV!ZH^PxW7dTuq zZrZxeld`R+V(XdH{;L~)H$?maBo87XYX)wBI#KVa0(2X_UO{WR%RK=CFl}_h5u=uU zW-U&LA8m5NDWhcTW6JFTq`=dneajxxoQ8D4$Cggg2|rqfne4ZYuI;LN87SC}bStXZ zbhoiNqwd-v_rij;n@#)LIv>O=?C`7mF^JL6A?!ftMA!-G#$mE{`K!i_8V|Co4a!uM zs0Jk^$!{vD$+Lv^!=vU6WHW5A$A}Md_E9nj-`cO3v7<2WeUBbKa9ks&VceO}N&14m zqyy3Aa7czlwnOJ|m#U-g2ua4(jyzlSBBMHobk$MPBLgoPmQ2v*YSEYrdYE;#@&&$R zV4g_-nXJ7Y%dS}U?xJ4cC1J_5`Z5*`*}qEQILT5l)Tp?U8SeQ!@QCqS3q~|Tzk8de zv0$nP2rQQqF1lNZ4a)9dhkr#*Yh5mIQ;l`GWP?Gs7dCku+9s)a$qqxF7LB7I)Ch%A z6E4dtX(WSWl!Qf|t(gb2rvt8g2HKvyB6C5QY>@P9BRai1NCalRUW=oikL7m43GX|W zm6f?5;_EfL>oJl5vQM_GWcR==Utg24mcJfXD^ACXW2M1Xe~bw5wtvdGodu{Qk-107n5y#zl6u8}@C+j3^yt7>PclFCvUW3c~J zO6xTVW3b#B z-SQ&(htY3J%HkV`C#amv(Vb|}XOzk5v?4|A)k~ln60(@hQ8`2RLv!1PE!|btL-YXt z&GB@aN|}j7k|o7_G=^pkoqU~+qNGiVhh<8KA=B0q!Y$8xcEn!VyNim`^7C>$Ey;9K zB25p8RKn?*J8-S-@>4x$a-Zct&0h+hEiM%=<*wwv%zq#JqIkJj?%G{;j8q&Ww;W?P z9b-3M+jqk;R(6b69OJOu{A(4?cdrs<9GW4|&MD@LMRi(N3D~a{q*d ZH6KRrq5Z+%&h>CU?1HOd9h?eB{sd4ozP|tf delta 1688 zcmZuxT})eb6u*CNKW|%lOK<5n&~Ct*ZVW`&2V(=tHVK;~BkI&JEWNj^NDKVi%3xhI z(L^4W=ytXYjfqV($TEzZJQ(8)^TlK_CX_Tmui1mq#7r}@MW6Pf|8q-~xSi(y?)g5a zJ*VgTGv5wWe{8b}9HZ5rKTmul-Kf67ldEuxzexgcPIu3#VIBy`kbGb3+{UFezfBX7-}k0<|$v4absc zB}Ub%L|he8lX7k%mQ7BksgoID-WVg#!EeUah%fMdq|m=~w&?3B`MTzH=k0e~?n^H$ z%r5OLy7rV@dzea?XP*@?8Ax~srZ(r&ndzLO%qdzn+RT&@9j#$6>W44QbuR}PJLf!J zuBymK`dSi`bS6e+H#1~L7`)DXFG^SJ7wrXoQEV!SO&enCy4ZT%Q54%tV*9)v&B8@t zp`$2-%=NG*+!~0o7Y>aqp+*b3;PvY5vbgMkPrX6DCU4csSqn~P^*Wuhw5VfWialj2NLo~z z?{NL&Tesw`V^v!F-Toka=4)!knQT(lSm8L1^-^1r81II^eQ(&dcI|<;{p}WS#YdLl zJHKMvTGa=K0v!(L|2o9Z-8}0YJ_z&^3GN2QUHe%Ul_WB_R}Bs2tv&#!gKeZ2-V5d| zgG^7m5l%pT-8MK{ciVCh#qx53qkic%zp9_2$#j;+(SYlw@^i^4dKfv~t4U?1ynaui z8TV4vrIc8j4xo%nqw?`smX4r!1VwU0wT#iEJd%l}rY0166qN?$%v3BbtHvpXrgB+D zHBu#WIyR-i?gne;(-_>3e!_$jdu4WnDib-{hYDjRIhLAC$`Pmb&rnSXC6>*xVEPhF zH$*#kqZ65-r|@lwPfk)fGm>BjOlziGoTp=>^f)@`Ct|Y-rTwtp@Pt9TJRV2|J%%1; z%=`*FLN{BUTt4{K&@HLdd!pDqRO%ioMu*m(eZDB2EJ-Ifq~UdGcx^PkCJh%Qxg^Qp z3GWrcYu3i1wXtMvTzN74O;vXt%SLhLv;jdowA$~36=?_|%wX#+YC|PL5QZDikv90T zv8yx1WYwUZMup<8mQR-EJBlK1nl|-^rUTmDRu0lP;aF2{r-eA~bC1G1;Y*PRsQ>Ff KqYv&i_5B0UUyi;2 diff --git a/framedecoder.py b/framedecoder.py index 39b3957..379a4dd 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -10,7 +10,9 @@ def __init__(self, input_dim, mapping_size=256, scale=10): self.B = nn.Parameter(torch.randn((input_dim, mapping_size)) * scale, requires_grad=False) def forward(self, x): + print(f"FourierFeatures input shape: {x.shape}") x = x.matmul(self.B) + print(f"FourierFeatures output shape: {x.shape}") return torch.sin(x) class ModulatedFC(nn.Module): @@ -29,9 +31,12 @@ def __init__(self, in_features, out_features, style_dim): ) def forward(self, x, style): + print(f"ModulatedFC input shapes - x: {x.shape}, style: {style.shape}") style = self.style_mlp(style).unsqueeze(1) scale = self.scale_mlp(style).unsqueeze(1) + print(f"ModulatedFC processed shapes - style: {style.shape}, scale: {scale.shape}") x = self.fc(x * style) * scale + print(f"ModulatedFC output shape: {x.shape}") return x class CIPSFrameDecoder(nn.Module): @@ -63,10 +68,12 @@ def __init__(self, feature_dims=[128, 256, 512, 512], ngf=64, max_resolution=256 current_dim = ngf * 8 def get_coord_grid(self, batch_size, resolution): + print(f"get_coord_grid input - batch_size: {batch_size}, resolution: {resolution}") x = torch.linspace(-1, 1, resolution) y = torch.linspace(-1, 1, resolution) x, y = torch.meshgrid(x, y, indexing='ij') coords = torch.stack((x, y), dim=-1).unsqueeze(0).repeat(batch_size, 1, 1, 1) + print(f"get_coord_grid output shape: {coords.shape}") return coords.to(next(self.parameters()).device) def forward(self, features): @@ -78,11 +85,11 @@ def forward(self, features): # Project input features to style vectors styles = [] - for proj, feat in zip(self.feature_projection, features): - print(f"Feature shape before projection: {feat.shape}") + for i, (proj, feat) in enumerate(zip(self.feature_projection, features)): + print(f"Feature {i} shape before projection: {feat.shape}") style = proj(feat) style = style.view(batch_size, -1, style.size(2) * style.size(3)).mean(dim=2) - print(f"Style shape after projection: {style.shape}") + print(f"Style {i} shape after projection: {style.shape}") styles.append(style) w = torch.cat(styles, dim=1) From 4b906ba2079a71b8d4b87abd372586d4a521ff81 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 08:07:20 +1000 Subject: [PATCH 058/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 10032 -> 10960 bytes framedecoder.py | 16 ++++++++-------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 8be95056d3b88dc3b608952ed86c48748b07b5e9..f3073bb42979235712bf5a9c685bd61623075908 100644 GIT binary patch delta 3322 zcmb6bTWl29b?(m0?(F;R^)CCwF=fGoIN(w@F_2*FG=NP4rc~

zNR};ZEgF(^p&&9}s3*`MUFR(PT+&jtvo$%isFRz(Wo3iYDIqIP)DvUn8qI5TOv_bk zRI8-(rG;8GT{)jy(7u|kC%2-@wdy8()@sNLG+(M3GN&-C#ay*mo2PzImU;mgVkKY7 zT4;d%NgTGGft{-YdwY9y;DO8i&DTD7ZRzxlw_2`v+ZDe?9xCpp=PS1^tX_m0TrFj| ztqeDnVKyioBMH`!&Zi_f{#bZ6xGTwjwvZtE+;(rT6@+mcz6am@v<(XzmiVrqT67Du z#Z&B#Gs5mGejpKvtpwe2%l9+mqu7 zc&rCEF@1ivOfTlBngK-RR{*XG+?x}oHz%5&iF@JSY=awIEoHK;Og5EC_AmRvK7Q`( z?b?s7G(BhG`t6zf=gg;;a;~kMYbxj1OODtPes2G*!PNsz&wjXC%1~PwYAQnyyuND+ z?LqC52>MV5wA5i5#`pep5MYar=Lc-E|FQ7DfXND}X1?VQv%fo3-vM}ebQt-_2oBtx zc3u<7Ap5Il582D2-eY7x)4Y>A200c$9wUzq;?~A)FoA3wQg``UkZBqH_(Dmz$(TU9-XzQ;;QW?sI_8YSA)rbs&wbq zvK7nOE}f_n{%!*G$_9sq+?7IhT`t&IE1D4HWd+H6)!@1?LC}R|2X3+}#PjWtmpGw3 zMi(6u^}ac+IyF-+Q*~;N=G9KN?MYAPsx#-OD)~Ali)SDWg|IwdtLDokNV-S3?I$ zoP8Ra+%*Qc%0mF11Uh^B^-fw8wkAb1tVelLynJhRb-w9&8LpNx(pE;A$_V>T_*F8> z9)wrYRgAjMf?k0U%oj8S(;3h^c8&_~+C!6a{h+f@}9g|t2y-e*5cpY7WeC)H{jOZ|xccJpC z-Hd6LiAF@f?&R%OSazcZir!J1jl?2iKz9K$8LMJN?5qOzJ-)!&K6WLZJjj_U5b#|w ztvl498d?o^?>5|B+V$veF5bNy0P&HqU3I9QZi?j~Dz<|QiELye*I)~?Ud}9;eul+d zx_b^jAFMXAEg?qD`5~Z+4Nuk9`6Rf(Jn?yK+;0Yox8Z|m@ahCi3`DjV_u4j=+m@Ny zY!fWcd_j80&A40NX%Jmj?R99^s`xIiS{08jgY6>gCsa=z1|$?RsDxOZ89p^ARpBN-3*~f`Dzpiv9D`>)Ii%w5# z&@8nybnb#SW7ZgY3aIHT2wp|-7J$0-6n6b|J=&qm%~mzqO*Y6zllRF1=1+aHxT%X{ z|1j)YsaDh|%oT=M&o9ug;imMBymrx$u+kY)QOlL+B+4@j!elSk+qcuGQx7!u;>KHMM|=``M*ykb|TW?@L;x|zC&bZ**^ z=W~}eN>8vqq+hi1v80z36TLS4tcqP?+06X$)XpCUZ=G)M8NGMDm7Q#7CtInL?bOM~ zf^5%l@c!hrSKil`q$TyC%ij!)ee?zZxPJ27FLs$vE!TM4HQscMv)^UjjE{aa*bI!q z#ithj)N&ndyN)(pNAK?2{x{1=3dGYi0-T5_j`yZ%QAN5PK_7t4W6BV2BG|>==^yvn zF yJ$@{1LYR delta 2567 zcmaJ@Uu;uV7{8~tx9x3jyZ+hQu3Nj7Zi7zwyDXyoaer_K3tI#vJczU7-i?lKE2mvx z^)#Rl1`RQbhh!v1tPc*!W|)q|M|tpxh}puTq>;!2iCL7eMG~Kke&6jn=m1Z9fA^g4 z`_A|8{Q2(On={>xD|Wk8fabaXUH{RlZycSXTZAD#f4vPbAuy8X^KzvmFkz7JIj`4{ zBmADp$E`-Uk#xt5e8_m%6b0FH0>5XhE|bGY}WQVY%Tp zmW=*dG%Dyu`F+{V`$?5sC6h!Drr;^IoPpnxW61J8#S)p27s2H!{)07mmXSk3%+S4v zQ)5ONgI7Tw9ZsJb)wHZen_x^Ei##{>11KP+Gh<1OdZ3?1&%)4X6|!ppvYrDY>S!{a z(bUe4P5mSEbeyVnDBBBgPPpT&oN#>NxHLNNY~U@*HvW+kXu4D3Ja3_1)KQBNL3wnl zwWt^qND-X;cV*KCqvTo={tt?(qI{P<=WKy#i>;_YLTH)2X+JNL&HRvSBiYE`b9F)g zmaC_6v(5!b-~!W^aA*Kngpm;QQ5gx%yxqMOre)k0BIr*{rg~!9klBvVfv^$bH6HPF zk;d5}&u2t_AEKEMGDP61>PkOR4a_hJilO|8z%8*$%SR8w*5@IR|UwG1rdZxU8@3RhNdIf+Sm6G}TX z6(yno8$NGABH!Y0@EhI=BE$BtZZdZ-{tJy_HVcNtV2l(jiD6YclNcVjO6V>i>sgIi zdD0hAwxF1aYAG$A;g@~8Nr2z=_1G|-=`R7`)Ex0M{)fMv`1z*lsqJviRy|{+Jf=Ou zu`6#uggGf^dNNpZAQG=h#jYMHw|bq%)xqZhC;uhj z;okr^K*Z6?68?`cg20kTLUt&|a5A z-z^n6Cv1!gMj*W~MfNktr-E9|2)`zW&X(l^Zz#RNQ3Z;7eW1&Rkj0ma|{@pe%TQ$po`MKF<7M zo!KRquS7B9jgWE3H*A#hgYGV~UCZ0TQI91Hw`S4voSV6!Tm<>)@Zox>0w>}rEoPuQ z!FYl6r!nSIA=nN5?2G*4dOw#Um1G|eL=NtXbiyGYrCL~@8Sc}dIB4M%9U0V)XA&dB zuY~CvFps{9(28&vAR9V_<%VC21mL2+8Tpa4@QHO-{RJsS6T=y*p&(wm1u>gQ(Y-t# z^)?+qB5v4%cs!n=aU8ymL(*BW4bX(zn~o2qpos0`Uqpk^f~*aX4rvsM7Tu4sT_|f! zYVng}y;Mt%!d?DzG}Lzt$*xSC4rm#mr$>^b=y<`TSHyI|h`|(${SYVJs+&7a-@&)- z_{a!V)4lyr7j;MIJ|v5DppW)oPfEteG)lYpk@{`$eehv@(2V}3ABgLtXxBs?3(>BX^So%g8rIU@Qu#DAPcDMbJP diff --git a/framedecoder.py b/framedecoder.py index 379a4dd..172dc73 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -20,22 +20,22 @@ def __init__(self, in_features, out_features, style_dim): super().__init__() self.fc = nn.Linear(in_features, out_features) self.style_mlp = nn.Sequential( - nn.Linear(style_dim, 256), + nn.Linear(style_dim, 512), nn.ReLU(), - nn.Linear(256, in_features) + nn.Linear(512, in_features) ) self.scale_mlp = nn.Sequential( - nn.Linear(style_dim, 256), + nn.Linear(style_dim, 512), nn.ReLU(), - nn.Linear(256, out_features) + nn.Linear(512, out_features) ) def forward(self, x, style): print(f"ModulatedFC input shapes - x: {x.shape}, style: {style.shape}") - style = self.style_mlp(style).unsqueeze(1) - scale = self.scale_mlp(style).unsqueeze(1) + style = self.style_mlp(style) + scale = self.scale_mlp(style) print(f"ModulatedFC processed shapes - style: {style.shape}, scale: {scale.shape}") - x = self.fc(x * style) * scale + x = self.fc(x * style.unsqueeze(1)) * scale.unsqueeze(1) print(f"ModulatedFC output shape: {x.shape}") return x @@ -65,7 +65,7 @@ def __init__(self, feature_dims=[128, 256, 512, 512], ngf=64, max_resolution=256 self.layers.append(ModulatedFC(current_dim, ngf * 8, self.style_dim)) if i % 2 == 0 or i == num_layers - 1: self.to_rgb.append(ModulatedFC(ngf * 8, 3, self.style_dim)) - current_dim = ngf * 8 + current_dim = ngf * 8 def get_coord_grid(self, batch_size, resolution): print(f"get_coord_grid input - batch_size: {batch_size}, resolution: {resolution}") From 3bceaab501b818624baa73a403a06867ad432872 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 08:08:33 +1000 Subject: [PATCH 059/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 10960 -> 10957 bytes framedecoder.py | 12 ++++++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index f3073bb42979235712bf5a9c685bd61623075908..b7dcd66135c6d5e947b6d05a0c1ec928b09321bd 100644 GIT binary patch delta 419 zcmcZ*dN!1IIWI340}xDkvOO(fBQFmJ^GgP%%`zM}8QE%?Y8VzUPrk^hGC7-zUzoXu zDT@)L6a;27q_EU5PhQWd$jAazARqij{!@RiuU)qErA~ z>Ezp7X_MW!jagrUbWg72)|Au%GKzFTgeH)<#aWtHTv(c#T9rC^H@7R}?#YZi=7Jrp zHv~kc^G)KLQF>88^@@P%mrZ#6&~#ssq6DsL43H#7hSd3+QAZck?OU4rWH9$-#oUj0Te%1!t=nfyIGD zCId52g(Zl~1SDPpc?_E3MVuf3v&l|EhCmG^LdzKqHp>Weu>j3b7f)j}-rOf{%?0F~ qQar<`zInP*A0wm5W;KT#0UV|`e!Zx delta 364 zcmX>bdLfi|IWI340}$9h*`5}^k(Y;qxrmivvkb>gMy3VKlP_|L2&FJDV`N}h4a5*o z%QQKjOLDRR7e7=5QwsBBK`vQAmKvrkxT@IV4liVI6qQ>#*obb$OK4Umf6+^&rKCNuJw>s}Yoza*f) zf$^e%(G>xs4%Qn2BGdUM@y#f`D4=>pKowPbvL}xL??lgwJaSif?H#e&}4I<$uAv1cCdm7^UcS3JD3>_CkF`XGU`vR6P(Rv2o|3#FC@=sI@w;x5GY(I zw470Yv!pN=3s9wscp9V8=5BFoE+FT);u%Ks%~O>67#WQ>E32GkVKW73H=X=VqY Date: Sun, 11 Aug 2024 08:10:09 +1000 Subject: [PATCH 060/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 10957 -> 11040 bytes framedecoder.py | 10 +++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index b7dcd66135c6d5e947b6d05a0c1ec928b09321bd..122f67a4813c4f1849b8ff965d6f8218530422b5 100644 GIT binary patch delta 1044 zcmah{OHUI~6z=WJly=&}U>%E5YAIltMtp!8gCU9!Q1KbW#6)cd=28bs3wLG|NE^56 zLeXO)iHX#N_=p=#7y1X>7)`7pk<5yPF4(DYCp2h;e5b5OCTl0;ZdMbCBpgR{j0LaV@P@L(wA;~3}4;{U2+3_QdKce@?N+rM^xXW z#E(FJORR!1AtCm8G(V`pb@Cz>UHr-ex`OR65p3;s#q%b;z*@P$G!9x7TZZe^99+q= zCoJ~!EwVo?I zQ(e>2_5I~gY%UZls^!pHmqE8%Sz%DO*${3IGY;!TqRKMUBRSL1B9&=y{XW-Wmwdlb*l)UmOo)CaZ~$<{O2Wc8%!lvKrz|+4e#(8r~m)} delta 966 zcmZ{iO-vI(6vsQ=U7+ojR%mGtN}(-;Eg#`4O^ApE1fp2M5V=WR+9@vd17=%+ruB%% zi^L(3sEM2q3#_$)_v4K7yBL%D#m`VE0uw#JCC$r5)?k~4x%cUsAE$En_brRLD7d&}qo=jJDL zQaSB1wNS;>C6wT3`oh%ONm`jjBnHipPN)lMj5Ao`NfpTFK^D~1atxDdW~)2{wTVsW z1HB|RdG@TfrM9q`ebfCP`4#gzkGiR=YRb_A-Rc=i9E0Q_ z1`zyot?H9~5X>upAwWA|gsRn}s5$?!I?vUXez8S5nXc)TSqW(sN+N=nl*G(7BA4l^ zTIalthV%EGn-%sxC`VBhny9NIcD9iK?WylT<8-cm z$vOc-KVSebng3qDjV7TZ&43f>I3{umYg|T|Cs#q_?`#r-15d5vI>dtzH^#6$w;UokmS!K`=&eN|I^=Z?+=L)M zPm;45AI0p*?xWKd$u#Vya570EYA70G5x9qpg3OVbX-$a9%b1WL-Sm#5vHTBjD=!`O za{=<)01tx-Qx-UW3qObe9}P-5p(nu7n8UVcI&x+po gQCc5)Kzn`ZmKwx{<1AL)YtjMehkt|Zbl*4l8-^D71ONa4 diff --git a/framedecoder.py b/framedecoder.py index 0bb720d..63a4671 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -33,14 +33,18 @@ def __init__(self, in_features, out_features, style_dim): def forward(self, x, style): print(f"ModulatedFC input shapes - x: {x.shape}, style: {style.shape}") - batch_size, num_pixels, _ = x.shape + batch_size, num_pixels, in_features = x.shape - style = self.style_mlp(style).view(batch_size, 1, -1) + style = self.style_mlp(style).view(batch_size, in_features, 1) scale = self.scale_mlp(style).view(batch_size, 1, -1) print(f"ModulatedFC processed shapes - style: {style.shape}, scale: {scale.shape}") - x = self.fc(x * style) * scale + x = x.transpose(1, 2) # [batch_size, in_features, num_pixels] + x = x * style + x = self.fc(x.transpose(1, 2)) # [batch_size, num_pixels, out_features] + x = x * scale + print(f"ModulatedFC output shape: {x.shape}") return x class CIPSFrameDecoder(nn.Module): From 68a6944238453595741e000d46c9c19f89b4a27a Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 08:17:41 +1000 Subject: [PATCH 061/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 11040 -> 8541 bytes config.yaml | 2 +- framedecoder.py | 69 ++++++++++++++--------- 3 files changed, 44 insertions(+), 27 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 122f67a4813c4f1849b8ff965d6f8218530422b5..bf60d380a3d8d45ef950806e8d578b62ee28eee7 100644 GIT binary patch delta 2561 zcmaJ@TWl0n7@o5`GdrDjr+evcr`_%LUOI(B%OwT`1PY}VT2KiZG&Bn{ExWzJ*{w*O z3B(7p8WJj}2}okxm{1|~0YkhzsZS=x2fMflW)os#q6rTsC5nl>=>N}jkuAz;`t_Xi z{pX+m{I}EW;k&z(kG)=*gAw@k^WjhHK2dtz0XHo1_xGa!mpGGTSzN9XIFm~f)=(LC zS%fW$)lB9CJQ>t@HqIZaZ3fwM++v0ILf^Z6Glymm=PIK4ifD%X!n{I1iQIZcXy=1M zQMF7Q)k%)!%0ldv7-9!xpJ<7un03MDs;h`8%z+g*&=Pw9$BxP>J1AF*mdErsa*afH zt&sDW#1yCDQx-9~EDzd3>^r4$H!O+iStiI!oaP?92WD{h{)V9m1tE1Je$t>FuvOq+ zPZ+1^YM9bh0GdPtIIm?tmp66NH7MdurO%8RL#K_z$cc1{24OEH0W>TR8p0D1R$rl> zZ9>NKRUvk!Sj|~AMafR()Xu!xnei^Ft#{Op+iJ(PEjLen zu{x*r=GESe7bN|mOUgy%@>I^(n)kJ4ghgNIl6+CVyd`_;T2s!~nfG;O1d#J8Gu~W9 zU54P?e_{LHclXZhe5e0>zq6yv2=6>^PksK3M8jxk6rmCAM4#+%tjDO*X12-Ot;OmE z{}T5!aP%~@3*I+Et#}aO-VU5$+-d8rsFHD$E_T)5N!GF-{QJlNs}CG#>#|vZ1m+#x zic1&3J&e6pP0Mg&Gn)-`!?AAy=Nr*1t;`8YFOqu^`ViJ4^s`iOAK88DtKbJje9KNX zHt7$Cbx0O-WxBYLI{Tj#pGIFSk0hF857TY(x_h24&AZotx7=3g8N9)xGs7;O%nNfM zcjJ++EN|0nSgw?GaSpzBgRt%4Skwet@n;SMZa6|9X$hu~6|o+D=}Jd4BILP*o9s=Z zu=YyDl`O+b?Sv%v4AO3ppy=d+E0NSlK}t;M#v6%=k?Vx+LJqCbWcFqC2KHxlTVNB4 z%G5|rj!mT#lM|5kYt;i4L$IB?3UH45gx9oGF?f$4^CROERUbfOYNxhr!%- zP?hd36)O@hooBU?M&dz_sSrq^R1VbfwmU;TQYk~Y^Hv%QeR(A9)`<&92yta-tuhEp zIhw_ui3WL7c5Dn`Sly+$W0kX5w1)i_4KzoJ$~i8n6zAZLN72suP-VBB94phqly#3T z&6ll68n5Zi`ti0Bq`6xMTUlR2%UKh3$SW-I{?7#~Oz#NDKXFNjyQMrOijUj{cc?1f z$F2ePQ>>P}z#hg{z5Mv4v2EG|QmVYPSfGOox_esWAgxC25=tqi0E)4L*+UnnAtnw- z{cOHo?JyP7KMka%oPV+8Zq>_YI8$-LjCIF$GikRw;#$MNlRno8FrjeV?8W=u?t8#JCa z3hwE|NxC1I4ao4R#I#Z1$Bp;|eGzvtGYanG@iaAXIe<&jTPPo)34JIPA3r%}&_Q;; zIaYlD`Gd&IV@7=R%n&ukrl4E4wZzCKwx^|j*C8bNlXNm!;D-%RxzW~kn(Fws{P^S~ z)l);mW5D2d=pb&p>BzByFcv>!P6g*18|CrLDaW-mXCSj=$x$za@KQp?g=(-;?+E zWW>dA^wQ|X(QL=HV>cUf;f?w5#tc5A(S|#buG^8Wn~_{(TRyUFK~f)xT(pk;)%J|A zq;SzS?BJ@^O%Eils%CccLTv3q*N%m{onO<1V1F*ypAYtDBzT=GeeC;HJK_J9_3h7r zC9Hk#5l|kix!nC;N0#LLjbK4#DC4!CJIzDu&<`<$dIW3KBMX~Fg^R{AXJe0PM+cMkf+pxrTn#!|19b=vA Z2t3eF!c5bwb|3fu_?J9J*xinve*j3&9~uAv delta 4995 zcma(VTTC3+b!MM;mtojlcG+F_;Q>G5^%MNW*u*xN1wR4-W7m$8(ls-XCBWj@C2=;h zT31Rd>RPhh8!5C=+WkmfBKRP2q*zrdHB~A#YNgJkBL%I5%185IwI8<9NR=wJJ$Gio zJBtIo+&yz2=bU@*dE9d@OYMb=?60b;ofMSCzkNLXd;bS)+c#Ifu7UGZ z%-oZW#8XpZQV`=pA~PilX)ZZ2IhEnkFDE92RxTT@oK8(;%I0IXqIFVCPGpK!jRJpW z9yoUaj&p-rLe(;Rcs4OV0w1UfYEe;(j9M0%>f7BnyYCFVbx~o$DifZjHQbijYxA$k zYztHc9Z}H{868=|`z&YZ$o-sUyJf@%d^O3v2NGFda=b}086;_0E6FGGF9O$`IcLdP zB}$WT9X{q@O$9R%bAmSxGUZ!*`DkI^u2bM znDYknk|}GYDB}qv8yTxmzkMV1OEhR5^<<;ZrTD3_L`LA_9pxHNbGy0gpmxVDE}eOO z3@qRwP30--;eVlIx0O>(im4$XofdfARtRN@M57)`5bA7wIZ37MhyjqM7$jg9fTHEq zr0|-!oAhGM!m~hw!YV;poVGZoihb`+2wtk{RulWrMR}LuB~6Xc6{O5G26Ou z?LnjBYExZp)7HfeQ1XED6|wmj7N$OW13pmyaN$o^^p=9IsOXA}t}HXI+g&%i?wp*l z-aV}_QI&}nd4nS4VK@}a8(ZR(U*EACA$VsCFY$gXKm%<7C_*==~ zR_v{6j)fVk`w_Js(D`h|6IIuJ&k#Za9-bgAuqWu=|9922ncj!rtl3RB;SX!h(k=Ku zHND&RX|n*5#0kYV(lr5iNW3EE7ae4v6{lUzK=w=5tp>8Km_t`#7n$rP;4}eE1f0eF z?z8mX`#*QTOXrJEb?J8~;>b^PjLfPpQq(BfhE|BDN~Q{5HUA&DRV4;FOU8`zU^_)h;Py|AO=Xg)iB?YB4_rl}^P(v^8l#ISIl&9p zlM^HF(_#miBSOzcDp(sm+#X{z7svnb?zd`wjyL(Xc@7b-&L^vhVf-Cm@6IzYoL&M@ z^4}L0&Xl}pw8Dp$@Q1!P>b0P8Y<75lOlFTkRZyFX+GNzWi1*Z;>^KGTs|*sn=}ics zr>J|d${hUk@>KcKyg7*bYHEk!yMjoOd;f{G<;(`1f#rd{S|#(->~(|!%URVZMP-P) zlnU^);?L`19SfR`Sv6?O*$v!nk`=hyCHwjT1JnK>?)UnTWa5#83U*=p)*GcTKmP=c)Ip z7_Ehgm`QAhi6T9mCZgmxJ<&M@?d&eR5Dw(`92-kPikW^E8p(8KC^bI$>;O!s>jB)L z3^X=jW5+X!1Lln%S(z6H>Sv}3=t_s5`QOI+U z{*DSm<-xKOxQcNCE))K?^CW@hvcVEzVmKp+Mlup!vyqO}_|@bD$)h!bSAlCv)QKQ2 z`#VVBPEJ6s&S}wGlkstwX^2AtpAiV97Thaw03}d*y>UhjwQ~4Ss4<`2TgF=$zbf#s z&Pdv@U$W$2Ix#*;vUx36%<@2hCokrrg_u;)LgXYKBm?wo zqCmb?ntIcN!=P_%Gy39YapY>zs!2iY!`H+6=m?$*N5;t!(hl(VQXbW*TmOUY@Kv^V zsNfoGCsYN+RTP&|eAyS7NxXGM^=&I$Q+=^%=VF*s!+T{E)KaXz8Fei2_Da%P z3e%}Foifw8h$kc4&wlGa4j$PkPu-X9Enm2|6kd><+ZA_*>h74fX#|I6ZS#!G9)hZ% z!zwx~qr;2%gNV0Bo9JsZ*pk+@d z)P+kQeOGTOOk8E+G813IbJ63COD>;W*R{|L0O|w#C)IjOadoS%ZrRnni2ok#%4^H* zFI;@@MVZ|XRY3<-bU;Q2Kn}IRndZAY3bf*Gn6@p~_@=A1s}i2zo4}kzj(`RN;Qps{ zw;(x;*hs);0Ob$TRx(Jy9x{;v@UI)x4Lq~y8r_PW+|f;G7%JMde+>vCSSmbU3DD@0 v6onI9o%2U9Lc&P;FZlc1*sc%_58JiiPUBm#$7KAUPbOP1*HDL>8g~B=`X-o+ diff --git a/config.yaml b/config.yaml index 27babb5..c8550b5 100644 --- a/config.yaml +++ b/config.yaml @@ -5,7 +5,7 @@ model: num_layers: 4 use_resnet_feature: False use_mlgffn: False - use_cips_generator: True + use_cips_generator: False # Training parameters training: initial_video_repeat: 5 diff --git a/framedecoder.py b/framedecoder.py index 63a4671..bd2332c 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -10,9 +10,9 @@ def __init__(self, input_dim, mapping_size=256, scale=10): self.B = nn.Parameter(torch.randn((input_dim, mapping_size)) * scale, requires_grad=False) def forward(self, x): - print(f"FourierFeatures input shape: {x.shape}") + #print(f"FourierFeatures input shape: {x.shape}") x = x.matmul(self.B) - print(f"FourierFeatures output shape: {x.shape}") + #print(f"FourierFeatures output shape: {x.shape}") return torch.sin(x) class ModulatedFC(nn.Module): @@ -29,24 +29,41 @@ def __init__(self, in_features, out_features, style_dim): nn.ReLU(), nn.Linear(512, out_features) ) + #print(f"ModulatedFC initialized with in_features={in_features}, out_features={out_features}, style_dim={style_dim}") def forward(self, x, style): - print(f"ModulatedFC input shapes - x: {x.shape}, style: {style.shape}") + #print(f"ModulatedFC forward - Input shapes: x: {x.shape}, style: {style.shape}") batch_size, num_pixels, in_features = x.shape + #print(f"ModulatedFC forward - Extracted shapes: batch_size={batch_size}, num_pixels={num_pixels}, in_features={in_features}") - style = self.style_mlp(style).view(batch_size, in_features, 1) - scale = self.scale_mlp(style).view(batch_size, 1, -1) + #print(f"ModulatedFC forward - Before style_mlp, style shape: {style.shape}") + style_weights = self.style_mlp(style) + #print(f"ModulatedFC forward - After style_mlp, style_weights shape: {style_weights.shape}") + style_weights = style_weights.view(batch_size, 1, in_features) + #print(f"ModulatedFC forward - After view, style_weights shape: {style_weights.shape}") - print(f"ModulatedFC processed shapes - style: {style.shape}, scale: {scale.shape}") + #print(f"ModulatedFC forward - Before scale_mlp, style shape: {style.shape}") + scale = self.scale_mlp(style) + #print(f"ModulatedFC forward - After scale_mlp, scale shape: {scale.shape}") + scale = scale.view(batch_size, 1, -1) + #print(f"ModulatedFC forward - After view, scale shape: {scale.shape}") - x = x.transpose(1, 2) # [batch_size, in_features, num_pixels] - x = x * style - x = self.fc(x.transpose(1, 2)) # [batch_size, num_pixels, out_features] + #print(f"ModulatedFC forward - Before style multiplication, x shape: {x.shape}, style_weights shape: {style_weights.shape}") + x = x * style_weights + #print(f"ModulatedFC forward - After style multiplication, x shape: {x.shape}") + + #print(f"ModulatedFC forward - Before fc layer, x shape: {x.shape}") + x = self.fc(x) # [batch_size, num_pixels, out_features] + #print(f"ModulatedFC forward - After fc layer, x shape: {x.shape}") + + #print(f"ModulatedFC forward - Before scale multiplication, x shape: {x.shape}, scale shape: {scale.shape}") x = x * scale + #print(f"ModulatedFC forward - After scale multiplication, x shape: {x.shape}") - print(f"ModulatedFC output shape: {x.shape}") + #print(f"ModulatedFC forward - Final output shape: {x.shape}") return x + class CIPSFrameDecoder(nn.Module): def __init__(self, feature_dims=[128, 256, 512, 512], ngf=64, max_resolution=256, style_dim=512, num_layers=8): super(CIPSFrameDecoder, self).__init__() @@ -76,42 +93,42 @@ def __init__(self, feature_dims=[128, 256, 512, 512], ngf=64, max_resolution=256 current_dim = ngf * 8 def get_coord_grid(self, batch_size, resolution): - print(f"get_coord_grid input - batch_size: {batch_size}, resolution: {resolution}") + #print(f"get_coord_grid input - batch_size: {batch_size}, resolution: {resolution}") x = torch.linspace(-1, 1, resolution) y = torch.linspace(-1, 1, resolution) x, y = torch.meshgrid(x, y, indexing='ij') coords = torch.stack((x, y), dim=-1).unsqueeze(0).repeat(batch_size, 1, 1, 1) - print(f"get_coord_grid output shape: {coords.shape}") + #print(f"get_coord_grid output shape: {coords.shape}") return coords.to(next(self.parameters()).device) def forward(self, features): - print(f"Input features shapes: {[f.shape for f in features]}") + #print(f"Input features shapes: {[f.shape for f in features]}") batch_size = features[0].shape[0] target_resolution = self.max_resolution - print(f"Batch size: {batch_size}, Target resolution: {target_resolution}") + #print(f"Batch size: {batch_size}, Target resolution: {target_resolution}") # Project input features to style vectors styles = [] for i, (proj, feat) in enumerate(zip(self.feature_projection, features)): - print(f"Feature {i} shape before projection: {feat.shape}") + #print(f"Feature {i} shape before projection: {feat.shape}") style = proj(feat) style = style.view(batch_size, -1, style.size(2) * style.size(3)).mean(dim=2) - print(f"Style {i} shape after projection: {style.shape}") + #print(f"Style {i} shape after projection: {style.shape}") styles.append(style) w = torch.cat(styles, dim=1) - print(f"Combined style vector shape: {w.shape}") + #print(f"Combined style vector shape: {w.shape}") # Generate coordinate grid coords = self.get_coord_grid(batch_size, target_resolution) - print(f"Coordinate grid shape: {coords.shape}") + #print(f"Coordinate grid shape: {coords.shape}") coords_flat = coords.view(batch_size, -1, 2) - print(f"Flattened coordinate grid shape: {coords_flat.shape}") + #print(f"Flattened coordinate grid shape: {coords_flat.shape}") # Get Fourier features and coordinate embeddings fourier_features = self.fourier_features(coords_flat) - print(f"Fourier features shape: {fourier_features.shape}") + #print(f"Fourier features shape: {fourier_features.shape}") coord_embeddings = F.grid_sample( self.coord_embeddings.expand(batch_size, -1, -1, -1), @@ -119,27 +136,27 @@ def forward(self, features): mode='bilinear', align_corners=True ) - print(f"Coordinate embeddings shape after grid_sample: {coord_embeddings.shape}") + #print(f"Coordinate embeddings shape after grid_sample: {coord_embeddings.shape}") coord_embeddings = coord_embeddings.permute(0, 2, 3, 1).reshape(batch_size, -1, 512) - print(f"Coordinate embeddings shape after reshape: {coord_embeddings.shape}") + #print(f"Coordinate embeddings shape after reshape: {coord_embeddings.shape}") # Concatenate Fourier features and coordinate embeddings features = torch.cat([fourier_features, coord_embeddings], dim=-1) - print(f"Combined features shape: {features.shape}") + #print(f"Combined features shape: {features.shape}") rgb = 0 for i, (layer, to_rgb) in enumerate(zip(self.layers, self.to_rgb)): features = layer(features, w) - print(f"Features shape after layer {i}: {features.shape}") + #print(f"Features shape after layer {i}: {features.shape}") features = F.leaky_relu(features, 0.2) if i % 2 == 0 or i == self.num_layers - 1: rgb_out = to_rgb(features, w) - print(f"RGB output shape at layer {i}: {rgb_out.shape}") + #print(f"RGB output shape at layer {i}: {rgb_out.shape}") rgb = rgb + rgb_out output = torch.sigmoid(rgb).view(batch_size, target_resolution, target_resolution, 3).permute(0, 3, 1, 2) - print(f"Final output shape: {output.shape}") + #print(f"Final output shape: {output.shape}") # Ensure output is in [-1, 1] range output = (output * 2) - 1 From b61d9e3d1c12987830f169232d544ac6f5bbe4bb Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 08:30:41 +1000 Subject: [PATCH 062/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 2 +- framedecoder.py | 276 ++++++++++++++++++++++-------------------------- model.py | 6 +- train.py | 2 +- 4 files changed, 133 insertions(+), 153 deletions(-) diff --git a/config.yaml b/config.yaml index c8550b5..2b97107 100644 --- a/config.yaml +++ b/config.yaml @@ -5,7 +5,7 @@ model: num_layers: 4 use_resnet_feature: False use_mlgffn: False - use_cips_generator: False + use_enhanced_generator: True # Training parameters training: initial_video_repeat: 5 diff --git a/framedecoder.py b/framedecoder.py index bd2332c..f2d7b5b 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -1,164 +1,144 @@ import torch import torch.nn as nn import torch.nn.functional as F +import math -class FourierFeatures(nn.Module): - def __init__(self, input_dim, mapping_size=256, scale=10): - super().__init__() - self.input_dim = input_dim - self.mapping_size = mapping_size - self.B = nn.Parameter(torch.randn((input_dim, mapping_size)) * scale, requires_grad=False) - - def forward(self, x): - #print(f"FourierFeatures input shape: {x.shape}") - x = x.matmul(self.B) - #print(f"FourierFeatures output shape: {x.shape}") - return torch.sin(x) -class ModulatedFC(nn.Module): - def __init__(self, in_features, out_features, style_dim): +class EnhancedFrameDecoder(nn.Module): + def __init__(self, input_dims=[512, 512, 256, 128], output_dim=3, use_attention=True): super().__init__() - self.fc = nn.Linear(in_features, out_features) - self.style_mlp = nn.Sequential( - nn.Linear(style_dim, 512), - nn.ReLU(), - nn.Linear(512, in_features) - ) - self.scale_mlp = nn.Sequential( - nn.Linear(style_dim, 512), - nn.ReLU(), - nn.Linear(512, out_features) - ) - #print(f"ModulatedFC initialized with in_features={in_features}, out_features={out_features}, style_dim={style_dim}") - - def forward(self, x, style): - #print(f"ModulatedFC forward - Input shapes: x: {x.shape}, style: {style.shape}") - - batch_size, num_pixels, in_features = x.shape - #print(f"ModulatedFC forward - Extracted shapes: batch_size={batch_size}, num_pixels={num_pixels}, in_features={in_features}") - #print(f"ModulatedFC forward - Before style_mlp, style shape: {style.shape}") - style_weights = self.style_mlp(style) - #print(f"ModulatedFC forward - After style_mlp, style_weights shape: {style_weights.shape}") - style_weights = style_weights.view(batch_size, 1, in_features) - #print(f"ModulatedFC forward - After view, style_weights shape: {style_weights.shape}") - - #print(f"ModulatedFC forward - Before scale_mlp, style shape: {style.shape}") - scale = self.scale_mlp(style) - #print(f"ModulatedFC forward - After scale_mlp, scale shape: {scale.shape}") - scale = scale.view(batch_size, 1, -1) - #print(f"ModulatedFC forward - After view, scale shape: {scale.shape}") - - #print(f"ModulatedFC forward - Before style multiplication, x shape: {x.shape}, style_weights shape: {style_weights.shape}") - x = x * style_weights - #print(f"ModulatedFC forward - After style multiplication, x shape: {x.shape}") - - #print(f"ModulatedFC forward - Before fc layer, x shape: {x.shape}") - x = self.fc(x) # [batch_size, num_pixels, out_features] - #print(f"ModulatedFC forward - After fc layer, x shape: {x.shape}") - - #print(f"ModulatedFC forward - Before scale multiplication, x shape: {x.shape}, scale shape: {scale.shape}") - x = x * scale - #print(f"ModulatedFC forward - After scale multiplication, x shape: {x.shape}") + self.input_dims = input_dims + self.output_dim = output_dim + self.use_attention = use_attention + + # Dynamic creation of upconv and feat blocks + self.upconv_blocks = nn.ModuleList() + self.feat_blocks = nn.ModuleList() + + for i in range(len(input_dims) - 1): + in_dim = input_dims[i] + out_dim = input_dims[i+1] + self.upconv_blocks.append(UpConvResBlock(in_dim, out_dim)) + self.feat_blocks.append(nn.Sequential(*[FeatResBlock(out_dim) for _ in range(3)])) + + # Add attention layers if specified + if use_attention: + self.attention_layers = nn.ModuleList([ + SelfAttention(dim) for dim in input_dims[1:] + ]) + + # Final convolution + self.final_conv = nn.Sequential( + nn.Conv2d(input_dims[-1], output_dim, kernel_size=3, stride=1, padding=1), + nn.Sigmoid() + ) - #print(f"ModulatedFC forward - Final output shape: {x.shape}") - return x + # Learnable positional encoding + self.pos_encoding = PositionalEncoding2D(input_dims[0]) -class CIPSFrameDecoder(nn.Module): - def __init__(self, feature_dims=[128, 256, 512, 512], ngf=64, max_resolution=256, style_dim=512, num_layers=8): - super(CIPSFrameDecoder, self).__init__() - - self.feature_dims = feature_dims - self.ngf = ngf - self.max_resolution = max_resolution - self.style_dim = style_dim * len(feature_dims) # Adjusted style_dim - self.num_layers = num_layers - - self.feature_projection = nn.ModuleList([ - nn.Conv2d(dim, style_dim, kernel_size=1) for dim in feature_dims - ]) - - self.fourier_features = FourierFeatures(2, 256) - self.coord_embeddings = nn.Parameter(torch.randn(1, 512, max_resolution, max_resolution)) - - self.layers = nn.ModuleList() - self.to_rgb = nn.ModuleList() - - current_dim = 512 + 256 # Fourier features + coordinate embeddings - - for i in range(num_layers): - self.layers.append(ModulatedFC(current_dim, ngf * 8, self.style_dim)) - if i % 2 == 0 or i == num_layers - 1: - self.to_rgb.append(ModulatedFC(ngf * 8, 3, self.style_dim)) - current_dim = ngf * 8 - - def get_coord_grid(self, batch_size, resolution): - #print(f"get_coord_grid input - batch_size: {batch_size}, resolution: {resolution}") - x = torch.linspace(-1, 1, resolution) - y = torch.linspace(-1, 1, resolution) - x, y = torch.meshgrid(x, y, indexing='ij') - coords = torch.stack((x, y), dim=-1).unsqueeze(0).repeat(batch_size, 1, 1, 1) - #print(f"get_coord_grid output shape: {coords.shape}") - return coords.to(next(self.parameters()).device) - def forward(self, features): - #print(f"Input features shapes: {[f.shape for f in features]}") - - batch_size = features[0].shape[0] - target_resolution = self.max_resolution - #print(f"Batch size: {batch_size}, Target resolution: {target_resolution}") - - # Project input features to style vectors - styles = [] - for i, (proj, feat) in enumerate(zip(self.feature_projection, features)): - #print(f"Feature {i} shape before projection: {feat.shape}") - style = proj(feat) - style = style.view(batch_size, -1, style.size(2) * style.size(3)).mean(dim=2) - #print(f"Style {i} shape after projection: {style.shape}") - styles.append(style) - - w = torch.cat(styles, dim=1) - #print(f"Combined style vector shape: {w.shape}") + # Reshape and reverse features list + reshaped_features = self.reshape_features(features)[::-1] - # Generate coordinate grid - coords = self.get_coord_grid(batch_size, target_resolution) - #print(f"Coordinate grid shape: {coords.shape}") - coords_flat = coords.view(batch_size, -1, 2) - #print(f"Flattened coordinate grid shape: {coords_flat.shape}") + x = reshaped_features[0] # Start with the smallest feature map + x = self.pos_encoding(x) # Add positional encoding - # Get Fourier features and coordinate embeddings - fourier_features = self.fourier_features(coords_flat) - #print(f"Fourier features shape: {fourier_features.shape}") - - coord_embeddings = F.grid_sample( - self.coord_embeddings.expand(batch_size, -1, -1, -1), - coords, - mode='bilinear', - align_corners=True - ) - #print(f"Coordinate embeddings shape after grid_sample: {coord_embeddings.shape}") - coord_embeddings = coord_embeddings.permute(0, 2, 3, 1).reshape(batch_size, -1, 512) - #print(f"Coordinate embeddings shape after reshape: {coord_embeddings.shape}") - - # Concatenate Fourier features and coordinate embeddings - features = torch.cat([fourier_features, coord_embeddings], dim=-1) - #print(f"Combined features shape: {features.shape}") - - rgb = 0 - for i, (layer, to_rgb) in enumerate(zip(self.layers, self.to_rgb)): - features = layer(features, w) - #print(f"Features shape after layer {i}: {features.shape}") - features = F.leaky_relu(features, 0.2) + for i, (upconv, feat_block) in enumerate(zip(self.upconv_blocks, self.feat_blocks)): + x = upconv(x) + feat = feat_block(reshaped_features[i+1]) + + if self.use_attention: + feat = self.attention_layers[i](feat) - if i % 2 == 0 or i == self.num_layers - 1: - rgb_out = to_rgb(features, w) - #print(f"RGB output shape at layer {i}: {rgb_out.shape}") - rgb = rgb + rgb_out + x = torch.cat([x, feat], dim=1) - output = torch.sigmoid(rgb).view(batch_size, target_resolution, target_resolution, 3).permute(0, 3, 1, 2) - #print(f"Final output shape: {output.shape}") + x = self.final_conv(x) + return x + + def reshape_features(self, features): + reshaped = [] + for feat in features: + if len(feat.shape) == 3: # (batch, hw, channels) + b, hw, c = feat.shape + h = w = int(math.sqrt(hw)) + reshaped_feat = feat.permute(0, 2, 1).view(b, c, h, w) + else: # Already in (batch, channels, height, width) format + reshaped_feat = feat + reshaped.append(reshaped_feat) + return reshaped + +class UpConvResBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) - # Ensure output is in [-1, 1] range - output = (output * 2) - 1 + def forward(self, x): + return self.upsample(self.conv(x)) + +class FeatResBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(channels, channels, kernel_size=3, padding=1), + nn.BatchNorm2d(channels), + nn.ReLU(inplace=True), + nn.Conv2d(channels, channels, kernel_size=3, padding=1), + nn.BatchNorm2d(channels) + ) - return output \ No newline at end of file + def forward(self, x): + return F.relu(x + self.conv(x)) + +class SelfAttention(nn.Module): + def __init__(self, in_dim): + super().__init__() + self.query_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1) + self.key_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1) + self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1) + self.gamma = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + B, C, W, H = x.size() + proj_query = self.query_conv(x).view(B, -1, W*H).permute(0, 2, 1) + proj_key = self.key_conv(x).view(B, -1, W*H) + energy = torch.bmm(proj_query, proj_key) + attention = F.softmax(energy, dim=-1) + proj_value = self.value_conv(x).view(B, -1, W*H) + out = torch.bmm(proj_value, attention.permute(0, 2, 1)) + out = out.view(B, C, W, H) + out = self.gamma * out + x + return out + +class PositionalEncoding2D(nn.Module): + def __init__(self, channels): + super().__init__() + self.org_channels = channels + channels = int(math.ceil(channels / 4) * 2) + self.channels = channels + inv_freq = 1. / (10000 ** (torch.arange(0, channels, 2).float() / channels)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, tensor): + tensor = tensor.permute(0, 2, 3, 1) + _, h, w, _ = tensor.shape + pos_x, pos_y = torch.meshgrid(torch.arange(w), torch.arange(h)) + pos_x = pos_x.to(tensor.device).float() + pos_y = pos_y.to(tensor.device).float() + sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) + sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) + emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1).unsqueeze(1) + emb_y = torch.cat((sin_inp_y.sin(), sin_inp_y.cos()), dim=-1) + emb = torch.zeros((h, w, self.channels * 2)).to(tensor.device) + emb[:, :, :self.channels] = emb_x + emb[:, :, self.channels:] = emb_y + emb = emb[None, :, :, :self.org_channels].permute(0, 3, 1, 2) + return tensor.permute(0, 3, 1, 2) + emb \ No newline at end of file diff --git a/model.py b/model.py index 682aba2..d44739a 100644 --- a/model.py +++ b/model.py @@ -15,7 +15,7 @@ import random # from common import DownConvResBlock,UpConvResBlock import colored_traceback.auto # makes terminal show color coded output when crash -from framedecoder import CIPSFrameDecoder +from framedecoder import EnhancedFrameDecoder DEBUG = False def debug_print(*args, **kwargs): if DEBUG: @@ -397,7 +397,7 @@ def forward(self, token, condition): For each scale, aligns the reference features to the current frame using the ImplicitMotionAlignment module. ''' class IMFModel(nn.Module): - def __init__(self,use_resnet_feature=False,use_mlgffn=False,use_cips_generator=False, latent_dim=32, base_channels=64, num_layers=4, noise_level=0.1, style_mix_prob=0.5): + def __init__(self,use_resnet_feature=False,use_mlgffn=False,use_enhanced_generator=False, latent_dim=32, base_channels=64, num_layers=4, noise_level=0.1, style_mix_prob=0.5): super().__init__() self.encoder_dims = [64, 128, 256, 512] @@ -426,7 +426,7 @@ def __init__(self,use_resnet_feature=False,use_mlgffn=False,use_cips_generator=F ) self.implicit_motion_alignment.append(model) - FrameDecode = CIPSFrameDecoder if use_cips_generator else FrameDecoder + FrameDecode = EnhancedFrameDecoder if use_enhanced_generator else FrameDecoder self.frame_decoder = FrameDecode() #CIPSFrameDecoder(feature_dims=self.motion_dims) self.noise_level = noise_level self.style_mix_prob = style_mix_prob diff --git a/train.py b/train.py index 41f0385..0c412c7 100644 --- a/train.py +++ b/train.py @@ -366,7 +366,7 @@ def main(): num_layers=config.model.num_layers, use_resnet_feature=config.model.use_resnet_feature, use_mlgffn=config.model.use_mlgffn, - use_cips_generator=config.model.use_cips_generator + use_enhanced_generator=config.model.use_enhanced_generator ) add_gradient_hooks(model) From 1107fee834fb4c5668b34d769d7c55b70ec10a14 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 08:31:18 +1000 Subject: [PATCH 063/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 8541 -> 11998 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index bf60d380a3d8d45ef950806e8d578b62ee28eee7..047d1f049331f0fa56c2ecdaf6f063cc29150dfc 100644 GIT binary patch literal 11998 zcmc&)ZEO?Qnx66J*qQhP$4LkYH6#!cARz&IDdj66A5F_gA<#k!^)>b+IQS!+aVSJ~ zL09cwgK}?OlnS@Xs_tFN4GY>$rB<6C`y)U5(-X~ZtdSxit*RCO!a}91{@H!c8PC|` zI4P}ewa3Xj=bU-Z_xnEYd(N4^R#sXm2=h<>ZTN+oqW%jflwiz9p8pR-?ok0MK!>Oy z`ALU$LpoZH>%%%f#KVRmV?M_;WYSUnR6zF;70`3KQL|Fz13g7Of`2&-F@(ziT*fuH zEQHGhT;?^ntb~gJF3TER6@<$QTor3@u>pFJYGI$^Zl3VxdrKZl-=Dt<$v?uZaMTbT zpoVlKFgt&t1G<2ISf8YZ>44!Q%aA@`{D>Md1fcIj#w2x}9x_3_W~r(-a@iO0bAdkO z3v)f3KN{dri{UAyqbP)@{V7EQxrJl6#WQ{bP$gI(f^(m@0{x8sqv;+M19l}mfXE{y zL1l{GsF{?Oho<6mobp2Ba)|37rHqt{>mjAKAFEKfC~DNC#I^MSzq;Zt$y=_u!idwr4LOC}JEk~QkjhNanACz) ztiVpijkiq?V16E{bEIt4 z$wXX_wwIr4kVIJD07&V_y26l3>(sUgD`-=CsHs)$wkRMl-0_RcDA8EPK7VFMTt>(T;@W1E5@M8Qh)-zV@^E9;^S9k2oktYnCp2P zkk9CQ)CjEPMQ4I8)~MS$ZOxzRCf5y852+SfV*5B>>@>%Bhob%~EqY1Uwo9VDk~uma z^9F)p9&3@ubL*$Z_AnO+`r4yhxV@wEOn3Xh@xJz9to;F5``gBD{A4{43i2_3G(2`_ zqN5x{w0XV3NHFI0wjF@>d?Ef2M3f>m&=kK3$lq>J+1(f9zwdqD9hw zz(gS|uI1-q2fQgi6Qx=x&jezj08B>ZAQu|$ip97{EEtWn=n$46Nss3N>y5{Q@FJph zKz_yeHb60cCy-mzGNs#XE0A1s>t~f|zu?#`IyNusHSZ8!-KYNx0amBt3Wp+G$jb*O zILXAvP%ywr<}qI&5R8l<+{czG$;gk7aY$m6xt6TK$XI?kB`Yj!J|R_(^PJbG3|G=c zB9irZG%z0Gj=`G1>oI>ca?N`Qiv}-M41=U72_y1FMmR|y;v$mCH#Wva0+Q|Q*iqnD zgh{dva#zRkRQf`aLv6+z^4;JNFPSiJXFxI!21mltU_i1C2P3|a7dI(Yy&dI)cny4^ z-Ut{M7+q(N#EwOIFPBfZFp`0X%$a9&#oc)|6CSi2nnq+}}Vfmzc`i=1KD+Q@g;_3QV2I)Mc5vCv{Cp!=ycz ze>Aiv4e;Ve6SdJZeerYew3nopHZ)8h{QU6rVVtV{!CCzopWgrJ&8eF+PNAk%aBdf! z+m|VW#Y@XH*_Cs@lA)(ACi`;Lo{Tfo@s(>Tn(WW94w2os$ZlC+w|pI$eP`}nVQa6r zwO3&KM7A%>_7R5qMYeH)ZT!0X{&7`pAo3rB=d$OAgu(a3!S^63?7kp^rZrs@*!M;D z{Ve%gHc{;QWlC?cFEwqwf4oRc+xE>l=hn@wdmJG``V33)3rrxk#SXD8_GFL9Zu-$m z!C-eTv9`&9)M#eiR3tMx=lh<{4xY)fX9f1G$ezV7-@JWn^4K$@euMP|1q4sxk1W;D zFI)&lq&KXioJ}uxwK?VwO8OLVULXj@TJTnU=<%5 zw4+GqCojT`5no*4Rh}z*AJ>C5GoC;g@d;`nT<{@{s^QU6AvqAoU-pf0UJN11M6VPBP$l5p?8@fMgzyq7Qs1Fi~F?KJst3LACsU1Gz;#srM|_cP!L* z%x)9vUl;3NpX-6&rMVulz6*r=3EPnNiL56{gZ?hE4GTCeupW~2L-VfhT(f7yz5U-; zi38__jv=vQNN9dfY<_Q`kuYW9jXdsBuV9PTMg zx+5J(x6C-Rj_rbDyXe@S>`Hb$VI7~GymK<$C9oR-kD%LjN0M2oqF5n_inbCJ9ic$g zXb|&1f#N`p65{$%HKi5B0DXn1Ji1WKa{=m17!sf$tEpd+@)@-?YE#$gggL>)^%+$` zl|JDo-Vg;0+HGG(RjFdmxB*ns9A_SY(mzraKVgaJ;`)RouFDibL!n&UqU{a1!heqI zWc9$Nhp5L>JOM;UK%BQg9Be$+4@n=4#3V!57rP9`@G6RdFTkPjI4FqWT9EqyZ3W&I zE71V7AD3ek;YWvYlD;JAE`K1=eu=&;(H}^a%4?z>SJ=>kvlIbe*6Wq5 zUT>KAKoDoW-mBxjko*Q=X8|3-1kZs~nRhJwTqxxA0(f;oE{}!yh&)qR1`sBA;rTHj z;6W(A+~!xa1lcyp_Ev{#p)9Hca^@i|nPHm{ z^7%Q`4H*)1DHsYyI3I#-2g}1=4Xm-x57rg%22OjVN?#~A5&?G|HYA7#kipvHH$o}G`qps1p>~l zQdc9n`FbUjvu;m`Sr!QQ`89{~x%-=VvbqmnA2a8W>a|>80V4CTDlQ zvTg0TTQLk0I5QQ&8G0Ay{0uq=6uhL>&Ywa z23{@#81FztYoiPPHamSl6>Z#xL1A?ww@LIFDpi9rUBASK!nXk$P$(e`%{KkZ8)8l~&6iKLkoVIjtn`_CcMJae_lp zUNVmO!eL(?Mo0OmTK#!Et#z|TW7l-ZilTzZLd(;_N&jTPp;B(L!NURJGAgkc=UjU&>=n{IoMPKzK zt;Ei(0zu$x!`9k30fMF#qKqn{#l0!FVNbwqSwlwIq_-9Eiq+-~&I@7fw41a<1)M^v ziW;xLD;p>GoA9OrcAH?;3nzo!FnprGXb;5TaL|A+KM=y%*tJI(>k*F$(M$cMa2SOl zDVh1`a4hV*j$-)b2S8fbfv7g}DBf9n^nI&iS)cW-P%yq%J zOLXo^nwF~EQv>PYna$a%okG=4v1(`1@=~2hT9<6D$za-=8CY!EyU?_E)(6)ET|(0# zvFVUtJ1p7`C-pg-^OMVWE~l=hH>FQyXu-Bwv~9-O>n2Cjy9L`u=r75zugilEDr?`d z?L?Sxc>Y${ccZa$nJPnUR-+5=2v#SueJ7?FYF5LqvF|&fw8p+;Yet~+$luAnV^6K{ znM3hqL#zlc3Rexp?!(qYYbOIb0;bTX!siGH@DZSH6~OSRXL-O1zy}Fnhhba=FQPnH zX-xvCIP_!!%n5$NDex-S1p&C&&;ktpC`#vhZq2Ob{+B zhY@$R*v7-*0U9YB9`xV}dNH@b&jmy13QiN;FoAi;SDMD*P}CQb?1&qIOHdAZFO3fm z$kfhs zeN`{iw23usvv5<`DY$lvuHA5&TI<199!^%%?yvXWKLBdJZ|;=Pcvx&ayx7>g(AX<9 z_KS`E-)|EdPYde@1=krU#ITD@{Q^^;9-4s*%ubQ%%qo8?O+S_fBEN@VoU5CDytO2! z%+gkZE zd#zl9gJb$l>ebwA6rJ=L;X<(JN=~bZ;E;ehRf47z{$xl+iIP%`rc}W=o6CO+q@97d zIfKpMilF5h)&)$DOPrz%-dC*%aNnZECd(?lV_vFv%&+ZT!m4Rk{h)|AHPoo;&dGQ> zTq1+NR3^arvBWFz)%3?Yg~E!dt$eH-)*pbmid(f@#XPudo{Q8}QkO%SYTzpR#4Mv^ zYo4IhU;_Z@wOU=V+6JIJKe%^LE1)-XLcZxG{v6&@5T|a!9NaOw^H6Y9_V|d4=agN# zqnL+4GZTjE;1T%N1x-Spq>Du*Q-He`grn*|;S3WOjPT=OL}2+foTKN15x4_}{h4Jv z!h>JOfy;OeS4i;t4#e`lpk#utefTIsc%IR5OiloijQCrV>+%CUK&TJCTpEKMHF`rb za^Xvu2S08|dU&km4%U~o3wf{b56~w6pR!lzthu-G?#2wA>B#ut+Ys_$NYa#ZZZd;Zn!6vG?f5EsKlf_|Yx-}qNn>(6=WyR+?lS4>=?1~k zEIOJO9Xl5sI|WC(=x9$Gf9Ca4)%vMpne{WhvwOZhC{%TcRb5F-`Al+4{MwfXzaSsD z%nbbN@qa%4@Z{r@*}4Pso3oBy!O<%^dKVoh7aS+EZ=Vtzr$xusYkCwqSd0ZqwYU zd3wG>u=R?zUTD?s{^Y|uA3_(MnXb%L!QLX;TV{?5_Vz5(E<2fJ9b4=)B1~>z5m@2a znvA=bsWQZ2GwxcZ)*^D7xaOA1ydIHnBibSnA18e$AJEASI-%;Hm;>@1>LK(`kOZ-&`0YTK^)yZADE1clYf^ShVx$(UT&+QVX0OUxxJBm-8iDE}p(j`IuGgdLXzy0{JV9_xu%j{1vP~poVFl8n*EEG&M@|*3VqSRvuAkTUboRFY0oeoLV~X@RH00;ssFP)+?&m`hjg15*CPL9MzZ9AhG)pw9xh@98KtzZ`3wt~!2Cr#$F%adW8hyfN z*KCpy9u>=ASdj$+eFp2!Eiv5{C)`+*aGR!&%3)L0v`%=$An1t{@$Sz-DfZK zb&q1_ybz4>g4i{2TSU=J)ozelS7E9sC+QbJe!4>~Gv(9H8RrsHyU5fkOufp~%S?T~ zt}Shw@#Gu#q-`owx6D<|HhwL9b$#yed|2LlQfWS=HlKQuP@2!nR|7IPqHrTBHv&x@ zY4|*cuBo80qws)IBbxDNkhu?X&B6cXl%&QD`Ix1ZqT_~)LEeTIBxy4r$crBE*)&Ha z5RW9nA_s$r6+kqr6pH#RVkPG3TN2K$rA6Fae3V0DT185Xb#7$ua_wia-?vU=8+JvY~=R31_*4}ocOF4r}D)wIldX3l0p zxhjS2P}vTAJUxAO=B!>G%)YI#ud3{;&z!b;_X`RL+R-1~RDI+B-1dmV?oru2=C)~1 z+CvUw&`=$iH?EW%9XA{uG^pcHpr@-)p0cKFDSL7oHw(uBS`}Sk%8_*B3MaeJ4f#)k z!YA!!p#wLSE1iqcVpG4)q;vazdCP`=o#3V&=#*37_H)JoHM)n)Msp9Gbke!r1~HFy zU=VH83$7+|Z1;9)1g(A&%BP84uUxTdpDi9w-)Jv?#Iy+H7%{>x`2e0W~+Z|0r z{1c&DLRb>_Ay2ukq&J1o*tNJM5*0%2M}@l|9}sj5<)TlXPtP$B&dkUP|xgX->i zJf<+OrQMnST-h2!)iwRr^Q9-3P-;8X+Rk+ON^Qqd?cT-Oy>oBPPcFQs)b^>hed%)G z?MT~KygTl@zi?;Y%1zF-Dc(-i+nKflU%4k!uCPsM8uQ%er|+K5 z9G~u=>DLQ98T<5ElJ#qEZJl23YO)G;iSnp7Qol9j-bETy>(u;TfP%3ZT^iir~_` zC(mBI+=syEln#u1w$(e(sqUFOFo$IIn5k?PMH)0LIXoC%im{@X`jyR$YvUycCrOzt z`qs*C_73YP7y8x{LAi$VC#~QI3VT9L+e-26#>0Tepww$5!2lGoY#bGKJ&mJfBq?nJ z&<}1qNuSz)zRe~AdO~-z>;p8n-ppQIrooJfmN;;Ev-Zlkm)=R|w(3>&u|HUlU z=fJ!EgjuY&NjYYcYAw>qs~TU_&|znF9L-?CT96VCB{*%YLwK1(H8}h_=vTuIA7mR9VQ*dQdD7$U57S! ziT%V)Nb2f$PP??S;rfZSuQkWi<>5-Y9mG!%8zN!Ww7?W?_hgaDK zxY>St8BkrvA+_U>Qis6Kldo@wzlx`A?ca13kDVuCeuld60tw`HmqBiK8Fsqs zM}}(cnWyJpS6bdsTi(F3rsbBlZ`kiXczk7nTR8coSswH&ou8_mpDHZ@wIx7Up4QaP zM()LCV=#Kza_PfL&8~FWQ?4d6rEqO3*9Lbrmrl~h^Q?E~+!EWo$Told%el7~tP6oZ zIGVFy)qP-X}772eGL=|q;kJC&Zw4(9gGP5!QHv8hLH>dAYnXUmqn zZHwNvZ{MDi<_90%T%Z>Ye1BBg-LLNMSG?y`?>X6fE??i2qksKs`ds?lnyWB_&0rL) zlMY7z5(H!Rj>RL-gYO1_2zp-(dZnyen};oVZ6;pB>w?d)@z%#SbI(_y9zn3F4W2oY zu+69iVn24X@7fx4@(#02&Kdr=&>#LM5u3or?7skn+jF=-n^GL366*=zE~1Q0=>e2H zGD^*}7&qcy=5jEg&H-^hv?69yvxLUMlro_xFWiFD4B>+h zF%dz!HYSF6M8tvMbMZ};*q1S-xm#`SR_c0G{Jx=7ys1{aDKl^CXU5Q3m^Mle_3`=d0qfuej;D*TY8)!` zR9c#X_-FAvrL1EV=y!=O#tlS)IB?v+P@Ccma7v4pKfp0ba$4+!I#Ljb2AtwFYgLL# zGTZ1q=seVIspEhRO=E4#XEXC5k~N>_T*=xBdjpur3Mbe!0?gC4^;t?XW=@)#No|cG z!JaraT80#0rJR9)i!*?owg6@Y$rj@{DRIRj;v_5xNJ?_Lfwf6lu^FxsZJ?07Kw#;UzHwv zWD221kvKFi9uUzD75$p#AWdj3%egOfSybF>mP zP0+PIAtx00A;hUUfWX@%300hUNi!*lHPvo%l0hWmboNAi0f!QnSj#qM% z>JDGpoo{SW8@pwu4#Go~)oI7_u07vX%?+uYr=DC>eDABi_jNeAM%i7hgiODmK0cky z*fRV|b#3NUw)$TGY=8Q_e0AOZ{xABoQ**s@iGP^-`>BQ2hpES@g~Y#3{max3t$#>; zpHjNdE7cd&>I+NNmlmrp$-nqWuD+yHf2>x2oPKYGt^EAL-3!?cg>6yU7MXm}rJv1= z=HQ;yuCVPY+fHItl{HIT`y$t#8IRUt9-#J2Ky@)ErQ24x}-p z*4VPt(6iXkGvA;zyrVX}BQv!>I;h5`uigO!t+K64?4Ct-&s_cdrG?6cOA32TWsl|A z$~@PQuWw#+*&6}n19^c2+OB!1#(hf-uP-*dK7Slya;*!4O2fNq!@EljrxzPe|KOH~ zK9oPYqBLAp8?NG53{_n>dsuGnl6#KHO~;>za#g=l)vs3drx}Oqe`y62B8Z3h>d4$4gj^Yty+S8}y;`<41mwZ1dsT&bzc zp1v2!Sbu`6#%xnoQn=kJw|j}(v&g~5$Orr6Gat)Wt}5IoD)-3}CoFP;JUS+a#^rEC zj!h`sq{>aE9kAI|b=m#*_U34XYbBs4T}~j6&ndzibV_VSf?*&1l>j-qXko<`B&|S7 zV~*rLjN~L%#DKtMukgvYzbTXN!iYRLB!3!|CnndZIy+pCwj+m}_QPwGK`M{fYu2cj z5i)lXXJ9~N^3t{|8r5&wRfWf{(L|J7uHjvVyOA5wbv=v97Z zj)0|v*MUdHzk)}AGL)nOS+mkKou}A4 Date: Sun, 11 Aug 2024 08:34:37 +1000 Subject: [PATCH 064/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- framedecoder.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/framedecoder.py b/framedecoder.py index f2d7b5b..356c4cf 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -5,7 +5,7 @@ class EnhancedFrameDecoder(nn.Module): - def __init__(self, input_dims=[512, 512, 256, 128], output_dim=3, use_attention=True): + def __init__(self, input_dims=[512, 512, 256, 128, 64], output_dim=3, use_attention=True): super().__init__() self.input_dims = input_dims @@ -128,17 +128,16 @@ def __init__(self, channels): self.register_buffer('inv_freq', inv_freq) def forward(self, tensor): - tensor = tensor.permute(0, 2, 3, 1) - _, h, w, _ = tensor.shape - pos_x, pos_y = torch.meshgrid(torch.arange(w), torch.arange(h)) + _, _, h, w = tensor.shape + pos_x, pos_y = torch.meshgrid(torch.arange(w), torch.arange(h), indexing='ij') pos_x = pos_x.to(tensor.device).float() pos_y = pos_y.to(tensor.device).float() - sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) - sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) - emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1).unsqueeze(1) - emb_y = torch.cat((sin_inp_y.sin(), sin_inp_y.cos()), dim=-1) - emb = torch.zeros((h, w, self.channels * 2)).to(tensor.device) - emb[:, :, :self.channels] = emb_x - emb[:, :, self.channels:] = emb_y - emb = emb[None, :, :, :self.org_channels].permute(0, 3, 1, 2) - return tensor.permute(0, 3, 1, 2) + emb \ No newline at end of file + + sin_inp_x = pos_x.unsqueeze(-1) @ self.inv_freq.unsqueeze(0) + sin_inp_y = pos_y.unsqueeze(-1) @ self.inv_freq.unsqueeze(0) + + emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1).unsqueeze(0) + emb_y = torch.cat((sin_inp_y.sin(), sin_inp_y.cos()), dim=-1).unsqueeze(0) + + emb = torch.cat((emb_x, emb_y), dim=-1).permute(0, 3, 1, 2) + return tensor + emb[:, :self.org_channels, :, :] \ No newline at end of file From 761aae78c702597e6c7b9ff05d0a062c10995be0 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 08:36:27 +1000 Subject: [PATCH 065/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 11998 -> 11818 bytes framedecoder.py | 37 ++++++++++++++++++++--- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 047d1f049331f0fa56c2ecdaf6f063cc29150dfc..60f4632a1585b4fbe3acdd2bc23ba6bf4b8facf8 100644 GIT binary patch delta 2395 zcma)8UrbY17{8}~E(O}#7D`L`TP(D+2w0o~Iv0>?<`OjyB%0xBZ7*<1|KPbTW28i0 zOw9hcW#>VbEE!ue{Da$EnruF8?rlr9r5HBNeb{`M51J*LiSA)9+xOjDK@rW)rN4X6 zcfRjC-`{ii_1L8m%Xy2XjDzv_-LFQoUs&dd%Xo*=aU3-Q@eW4-=J2a8{~iEbH0 zq%FK*o1%AP8N=g zupe!W?3RsZf7rHyK97ALY0XdAzb9VRNHrij>PGFg07DmvQX-`$6VaGR2iS1swyHW@ zQ*CI%uK+?DTd3Un9QvS?Lk42w(Ns!I$cj0g&8*He!wkUW@p$I6bk?&H#|^OKH%HbR zLrqbDi#jbFyWsn=ua4B_-Jz{4Qmyj2DEpz0`(^0bvn zHDMy;-*-wRw(2@P0&7F`C?F(gLr+mSf^ZxG%?;>Pa~0D=w>3xgu=I$ZV)Z5P)8@|j z$5a`G>{Sd@Ok`CZbR3B5d|h&fJ*uq|tUck7fJA{=Ylbw31l1uYM6xFTUCpbc0aFwRfs;kr;)FR&=a8Ih)C#W)sCCZk1PXcuZ%3u&Dj=j)cB`v&Vz>mqAyG=SW= zAKeBJC|3-!1X^XJ>_d1FA%mc9sXviS(iBXU5;;AZNk*q>78j=h0_I}!S-0mEt@!)% z^PX<BbT5Enh>22yCUD-% z_BQ*xYD=o~H;|}CV}13$)BH7L6&ii}FA%76T#iftuQo=2c~%FHe%6q|Fc!xI9s=sf z=$XB?%5GSMZe4=kqf}<{SsUWyvThMXmsZ6`KzP1{U(Uh>?BdbFnF*iv3$7 z@UzAozwrV}l(3U6RqUMJYRGXDCic0nf&J{amCc}`B_O!*x<_3%J>~*-j5=~93uTP7 zdf9P*i+NNR*Iza+f{9DIqV`Jns_7gzu_k6x#ah@2zjq@sd#LOw`<3R(9^K}l9Sdgm zt>1Yn&c{pQrbTdQ2~(wx+43C8mB+a_cUl+cz;nED)PZx2ia2V(V^141@EYU-r3Vz9 zGy%7hNlL}VX(=^E*Qy?CouWrt=?}oysenMkNKwn&l!TLfj7o9(HqaDZCav&saY}-t z^ehrevnlyxRurLC)C${yM7=Df;A%vrJY0@)no=9k=mCU-C}EsT%aLjI z#p1zT#hzj3gFgOUD5^P*-3_*y3x^IDT*J$*;T6{~t8NSLzvXbBJNVwgh26^z|BAzZ zq4{$AjbQ)vVE?tI;ll2M z(77yht_YnsgzoD?_f>K=bTxX-zAOx^2m@#IH?39k`uTo*3XP!!!-^*a?5Gj_Fc9RJBr5-z|fnr5z3h`f#e@xFORe=}iC;5-t^`zHHz?vsvzUzbD Mdq}^(!J?|~KQ~Y$>Hq)$ delta 2645 zcmaJ@YfMvD9PjDlQfPZip{0U6ih$7aEaH^u93Y6BkNF@Oi@I9d3lwRq=N9L%6+X;_ z8J7_ci7c5*lQ}2i2bP-Xl4axTk|k?RGu=&G_F>7wm&Ij#*!;Br|FobWdM^Fld;YKU zf1Lk0+&=!>3DegmQ=SUO!s5MwrEg3Vq(Hltp8B)p-DgYD`P=r9p11dBh=eZt_Z^SFh z9*?Ytxr2|6iHxCApp2EVxUMWQsJ}wMWY7*|iVIS7d6v*3csUJ9K9%Gl?-|K6 zAbBIBFriJrP;#ZvXi%cwh@@!70%zzZ)T}~4kEj!6bdfOi4JjI*H$oeL;MCJ*<}y~3 z7WR&@o>&qWj2m(NYo>j~j*jy67gwAB`5N}0SYNssmGfvbio6b>=(A4w>6@&>+-l$% z+6tVp8VJnBuA1+XLe`t_YQSh|2LiGcQ#La`X=KFXfhb#nDnlVu9n`|E=T{kDOK!ry z1$VNqpZ%2IFBGD(mW4{XNo^u*`H@s{S9zAybOUlr0QzQ$k`oA>8S(lg+Qp6)ZnbSh z9(TZjUv&sB_Eq5)(#(D>tg!O6wLpdXEEk6vtq;NVKdhrwr%?VRz^pnG`cC*>wI@1Q z)V6yEIyfuPL-3-95sn}{3;ZZf`v8{X#Ca_T_SE)XCcs)ApegZ1QGobX0;KPPSRd^N zgkr24588Pj!GmxTz^UQxDth0LHx!aa`_vzGK_0y6P{Y>7n4c~XwWauQxFr(tazF7{TbyLCH= zD7rC;hGj)J=nV$F**K^hE%REf3E@uZ9PHlyvP){uv+n6aXAuE#~)31^N*J_*8(H!=B z9rvJUK5=TH!TNJkh+>QqRkvZdQt0T4M|E zBtAe&Tf96`#qo=6Ezd8--2i*RuU@ko4B726{t|(xVGK+hg|q2k=IZK#DP1|e#&QA; z!;QNF!<4rt??+3f|8}Dw?>~JI?zcf2@Y5?mQ`C{LBKW0q0XW4!L4qI!Li^>>Amtms z9Z4EF5Q19`4vO5-kbHJjlAyuQp+(*ym*Fehxrz{xLUNc=ED1e;aFG4!+@I}b*6!L| z_6rndLiNaEzgB*=JWk^7xc8PNWvNeF>L&#zx$1nYVsB{A$N1b&tiuE2n4NGv2uMM%lb2RoR*>*^(05(qh{~aoYoN z+kN3XQ%c;E7WYhQ9#uOhHItqEIXel@RMS-3^cz!6NuwjBuY8eRJ>3W*%WTmyRXs!a zO$?@`*nMk^#C0X8*ENDrK?uTOg1F`*n<^978WWk*hd(vVLENixXdhzl_@X-?Q zjGs-JooTc4*6x(KIjL`^FF*DhF{cP0!cWGt_1_6dXXbgL7Idt+!M;n>wLDS13NkC8 z9j7U~FZe$%O`#BVu#XyQdV;{q9FD?)uzz$!qHsE8?@<$e@q-sFP$L3&8ISCz`kg_p fu#0|~_^Uxf+U*3qc@bt>KX3mV=}YS{xxN1Z4K_{R diff --git a/framedecoder.py b/framedecoder.py index 356c4cf..e9f0193 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -3,7 +3,6 @@ import torch.nn.functional as F import math - class EnhancedFrameDecoder(nn.Module): def __init__(self, input_dims=[512, 512, 256, 128, 64], output_dim=3, use_attention=True): super().__init__() @@ -38,27 +37,39 @@ def __init__(self, input_dims=[512, 512, 256, 128, 64], output_dim=3, use_attent self.pos_encoding = PositionalEncoding2D(input_dims[0]) def forward(self, features): + print(f"EnhancedFrameDecoder input features shapes: {[f.shape for f in features]}") + # Reshape and reverse features list reshaped_features = self.reshape_features(features)[::-1] + print(f"Reshaped features shapes: {[f.shape for f in reshaped_features]}") x = reshaped_features[0] # Start with the smallest feature map + print(f"Initial x shape: {x.shape}") + x = self.pos_encoding(x) # Add positional encoding + print(f"After positional encoding, x shape: {x.shape}") for i, (upconv, feat_block) in enumerate(zip(self.upconv_blocks, self.feat_blocks)): x = upconv(x) + print(f"After upconv {i}, x shape: {x.shape}") + feat = feat_block(reshaped_features[i+1]) + print(f"Feature {i} shape: {feat.shape}") if self.use_attention: feat = self.attention_layers[i](feat) + print(f"After attention {i}, feat shape: {feat.shape}") x = torch.cat([x, feat], dim=1) + print(f"After concatenation {i}, x shape: {x.shape}") x = self.final_conv(x) + print(f"Final output shape: {x.shape}") return x def reshape_features(self, features): reshaped = [] - for feat in features: + for i, feat in enumerate(features): if len(feat.shape) == 3: # (batch, hw, channels) b, hw, c = feat.shape h = w = int(math.sqrt(hw)) @@ -66,6 +77,7 @@ def reshape_features(self, features): else: # Already in (batch, channels, height, width) format reshaped_feat = feat reshaped.append(reshaped_feat) + print(f"Reshaped feature {i} shape: {reshaped_feat.shape}") return reshaped class UpConvResBlock(nn.Module): @@ -82,7 +94,12 @@ def __init__(self, in_channels, out_channels): self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) def forward(self, x): - return self.upsample(self.conv(x)) + print(f"UpConvResBlock input shape: {x.shape}") + x = self.conv(x) + print(f"UpConvResBlock after conv shape: {x.shape}") + x = self.upsample(x) + print(f"UpConvResBlock output shape: {x.shape}") + return x class FeatResBlock(nn.Module): def __init__(self, channels): @@ -96,7 +113,12 @@ def __init__(self, channels): ) def forward(self, x): - return F.relu(x + self.conv(x)) + print(f"FeatResBlock input shape: {x.shape}") + residual = self.conv(x) + print(f"FeatResBlock residual shape: {residual.shape}") + out = F.relu(x + residual) + print(f"FeatResBlock output shape: {out.shape}") + return out class SelfAttention(nn.Module): def __init__(self, in_dim): @@ -107,6 +129,7 @@ def __init__(self, in_dim): self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x): + print(f"SelfAttention input shape: {x.shape}") B, C, W, H = x.size() proj_query = self.query_conv(x).view(B, -1, W*H).permute(0, 2, 1) proj_key = self.key_conv(x).view(B, -1, W*H) @@ -116,6 +139,7 @@ def forward(self, x): out = torch.bmm(proj_value, attention.permute(0, 2, 1)) out = out.view(B, C, W, H) out = self.gamma * out + x + print(f"SelfAttention output shape: {out.shape}") return out class PositionalEncoding2D(nn.Module): @@ -128,6 +152,7 @@ def __init__(self, channels): self.register_buffer('inv_freq', inv_freq) def forward(self, tensor): + print(f"PositionalEncoding2D input shape: {tensor.shape}") _, _, h, w = tensor.shape pos_x, pos_y = torch.meshgrid(torch.arange(w), torch.arange(h), indexing='ij') pos_x = pos_x.to(tensor.device).float() @@ -140,4 +165,6 @@ def forward(self, tensor): emb_y = torch.cat((sin_inp_y.sin(), sin_inp_y.cos()), dim=-1).unsqueeze(0) emb = torch.cat((emb_x, emb_y), dim=-1).permute(0, 3, 1, 2) - return tensor + emb[:, :self.org_channels, :, :] \ No newline at end of file + out = tensor + emb[:, :self.org_channels, :, :] + print(f"PositionalEncoding2D output shape: {out.shape}") + return out \ No newline at end of file From ea4a32a4582738bff6cbbe44ea9cdd624acc2e9b Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 08:38:16 +1000 Subject: [PATCH 066/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 11818 -> 14504 bytes framedecoder.py | 19 ++++++++----------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 60f4632a1585b4fbe3acdd2bc23ba6bf4b8facf8..a1b180e9e4b5fa56c0fa4c4d979ebd11a58ead52 100644 GIT binary patch delta 5513 zcmbVQdrVu`8Nc`Xb!=a31IFfMuu}pUJQC6X;Ss=u@K9bQ=`&EZE;eAA7;>-aBKLZ( zv~HtSq>>z@qP3a$Bg#lXnp>EfN!8l!rD@Zqd#fuovy?S|blSA_k5rAcseiQl&b>Cq z2GYm5htKDJ=R4m$kKgw=+#FlEVEK7Lfr)}L@t=1l{$Bo;<+nQ8fj`uJOqZ{}t^XxW z8`h`LoAjUoL^SaC`9+|wQ&F0XQA>~xGA~mC-A`Qy#%hL5y+K`Nu7lj_rYue2XG~?- zXR;?WKhl(KO2(NSecGT7^fBK=pTGq5Tl8ph^cV!qG?V#A?rg|%D4G(kR5|etbJ2mE zacK8(>aeJhqOeD6w+Ym=KJx)+8}M7EI{J6`Z>F059F0an`vw&><>}St?+!x=dVvwN zL30eWmQnaUlcP3n$Wg^AstN={&eF8FM^lNo{h)!GHsW@3U1D?BDJp1D23hk5<$pg~ zl35HzO`9`qK_eI@nV@;ib{%$eb#q7Kxw~DMyT%GK95#L1A}{@3z#rEu+uoM5vZB20 z*#;g#bB6L3pN#qNvd*4xhaO{v%-6dp5V!H-EOW3emJ5jw*P2kDi9gR z*L7C>GuDu3>N+CYN<93eTL0*|6?k`Y^Ods`jip3m z`RS1I)!=toOzWpr*s^ElBBRxo{+!vH_L|@0kp+xHeFRKQ;dx7w8vE#b6rs4VpltFETPPN_FMEJMP0||% zwI#6~GTR}t9e0ae3mszBz9gOWyrq$=4obx>a&gOizcTXR(ui1e5bBC{&7`&@wnb)J zM7HIwwOF=#?omv^5RFe3I1_hlC8A?sr4R_zTQ%=Cs4dAhDBA`_+u%QJC5y(T=B410 z7o=>BlC4p;HO~8$PxS(;s26JT!p8U2mc;sH)-Ng*HsZBYubx`+OKg?QR;9PBX?f%t zY(o=N388&5+9#rYc%i_}{s8<1@eCvgVUMZ2qMv&D<`|rEToczuNz6+ZK|76qZE`jR z>E{XVDQ*lnN{hUWfn;14*MqxOf(zkMRMdz+aF)<|?6XwsP#od_UZ{0oXO$f%El!V4 zFbGI6Ue|(Og+H~#O%aFzag)F-Bivk7a>BIDPZLL!E)t)wVbW#h_7m^c5%r(y5n>&mrBkC+1YT9LdGFLC!wzl&im)1 z_yeo3S2_NH<$zc;05$27oBP$K#16{rpr};bYirdHgIcZ|3$`ZM`&HXelS@~4kJ=>N zW?kt$YIm~kDQsuxV|cXqFZ83>YyWdMIURz9&p!c@pZ`EO*$g_lc?JlRGy{vSQb>EO z&2c&>frX8nNL5BhsZVFb*k;v+0RInU{c24ja#zIL&;#tDxfTSI1>&nO) z7-uU~3AM?nO+;LPL{UOlnok=RO^ ztxQX{tpwI4MN1o02_2HrArT$AZMDrK61Tdj8+z_Z_&oO%A!iA35u%`Gl(fcx@Z?N; zRTz_i&Jl73e_8Tyy`G@yK|~FY^*B=AzkbybrP*g-?4#VXK)ndRR$Rt%6VT=+37Nuu zrG5G^H2ELnCVL(Jq|}A~TKYs5bpL|B?+m&`UV6<4%!A>&)qzbD;44L-1%{-hxy~pr z3DiIimv0J*oEw0W!~rnNNod~%mad#W@EhUa9K`GlBr9^d^YEz4!()m*Fp$!6p>uOy zZ5EV@2UE1BG@yVdt0b;VJ?&wfEGtX!z~iGp(kMN+(!2(MItW!lEi!5mQOoUO=j-iP z+LuNycP(~-3nOVk_sY~-BroFLu~;unEQ~CbULKPydt}R=d5t35wsK@`cpFZO=1rWP z+{(BJA!GQfvQl~qmzFnq^txh#yxqzAWO*?|`|*pF$MphP0fCeY|GV-ht?xsy`qn6- zvE*{#_X{ofs>_O*suJRk@Uyi}4HIPWOqByaVsSb&LAbPC2M4)IoTO6XdI7&yt9Y+1 z;$`^lD#sDMV92{mpy-yEHci&96afX9HGQ6}QJ|L15E=L4K&?ZWMs>BKS7Q>4WI8%! zI{19Ghe+US)#Wsc->Y^ca-6S0FmJz9nL($1t#hK8#3IArksop!%U{o@Aj${Gc zF+xZXwDsx8>n2wT+Gm~&;ahlTd;aG?x#>G8C$SFczB*YISrFriBd5)klW#DHf zqUQq_IdZaG0*IGQgBo|a`*qw4z`~veCq0a>y31>R2(SE6Fvv87Ta+B$(jr<~B-ARS zRuQ$X|HjSI4GJ0$t(;kVO0*n;Dxr26wTr0zwpBqR`Cb}~AA3sa7xDi*P1P*H;qGMp z=$;K4qUaY7bvlj{1t)-{wBw=hIqm{^_v2$Vje4cIfQ6b4;}ziYgy0LP4BxI4Adjn8=IEGUg4%o&Q9t6&R~I+tsT!&WRp=Y2>f5Hul&dN;L%}GRcAQZG z7_1!ybwNWA1&#R2`fA#OtGuET++&v2ZLNIyo%cxrwMo%|oNmhDfrZghzjKU4|z$*506ee0oy0R6bdzQDs(3Ldd) z|58-0-VeV?kMfKp;SnoZWoMgYX-mhXp5&1Yh)F$ACDbdUUJ>UQ delta 3033 zcma)8ZA=^I9lz(!cfLCx*x+}-V8V+r5J)JHkZx6+WGt1cL=8z-Dyr7OhQPoMk4=`M zJCRM=WTk2A?m-Gkqt(*XtSMQSj!oS@tlPBDQ*mXr$Wql*s??W#AsX7(?f*Z=21<$c zocz1z{%_Co{NEnGJ@?L7^>3=GR1U^}zIlIQ-u+(nKSbhXABmrl`qke{egyXJrI+tU_=3;m2YdT^aJmlwsAI+G<;B`L6Z~GQNo`6CqB$j13E^XbK z>)f=Z`~`E78MY7^V{h47Ri4u&o!2cffmyXi7jSVe=Aa&6a?@5etAz$?%2bYg&2!ue{7SP4>Pvtt-2H|itFLJY+}8rj(SjDtkaf-Ri5B<>zw=tqB(dMt z1ol~#?hOtgueD?%d3#G%T2|K`o~)FedtF&lGLdy=ul4P_Z+8~#%^Mu48eI4GX43_4 z_uEewyuJ4()>zOQVPU%?BU2Y^>h|yJG($DX zK{&_>?)g^?Aq|xB(1mHj&N`c0;9aJTL5K~DAsSM8N1)h6pPPv6cdZTTWnx$jSr>Am znE+oo-9&YfGga()18smT@4tKE)y~**os>h*0fIclxI+inEEW^=^!j6BOeZ_0%#{pA_p zT<5kY>Q;K{gnba~SX=;+ooO_}(MceNg(4X)T4;P`-gFtg0L0zCat%Wrf7?^%QM4K0 z3iqkqaqUtzy4?C^_+~g4xzT;AJMU@F+d~C=XiW~anckvPAWP4o76CX~ibXXN5q`w- z`+C4h9q6RJVKvy|ca=Ii#0(py>QXn7ROP({VY3z6p$F#)sP=F;v9={wT zcg1gYP6cc#Wt#(g+f;YjEL+!AU19%KJ-9)3zQf7>9`vx6JL(qATL`6e8=gof7-qw! z+n_s0+ne)opqm|<4!{ihvn4gP5T_|5KZ$^~M$-sG07c>a>?}>-%$iD0q-Ud-={zzo z1BA75QL*dc<6==r&ZpUDVQ;%xRPK@Z%tR!ejwjMn$waxPW=B&CnVA9%7&D z?~=^vRVMcwQP5ebUjmr@CDyZ~ehle=i8Rfz1HEo zyWZ;vkh}bj)Z)fK7w-pMG!x5HH(j|Xb{fgyt>9uW9K$%h|NBR#k#!b@h2{Cd#eTO|%{ zaF2qLIaCT)K}Gy)uwP1`nPJv;=n3^Lm||y6b+Oq)Cl{qB2wc{?VY%Y5L@U)HaxT diff --git a/framedecoder.py b/framedecoder.py index e9f0193..8cf2196 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -16,7 +16,7 @@ def __init__(self, input_dims=[512, 512, 256, 128, 64], output_dim=3, use_attent self.feat_blocks = nn.ModuleList() for i in range(len(input_dims) - 1): - in_dim = input_dims[i] + in_dim = input_dims[i] * 2 if i > 0 else input_dims[i] # Double the input channels for concatenation out_dim = input_dims[i+1] self.upconv_blocks.append(UpConvResBlock(in_dim, out_dim)) self.feat_blocks.append(nn.Sequential(*[FeatResBlock(out_dim) for _ in range(3)])) @@ -83,21 +83,18 @@ def reshape_features(self, features): class UpConvResBlock(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() - self.conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), - nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True) - ) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm2d(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) def forward(self, x): print(f"UpConvResBlock input shape: {x.shape}") - x = self.conv(x) - print(f"UpConvResBlock after conv shape: {x.shape}") x = self.upsample(x) + x = self.relu(self.bn1(self.conv1(x))) + x = self.relu(self.bn2(self.conv2(x))) print(f"UpConvResBlock output shape: {x.shape}") return x From c796b23cbfdd646e6079371cc3e2d861c05b3ed3 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 08:39:40 +1000 Subject: [PATCH 067/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 14504 -> 14712 bytes framedecoder.py | 27 ++++++++++++----------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index a1b180e9e4b5fa56c0fa4c4d979ebd11a58ead52..ffd50ffe20f6fde3457d306fa6a9a970dff0b02d 100644 GIT binary patch delta 2756 zcma)8eN0=|75BBD%`^T00uS35f&s&7KFmj=kg83Z&@2^g!V;XMA!%i~eh+Zt54g|H zs0)K>?GklqiF)WlKct~e+F;GpOy0CfRi*qpO;v4)Y1@;~G^vvo^`CT`rYL`GJLlSf zl748{_V3+u&OP_u_q*qu>-E`t-2F3`tJ;L^51;*J((eDI``1i9VKJHBVpnW%xc3zJ z%^!BYz97v@s`-*c1G!_K%}YM&il$ofR+A}R1%EadtkNrTwjOQ`G{O6p4)z9oV(~K@ zd}c{0Rl6SZX4R@PmEDvfC0%4@_?;AT+sE!OoR6Ck?wAYzmd-KehLbjrrxrDaMNOo1 zvIPRXp2ju6%eE9#;EwGPXsr6jNGmA=1Z@ZgOBl99I*ML=56SjuG@gveI0qZNtpPp4sGsJEm0za1{jx+)vmP6uA#_?sIMA&4&pd zMfhqbKQz|_aoM9a?=mlJ)jTtDkemh(47--jr8FMPYKA4BnBhUPbP^CByobhEvl$+r zHZ1X2mWNRhlKCLKSQ}(Ruw2{2>I)y%9x_*6Yft?$8{kgzbs*&H#x?PBYG6`}Wpi8$ z@W*M2YJP;qe1qU1lp2Cvk5Kn8Ldd}fsecT?ka@+Q50P#Z{K`IuILRka&UXr4(#of} ze5A3vmyU!#O+aSDRXM)6mP|&Y8;tKqtxf@RgxAAH;}vup^o;sLsEK3v?lc?rNdD(F zc(nH^__)Id4zB|L#b=)v*$?`_=%|KQy&kxSGU}vxQ&fhdL510&=xu`E2c0k)be$s> zL{n{Ql`7qoe~RIgl}OJdV{wf;Vam7f zxQn-tRlIXM^?DHuYdn*_JZLyZV%hlgSca#DRKxN@da!I5qW+K}jcX^)8WPu%Ir!M; zVT15DUu%P4hq4vJd^^U$+teog7-NKrXcQqF8DXTE%FL*vR*YsZQWp;AL=vmu0o z>|5TZ_Z+wCe%n-xJX`dhFL}>zI=$!7;<}0-F4D+6$KYDvB4cp9%V#mINx&QmKCN)n zAB2yBZu*%Qw#_y&!{_}T_=UfA+YBH1Yar`ycQd5lGL~;XJM$z`2#alA7;fIXD2u4% zB4^S>bCO`1V3OcPf&@ayTn^mh*nc1!DIoKW=w2=V|aN!jfrA zb`{+R)=%Gj79aR6$-|rSFx+mgnmGB|$LaDKDv$3z_3~xHa|K~8c-`uuHHD$xb&*Otn5L_VG&83SZK8LWKIwVuh z2#Ap31j4-~e2MbsuDQxO!S0&A0{7Ugb$3nM0PnOY<+>IEqvx2d2Iq>KVad}_sNXxq zrXPy(Zd#Z^olg;5LNFu>`jAENZ^-fKSURmG^>VsXuI&ah_j1Sh{{Kvqxso!aq1yy_k=}hN~pP5(y(N5S-p}sm3u}UdBX$iR1SOw z2HF>bqKtD|Wv?tW2**(A!^V?VV)I3E%JQ zu$9?`5Bl0|m&t1c{@B;mA?_t1uaNRf>t!I^ZeK>L{w4wyO;E#mWbrTIA4pBREFPi> zb}Btq^jG%vEbdTQCQMG@D!+lF4RbbQ*i`Lu0&f7S7=}HU)@O5?hNO%AJV{E6 zo=D?eozctmzlFMxtDG1%JYhYaPLC>4)N5fkPk2I zZ(}zK@9zKP>v-;`sR`Z?JpbgtW6dJoNrih6!61u272$N6x5NDdooo&Q!;N84RBFgw zoRyeJ{LK;mbnoAnA5Ue}Tv9vCKLB<3N^dR0Qu0N(KE5*XIqCnkYYc+t;Oi1i^*=SY Bpm6{I delta 2510 zcmaJ@Yiv|S6uz^2_x85k-9C4_Pr6G>+b!+3rPY#jL8+jYK0pv6N>bO|yL4e6?c6Pr z3YMq|XfPBGCiqIw9|FP;aW$G4;}8EJn%J8DFgMYtF&g5ZqL?%iiRaAr5kR=R`R>d) zbI#0sbIzGN&(Dq7&)RGy5D|xbr=hb*#&VCLw0nwri=xhFng|BRMrgfcXhD9HVj%w!@*25BauBIJVInh_ z%?O4im*ue$!yJp|cqrCjsqr{!xQgxh}uA1PdYF}j! zjV<91kfR2IZm>J*8ycy(j&!BGo!So}7%JBU3vu2-vPaD&*6Wvcan_K*( z>myiCMz-PMSSqUPF)fuE9zM@_CrWiBU3!pL!f(zaXqKtz^ERO*h9S4~lwynef1QKk z)<@v&dM7lu*Fd|g2C}s^u-8=vpV!*p-P$$G1RXBtZBnoDo+3#X6u96D!q=@O5OKS@ ziWj}drR`Fvc!-mL(@c1>6tvun@Hn1#Q8F`{ipDhF4@cddoh95yTCukvwL%1|5scE_ zXf8G~l;!C)aYNpw4eT)FST?gCn5TxdgWFSEEw(J<;6Na7xDdd881MwyCCGUElhw{s zgFjg9$NHxBor!$<*kv|t-8^I6JZs%d(xKCl8Eez5wTU{9oZdKX4bE7Dv({k2Bz0!s z)9B@jX>0$CwSU&ye@~NZwjxK#bqV39e60BIa>?y|s}yc{+&+re80F96L(3KP|0+Sd zcQ+jHwLwp*3;t~^gZ;iTW`=AqXkkdB3>SR?PjJWRL%8d0n%MFWCoZCT1YPsJ2$AUg?{PzC3Vj;3)e+wN2Z5 zJ_vr)h%fxlsF7JUGOb3wx0W9td}Z)7(&L-8`rwK`I7!<&uka`E#di@rNw6HEPf>R_ z!eT^HpF&ZTeF z`pLRO^`howXp-9*wUdxw1m3CJT42O37;)w!^^>fIat~|GZN#jG?^bmTZ9tu2nb3Gv zH!ORi>2!2~7`}$wisZVJ6Xz`#$r33Ck(A>|$j=QoO%4~?Q_%B%0;&T%0@p+JLrYW{ zvYy1LCrGuGU<<)>1Y%w0XVPh|W7jCsvx!_fdVn7!nm2XtM~QAlz#_2p z6pDf1)p6^B7l?f^|4Z`_y9c#SnvLK&q}Bn+mYxE=PC;*2RD88{kg=ueUGf9us8nw&(feiDn854zeW7VeYFZCv&&QyC!-OLZ0N=i^?rR4uCG>CHASYa01YGUy!|#-*?}) z_aZw?t_e;Evi0sBQ8Ktv%9G#jujx2~*on?Yc(t>&wb!_b!||A diff --git a/framedecoder.py b/framedecoder.py index 8cf2196..2fda877 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -4,7 +4,7 @@ import math class EnhancedFrameDecoder(nn.Module): - def __init__(self, input_dims=[512, 512, 256, 128, 64], output_dim=3, use_attention=True): + def __init__(self, input_dims=[512, 512, 256, 128], output_dim=3, use_attention=True): super().__init__() self.input_dims = input_dims @@ -29,7 +29,7 @@ def __init__(self, input_dims=[512, 512, 256, 128, 64], output_dim=3, use_attent # Final convolution self.final_conv = nn.Sequential( - nn.Conv2d(input_dims[-1], output_dim, kernel_size=3, stride=1, padding=1), + nn.Conv2d(input_dims[-1] * 2, output_dim, kernel_size=3, stride=1, padding=1), nn.Sigmoid() ) @@ -49,19 +49,20 @@ def forward(self, features): x = self.pos_encoding(x) # Add positional encoding print(f"After positional encoding, x shape: {x.shape}") - for i, (upconv, feat_block) in enumerate(zip(self.upconv_blocks, self.feat_blocks)): - x = upconv(x) + for i in range(len(self.upconv_blocks)): + x = self.upconv_blocks[i](x) print(f"After upconv {i}, x shape: {x.shape}") - feat = feat_block(reshaped_features[i+1]) - print(f"Feature {i} shape: {feat.shape}") - - if self.use_attention: - feat = self.attention_layers[i](feat) - print(f"After attention {i}, feat shape: {feat.shape}") - - x = torch.cat([x, feat], dim=1) - print(f"After concatenation {i}, x shape: {x.shape}") + if i + 1 < len(reshaped_features): + feat = self.feat_blocks[i](reshaped_features[i+1]) + print(f"Feature {i} shape: {feat.shape}") + + if self.use_attention: + feat = self.attention_layers[i](feat) + print(f"After attention {i}, feat shape: {feat.shape}") + + x = torch.cat([x, feat], dim=1) + print(f"After concatenation {i}, x shape: {x.shape}") x = self.final_conv(x) print(f"Final output shape: {x.shape}") From 4ac52661cf72f00a598b3caa41701ff64e61af93 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 08:44:26 +1000 Subject: [PATCH 068/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 14712 -> 14786 bytes config.yaml | 2 +- model.py | 11 +---------- 3 files changed, 2 insertions(+), 11 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index ffd50ffe20f6fde3457d306fa6a9a970dff0b02d..fcd22133762ccc68b1941fa32b97d12aea205c20 100644 GIT binary patch delta 2719 zcmaKuUrbxq9mns1Yh%DR*TMG1HcJX_llUB}x~|wXca2|B25v zS>}-?)1)nXm`Zb+iD{ig`_Lp!E4B5*SQSmx_R>|FW_R}?x>BTV(jL0)foXj}(^T#E zJJ*=fbUj!fpZh!Kckbc)`~A-G7X7ViyKA#G3D5%n{8jRqcWpPxg!qY2CkWI8#!mzS zxQV@bz(QOEZnsA0p!MIUfp^G{&CEV@vgE6~R3|3ti>A^K#WxADu#kCv)CSW6M7KSs zq?0p+f|4zya@jx~Z2`MppGaj?c%#tOMfboLZ3PrZSl-gXes7KtXUWyzCB9A^dJ546 z&`C-+r?OG-r4BZ4X*YA%4wkm;HT@EHy~XlgPw8FD*N8h`aSz-Je$f5lO!@i6vh`fW zdQP*RWB+X&B~CWdq7Hezkxy}_W! z`c!X< z<$hIGl6+FsVGz~DrdoXLmfdf5f*WmP@7d3G;|la3VjS@Q9j8ID@3&kaP3&>YB$1g{ zicF57TNu%Y=tuM-{D>|@5TKh1IU2vH*T-W8Is}G*nGUnNl3yOdQ8B5+3M%T|#T(tj zwxlrWVFQlejXT_T)U{E~H&l^BnjES)!kQzzDcBm15`(v99!Y(xq9*mMk7!cgdaNQ1 zXwm>1*yF#^BCU@6WciL-_U)_4Lz+DFam!G-W$2L{TrFtw-g`uogXO`=jZ+)>Uq7b} z&Q#Eb@@$iz-t2iT?Gg5X&{^s<&a%?TX?q5s6EN8gl z3~SEtD&cRR1(SOgJ{uzs$BbS@n$x7YvNX5FOs&I40++tk`Xe%J0qrQ(5rwoAC{}gR zaa!R?k?2Bsm#NZ44}m?{#>o}vC4r6}?pj&8l6^fyv zvux3|->A1K_OdHTzR4cC{sdcDbqCLP;EXxM6uR1?(Rezhs&OTqjz;ejx*x16rubo_ z*zrW62n@ZN2W;IGw*MaW94zzC2P5T$Mb<5QSV*=@ohHYoP{U!uPRi{Kn?kdxjXGIY zK1K#g8}i$vmn(ugKuusNQXgOeDy1Qny^@Z_723>hc@9YJIL0^LjlDtk2Tz0yGmrNv zC*Lk^50dhxvEvSQ+}lphu$1=!m)~0o`+h@aP&LYH;S%;PA`*xf5lMssfK!hP`wi&t zshO^ZY`){QX{^lqX1ah*XAox*JJLRnkBflMD>;Bh!~zR;o-wL>fW5hIme@-l zb*9LTnjG{z&Y~{>0~Y#)|2}}r3?kx)C4er*bJ?qbdcJer62BPBW|g#R#1amF5NBH( zcqqD#LassL9ZBcnmqXELDw`@qqf6{me-G2TGH{?D>5h~0JAz@lPE5lT-BeU)PSs6Iu}mgbm4`l!tMjyZQW^KI3$P|l z5u|wuWcf?~XLYVN^xs{^bdr4<@Sms&ri*F{UX51}J&uSVzK7uZs(&$)p#^x<8`NB~ zkcnNRD>!%q5U^JDWy#>&*j2EoRsf`Ttr=;~<#p}ea8=(Ql>Qn#LB2v^H@*$=T~OG? zQu_`WD!aMLA-r@;mZ4_}X=Z0by`h@Ix;00aqP6^b2A4SofCN-hX-e^Vn%xcc_wpLz z&#P$s(lXr)^_X!~z0I=Tb~YL=R*%aM!&LO{=sbxl>=Z{cTNoX@Q4`yc)RKWCaEYc8 zl)-Ghu8`Bs3FT@E&X8Bo(72pc^UI0?Vc{D(jK+F3m4*8*ry2oz8EgTYQAB3=5v1lQ zmB2<%Ax@)*=}JzGUV|_8*=VSPM+dLwT4=CS4G%kZpqs3fE+6>gS4AR(OCw(5kx<7D zwe!M26YPbji-ljzp=_3RvyaC5<}h;%+0D=963c0YpU?bg;bp@k%NM2mcLQH=4>n)i+n9Av-W$>Wo0LgY{D=B9; z^nJvhsxUan%vxM>BGdsWB&_JeF+c33rL5OjrRQ5~Sk&mzjl+Q+vCkcqsgI2diS6L^ zz%7#A(w7$cGeUn>=x6Gp4q|8inU2(^qC-R&Q5`gZO@6WUnyMMgRl9HQf+HQ%a?~Bu zt5qYq5e?f=U(H0pqlKxSy&||oLjvCQ_zDRl+ViTAdBnh+LM^*1%=Q@GgY^#zI>(;Y zFX8$Yu3y7?%$3BzLgrrb9h)@YY(&;B6GbN<2n^K6lU|8mvFxjJd7V&9jz^>)}s zE7_xxt#xRD&O=xL@c$b-fsLD%J*1eqtc%DDw%6L*Js-EV0VpOZHWZa;AR#H5WO$Ie zab*@l9YPc8=n`=nJfLWTfdp*8y#ZD|MeJs~%gJu4$bllXX`$ z*_w__nJ-)BJH7m_U`v;FjBf&fdqR7qC^yOozO3L&3%*}0Wm!wz1C`pmn%ogv#}Zkw zd16ymY@O6*L|<0)jcM;%?dgh@6E*<2ue-l!%#AYERaxt*v~|^Q)`j-7rt$Xi(6|@W ztW6nfQ`Xvax{JTP6GWw*a8DloS~3asb6h^ZO|VkFH^t z>~o8Gc`b!`au(0FKkPQLlIbdbg46t5(YUh)ReKOnNiqBV!AL-sgHj~o_g^No1Eez6 zL1(VWwWx}q$Av+FN2998zdQ8{)BOJWaC+M=wny|>s&%#pssb!w7e$d>5sP!#u4P^B z<;0s(&-jqII3-8~l#Bx&v(d^cuyi{BrPg%f^%A1uo6GE!H~ zZaZFuUgn>9hwMhmDKD@Xjv@#L5C#wqB7^~;L-MGsgz;0wi##VQroA#>D7)F&>Pt(p z30&6GZSbJm5q2O<1!*TvcL6*Z6E6x8wy;D^e~wx&r`D1xuMLw>fm-SZtvOr3HMlN!t&FX6cZ zburDkW>(eEKBpU06kSrHaaqyr4@9GZyl-?aTIb0=&8-s+Gs(-~Rz3T&AquWFdHzw` zD^dID@}-N}l1BIXg42pt4#SI%qvl404G1R?_$f5|qER{wQ^g?12NKc1FdfCklK@`f zF-Mb)E2@$pk+Drc?iQ|G4)-uylY@+=nw!>>=TIl2VT56Jp*c^T$f;|EuW4m!^9$<; z5!gE|u9gCAP(XD5Q~|9;6YBsVC8ThK9>IAxd)U%2o98;8pG6@j@foF7%{FzE7Zob( z4BOp$KHm~c+o;wt#pz-+GnKOfyU@{eqCjkpK4TP`-F_Ml(Q|kxbt10lL(-8jw25&P zn1*8V;UP(a5b#qiL!m|v$KYERmvg~*3uIn%F1PjY??#SOY6BVVN7#xxbc1o(KMW7H z+SgLgBVh)4Aq+UkhXTa#?E@d#6`zruOI`E*@VuC`;yVytXA|=_R`I;YDIAFBhmn7o zTVgR<&&ub|T7jvO3xIxIJTw%Mco*hn&#QvRmm8$~?*li8-TyD+qA1_6gnq=1&)?k9 eOUR4|s%PW)oVOP| Date: Sun, 11 Aug 2024 08:52:19 +1000 Subject: [PATCH 069/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 14786 -> 14014 bytes config.yaml | 2 +- framedecoder.py | 102 +++++++++++------------ model.py | 33 +++++--- 4 files changed, 71 insertions(+), 66 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index fcd22133762ccc68b1941fa32b97d12aea205c20..62ef04f55ab4818c8e47756e0c65da3033936160 100644 GIT binary patch delta 4591 zcmb7HdrVu`8NUy|Fc&|t0|DD`F;5eqFmM8>sh}a*l0lmgT1`PKR4B`?xg5SROui615KKU+DdITZTr6K2X;bH zcYXc&ob!F>d!6t6&hcO3e~y@bS5;+X;Cbs`f118?^gUA-+tUW!*O^~3LFN)GF_$<= zv&>Ah68B2QC0^oRVJ<;`fxV=?$au7homx%yx~Tw%O7agXuBRv6 z#>GTsL4si)E8YEC%y9_K?XbH&f5=icj&_{Ir zCncsljGuEW)i9eSI*{3572!qK`+W7ojkSGS9FBAA2FFg!h)NT5h8N@e|2 z8?R{MFi2crS*h^(0%3nV?(@YF7({U?mjfq_7}~j4S$eb0v3SqbxvY^5^`E;MZ(QCH z>eu?$`K-_)3oY0gShr<`R#|9$h*Kukzs_ZaW?5+7tJHNF?=!lF$}f?qn!mCz&W0N& zcLdYQL~b45pTB+Ky@fw6KyNYY7?EMch0%lg-w0IkX-$W$`>>R;59Ned+gs!|%&fYYgZhftwDW!`D% zn$T#*IDiLUnNnJHQlp~?{-j@>kDF9|Qat@u@L0_yrY!kcPIs9%Rt&K)4U|_G#HeBy z>aVING_q4%N;|_z`b}F7R@D`OBiJpvSJ{;-Wjc?PBO%MsPdmz&vIZyWQXBweEj5{f z9UFj@fDg0&xuAz52-FHY{V-_ts-X3<*~xn7$EM~Atfbxl!bAUN9$u_7o+q(DFdh#@ zXT(w%5eb4)(l$Ds2$FJ6k{cG2H6`=*!A0Cfa88!MjhHlvq3GOvLR3A4nPxFIpU4l; zvK?Hmb36QmD3CV1T-%|H}B9E_4C?6x35 zL?n30QnZN}3Ct=ySSKDB@aRYvjvLU+c@m5(?1I9E6r-9BHD*PFZATo7R8$`oNfrHc zj9m8$*v1LH4BazZe$l_0kZaqswhr0W z@rk)3WA50o9LdxawAKtjJ8SXE7H`Jlz3Xbtsa&6yUEQljRrKuo=(_*rsZ7mTc(;=B zi67>g*+t7a*>WypId`|VQ?-(2<=W2laJKfiTzh;~qs|^H%pS|sjNR7UF~SeLS<7kJ zaynx|Y`O#OKJLi*aqMQFhL zk>;<)Ono0Vv-SzuK9R9ce2!tAO*d_}=iIwJb>~vv(Sn~l)NFHJK9Tls&SYv%z?-!U z%9g>5W$>HN4OJOaSJu!i8@e-w?z@8J%_o2McmUQEqBbtol4q3P{F$4+L zOnd^!*Gq7_$$j>?M(y!E=DWz_Y?Icy$2>%;O9oZXh=z`~jIrW2ck~ZzvmH(1Zr|X@ zQEb-(nSyPnhawL913@xCKdJLt9LQ5cX~0%Ht*!s@)9Mze3`Ko`S${Md4967#Hd?Wl zix1<0@-N&YpqH#Izr*wtE;tx*{5Ocix*{SDA{nQNhNpC@ps40g z*tR!@SgDvlau%ef$T^@M19>1_ZWtx)NG>3G3W%aXtn%=A94TZlm$L`1UZ08etdIU% zd=pZ#(iC7D_fv8ORk3H(tavo}3{X?EQ>R{CM`&+z9UG$OnqNF5O8kdFkX)hvY(Afh zQc$C8-F~fQfo;G7fvY(dh^O+w;M%FPt(5s2_-7*#;ztpEJT{$(_!meBnPDIvp^&Ss&eP5SjKnbkDo*D~ZcmO->#=(F zh3&3mPq2riuoV>}xklgZ%BQeZO<^bXIqJ42yC+z7fUfs6KT@m>*eztHxLXt+4kSer z@h4`XHiDrr@uAKkdcUWwqYImAe+gSEa2&9)T=c?#-B{dD%k=Ui@31!2f$CnOso&46 z7uD<+?*M&w;Mig@y}A4uBEk6V3<*i(RghM=L`>01!RJE&3a{Zn#e6gle>Q@!5KE|T z7zcT{O~5%Bi<4_O_7V_JRj$Z&aLUJHgqTr|omJTDa16)dzJ=V+3v}Gu=~kDb3M5LG zLd>rP48ySrG15)1fnD8x-}~`li_!`51j*|(ed4@R-M!cahtvWwsHb{QG)fw%eX#X3 z)(UwUh@zW_N%P^L3dTglVXO{yfmMa17rUxL+-Kh!9*w9I9wWb|KOLMLIK{Gny9d*? VAKZX59eH1tN{my@&<7e+`9H$iP*DH? delta 5078 zcmai2du&tJ8NbKxIM=VliSw|-4S5hdB*DB2X&Hnj;SEg!<(XP0*9i`Zog6zrGkc+0 zyM~I5vAFC=!-;Bcrt z)E5;))T_J0X*iDR!SfxC0Gz~{7GQUu4ls8WK*fRGpLPSDBhMRN=Hi6?nv}6!CH3r2 z#Oh3F5?laHlq;cymX394oSH!cd&XC)gSPRWb_KLeIqP_FuEufk687h+iik+YVg8rk zovy&IdbM-C-%aWgI#F{}c{y5`APHT0?lvcC$6+QBGOuEfwtFnxfYG z-<)u#zn9R`#&}_F2*(W?a&=s+1QB6TpVO3cY@N1@{kYtUj9FF2EcmB+(3Tqj#>R!W zRi~ZMh0=;oMe-F(O(MnST(>GC%N>BLu&aj_@nJNT$z?~j~ETgyv@I*^bXsEv}9uGz0;b_FGp~#i2#fRbh3j!^|!Nsh}xOzhc zRAWs5C%Jh}v%PcW&MTq*ed(PQF<y8uOYhy6=J#j#{Sv>QIm~4q zp7Z#xRm`-#@5#8=OKyd?^0WyHb+Xa2a?OmE^_thQoh2@|tlVxEIrz~&r-hw)*Wxnj zP|_0`HfFA?)`*6Mji1!sy8%UGLOZBoNwbFq%FAlB_bbo?wx+DDTNejyN3CEbd)t3EiulXjc4>THHKHZy`!AN+0`0Nf~S5jZ5BC+~?)@ zKKGj32vwNa9PcPVmzIPcrUx#o?iFLYOmT$=lyW6>>=BE%(5k$JEjz@ssU&wFbT_ZT za7gI~G-BaZyQ=j$n+LdZ+&?43>3=4;6wkU!oVEXli=y0$=(QlI#q6>*RIe!3WdiDP zBsdbfoK-5)|A0>UZL3HF8?e-&r?%6h{=30=fp^tK*xNVmFE% zC~M;`wak#|wyT@+O-0lciD%5#y42Fv9}iJM)hc1UaH()YP#)PTP)tY0O{nZ+G$ed! zVw>&0iOYn#u-~Z;1c~blf(Qpug@tivXE+iZ5>zS7Ir3U${Rj<5;oxslMg*-*o%0HDuh6O72Hz+TrK$OuOW6gN-|9 zvZSqDSsK^vgi5Aymjon3r?o3%>XuC1X;b%Y-g7=E@dC>@0uz>kPLM=IkVJdkln&0A zlwY#frGgoIgJf@*Y=c-&7p+ZI130WovknSc(zV=r7N~VnGru2RBIXMHp#gy?cDYWOBK7T z+QFi<9m>JJR-!FK)BH6qi&sr@AigOY{nYGy`O{((o z--5#j^NFMe(n-ks+@L=19D*Y|sZHwQ`E(L86&-AT$WsK79K{ozqld5yYj_<3)1*FW zNE%s7VHx|CuaFqnhH|(01WB5bph+nMZde|(Wl?tttgmRfii{11DC`t9#4+q;Ep7?ja4>!Zj`mR+myM8J43EV_vhG+obev*% z(sk&dP>FTgHi5%Txj0v1o3>*gIV@|A9G6L-OpeIpxNOPUD0?A>uc$`8f@(7G4d@)N z`?{}$=IrSFM%o0Xv5z=zf;-907MDz#=3GLm3=V&Yy?M@5uo=P}@zkctw#oRMt!QfH z+167NrzcX?H*Ga(Tg`0cq6@7rnx@Fq@~P3evdY)mFKkI2er3mn9howpROUXB zR;YcOf0Li?x<;<~Gr}rKSanNiy(zTL_-BGQ95;H}GD4>$bWT~OEVqT)Da$PHJhSVC zUFUaXc%Q`k(!5XEP@Sn@x~LP%HIL+PSF0I*lf-XID`nQ^nzSiO=ryu(XWfJcCP+IF zkQLdIO``fjLqmZ8NZtx00CV3BDt8z+6ah6QHVOdAZ0=`v@#?fvZurx?_sw%v`VGL( z!*KQd22!PWpive!kV1XgJohkw4fT<83u(CSEcpScQQ{F58x9yu7-%KHZa8*~kfRO- z`$Dvk{jGGB-HrE@>0$~_8=2R&jVxnhu0>8|k~k>{l6=BA?Q)Yn?6PjK{+iO*vHOxV!dAI4wH#I74#rl(L;b<6MbOjUUbul(1$Pp zAnPz=_iD2zOCNkcM?ym}b@f3(fS4RF&=@c2gi{5!yzlTL*>uWo@QWn+iMvEY!bg#$E}T?s*rI1ElWUhjQkf0P~%QgM{E z3CUW8h6osag)b|1BJ{Jr3QwqfSrxuYuP3WdkSdHCFHxcfl;#1N8d-6TfBAYak@e#t z8jZ>Nf#C3PFv}8Mj^35Xxt}2RbWJ50XXk3p00VWk|Iw6SM)QDT(#P1ECAB;A+<*>- z;USnk((MS_5Wb0^yn*)c@Gy--Q#Qt;{qf=837Wv0sEb}cOC@`ENr(Fw^u%}oh~@m1 zWs=z&D#$a}8yj|#htP(fsTAR9_L?_KTd6`@Gp6v@yL>wcDP*rV)->lSlldqe2;}2t zGrHIU0D%<>4^fKk_3UnAeS@+$l=c*Me%;oC{&5fM@;lA>*qUO!OJ60*C&cxlCaSAg z;Pf$cvk>`(?EC)4i9E3?eWqbZzXoVnqztTOO*|?a#L%%Yh{IXzC>V>x;F~i9OQ5{f z2JF 0 else input_dims[i] # Double the input channels for concatenation - out_dim = input_dims[i+1] - self.upconv_blocks.append(UpConvResBlock(in_dim, out_dim)) - self.feat_blocks.append(nn.Sequential(*[FeatResBlock(out_dim) for _ in range(3)])) + self.feat_blocks = nn.ModuleList([ + nn.Sequential(*[FeatResBlock(512) for _ in range(3)]), + nn.Sequential(*[FeatResBlock(256) for _ in range(3)]), + nn.Sequential(*[FeatResBlock(128) for _ in range(3)]) + ]) - # Add attention layers if specified if use_attention: self.attention_layers = nn.ModuleList([ - SelfAttention(dim) for dim in input_dims[1:] + SelfAttention(512), + SelfAttention(256), + SelfAttention(128) ]) - # Final convolution self.final_conv = nn.Sequential( - nn.Conv2d(input_dims[-1] * 2, output_dim, kernel_size=3, stride=1, padding=1), + nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1), nn.Sigmoid() ) - # Learnable positional encoding - self.pos_encoding = PositionalEncoding2D(input_dims[0]) + self.pos_encoding = PositionalEncoding2D(512) def forward(self, features): - print(f"EnhancedFrameDecoder input features shapes: {[f.shape for f in features]}") + debug_print(f"🎲 EnhancedFrameDecoder input features shapes: {[f.shape for f in features]}") - # Reshape and reverse features list - reshaped_features = self.reshape_features(features)[::-1] - print(f"Reshaped features shapes: {[f.shape for f in reshaped_features]}") - - x = reshaped_features[0] # Start with the smallest feature map - print(f"Initial x shape: {x.shape}") + x = features[-1] # Start with the smallest feature map + debug_print(f"Initial x shape: {x.shape}") x = self.pos_encoding(x) # Add positional encoding - print(f"After positional encoding, x shape: {x.shape}") + debug_print(f"After positional encoding, x shape: {x.shape}") for i in range(len(self.upconv_blocks)): + debug_print(f"\nProcessing upconv_block {i+1}") x = self.upconv_blocks[i](x) - print(f"After upconv {i}, x shape: {x.shape}") + debug_print(f"After upconv_block {i+1}: {x.shape}") - if i + 1 < len(reshaped_features): - feat = self.feat_blocks[i](reshaped_features[i+1]) - print(f"Feature {i} shape: {feat.shape}") + if i < len(self.feat_blocks): + debug_print(f"Processing feat_block {i+1}") + feat_input = features[-(i+2)] + debug_print(f"feat_block {i+1} input shape: {feat_input.shape}") + feat = self.feat_blocks[i](feat_input) + debug_print(f"feat_block {i+1} output shape: {feat.shape}") if self.use_attention: feat = self.attention_layers[i](feat) - print(f"After attention {i}, feat shape: {feat.shape}") + debug_print(f"After attention {i+1}, feat shape: {feat.shape}") + debug_print(f"Concatenating: x {x.shape} and feat {feat.shape}") x = torch.cat([x, feat], dim=1) - print(f"After concatenation {i}, x shape: {x.shape}") + debug_print(f"After concatenation: {x.shape}") + debug_print("\nApplying final convolution") x = self.final_conv(x) - print(f"Final output shape: {x.shape}") + debug_print(f"EnhancedFrameDecoder final output shape: {x.shape}") + return x - def reshape_features(self, features): - reshaped = [] - for i, feat in enumerate(features): - if len(feat.shape) == 3: # (batch, hw, channels) - b, hw, c = feat.shape - h = w = int(math.sqrt(hw)) - reshaped_feat = feat.permute(0, 2, 1).view(b, c, h, w) - else: # Already in (batch, channels, height, width) format - reshaped_feat = feat - reshaped.append(reshaped_feat) - print(f"Reshaped feature {i} shape: {reshaped_feat.shape}") - return reshaped class UpConvResBlock(nn.Module): def __init__(self, in_channels, out_channels): @@ -92,11 +88,11 @@ def __init__(self, in_channels, out_channels): self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) def forward(self, x): - print(f"UpConvResBlock input shape: {x.shape}") + debug_print(f"UpConvResBlock input shape: {x.shape}") x = self.upsample(x) x = self.relu(self.bn1(self.conv1(x))) x = self.relu(self.bn2(self.conv2(x))) - print(f"UpConvResBlock output shape: {x.shape}") + debug_print(f"UpConvResBlock output shape: {x.shape}") return x class FeatResBlock(nn.Module): @@ -111,11 +107,11 @@ def __init__(self, channels): ) def forward(self, x): - print(f"FeatResBlock input shape: {x.shape}") + debug_print(f"FeatResBlock input shape: {x.shape}") residual = self.conv(x) - print(f"FeatResBlock residual shape: {residual.shape}") + debug_print(f"FeatResBlock residual shape: {residual.shape}") out = F.relu(x + residual) - print(f"FeatResBlock output shape: {out.shape}") + debug_print(f"FeatResBlock output shape: {out.shape}") return out class SelfAttention(nn.Module): @@ -127,7 +123,7 @@ def __init__(self, in_dim): self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x): - print(f"SelfAttention input shape: {x.shape}") + debug_print(f"SelfAttention input shape: {x.shape}") B, C, W, H = x.size() proj_query = self.query_conv(x).view(B, -1, W*H).permute(0, 2, 1) proj_key = self.key_conv(x).view(B, -1, W*H) @@ -137,7 +133,7 @@ def forward(self, x): out = torch.bmm(proj_value, attention.permute(0, 2, 1)) out = out.view(B, C, W, H) out = self.gamma * out + x - print(f"SelfAttention output shape: {out.shape}") + debug_print(f"SelfAttention output shape: {out.shape}") return out class PositionalEncoding2D(nn.Module): @@ -150,7 +146,7 @@ def __init__(self, channels): self.register_buffer('inv_freq', inv_freq) def forward(self, tensor): - print(f"PositionalEncoding2D input shape: {tensor.shape}") + debug_print(f"PositionalEncoding2D input shape: {tensor.shape}") _, _, h, w = tensor.shape pos_x, pos_y = torch.meshgrid(torch.arange(w), torch.arange(h), indexing='ij') pos_x = pos_x.to(tensor.device).float() @@ -164,5 +160,5 @@ def forward(self, tensor): emb = torch.cat((emb_x, emb_y), dim=-1).permute(0, 3, 1, 2) out = tensor + emb[:, :self.org_channels, :, :] - print(f"PositionalEncoding2D output shape: {out.shape}") + debug_print(f"PositionalEncoding2D output shape: {out.shape}") return out \ No newline at end of file diff --git a/model.py b/model.py index eb946c0..9493564 100644 --- a/model.py +++ b/model.py @@ -306,34 +306,43 @@ def __init__(self): def forward(self, features): debug_print(f"🎒 FrameDecoder input shapes") - # for f in features: - # print(f"f:{f.shape}") + for f in features: + print(f"f:{f.shape}") # Reshape features reshaped_features = [] + for feat in features: + if len(feat.shape) == 3: # (batch, hw, channels) + b, hw, c = feat.shape + h = w = int(math.sqrt(hw)) + reshaped_feat = feat.permute(0, 2, 1).view(b, c, h, w) + else: # Already in (batch, channels, height, width) format + reshaped_feat = feat + reshaped_features.append(reshaped_feat) + print(f"Reshaped features: {[f.shape for f in reshaped_features]}") x = reshaped_features[-1] # Start with the smallest feature map - debug_print(f" Initial x shape: {x.shape}") + print(f" Initial x shape: {x.shape}") for i in range(len(self.upconv_blocks)): - debug_print(f"\n Processing upconv_block {i+1}") + print(f"\n Processing upconv_block {i+1}") x = self.upconv_blocks[i](x) - debug_print(f" After upconv_block {i+1}: {x.shape}") + print(f" After upconv_block {i+1}: {x.shape}") if i < len(self.feat_blocks): - debug_print(f" Processing feat_block {i+1}") + print(f" Processing feat_block {i+1}") feat_input = reshaped_features[-(i+2)] - debug_print(f" feat_block {i+1} input shape: {feat_input.shape}") + print(f" feat_block {i+1} input shape: {feat_input.shape}") feat = self.feat_blocks[i](feat_input) - debug_print(f" feat_block {i+1} output shape: {feat.shape}") + print(f" feat_block {i+1} output shape: {feat.shape}") - debug_print(f" Concatenating: x {x.shape} and feat {feat.shape}") + print(f" Concatenating: x {x.shape} and feat {feat.shape}") x = torch.cat([x, feat], dim=1) - debug_print(f" After concatenation: {x.shape}") + print(f" After concatenation: {x.shape}") - debug_print("\n Applying final convolution") + print("\n Applying final convolution") x = self.final_conv(x) - debug_print(f" FrameDecoder final output shape: {x.shape}") + print(f" FrameDecoder final output shape: {x.shape}") return x ''' From c8a64766d56b0037f1357baaf64fd706f11ac1f5 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 09:10:24 +1000 Subject: [PATCH 070/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 14014 -> 14261 bytes config.yaml | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 62ef04f55ab4818c8e47756e0c65da3033936160..3497f50f3a6d4d76fcd15c8ea9b8e92e53ea4811 100644 GIT binary patch delta 3025 zcmai0Z)_CD72mnP}kjY;v@=NNIeV}h|U5K4-y<8BQHzB`yb zJ1I(Jq=d)<6ig;0mZ}z1A=OqB_0n@yKA?&Zv`X7Z-K|!tbrKRE!iRil>nNXERaM{H zJKKOzdV2fYH*em|ynVlUGxxjXZ_?ghl$Ch|_{snL?bOe=-Sk$I|J-cts*c-Ey5siu zg_G1K3<+_^`$F8Q+Rl}HM3i!mWpaVZy>2SCOcYEd1yi-xvUtE!y6#kFnM5#^6-{&A zIGGRv-fu91(ZLwL)_+b=%HjF$Wq?nG%(@-k%$u=qHs28v!aL+H`H&DHE=*WYfQ$xg zdFO%1{$oRV=LMEXWuWbdvNKxVb^a}UYc@gPSDbB!@}+V0^vsz_%Zru3#BO-B4uH$T zBe%FF5a;$sv}``3)B1&3fTv~64!xGkcNkKxnY`*CUamTuoQ#5jf9~uj+xXv|Z8qlN zifi15AH^yWssJzy!SA0kxY0aw)fi)Hi`4Uotrv zn~rMQmQ>+!6I}tE?_7ZkDGieq(o43cJ8MYe?0T=7!eHB4MdAeff9QtS}NFeY; zD|J~dE_(7bD8_KigqXl6s|KAi>6A{V9*gqAk?XI|6U#riL_TxtauCFzohI$nY3E~6 z`kPqySFvs}YKVSQ^cQ^(+zZ}+MVAkN7&Kzih)(AsPb#+F(iZoawH=1C%T#t56}!!f z-D|>@lKyWI<_8{$ieBB59nW6;VgeKpLkyW>NEbt^QpfSrcp(M|po`&ER(3eE`|T~1bOc?XqQR5blSBlN;eK)J-j$zi1nseU*zt!2jn5}{2GWsM@%}R(-Fwv zJcYekA$N2);)aXhhAS)9OiZ1go}64JtP8go0eH1gJR-Uxc(^7Xcj3{=ME)1~0NIg~ zt3D@9R`J$>dD+J5vDXi9beZJciPVMZXiQ~2{OhgVQUmr`3pbi z)Ha_)^CCd8gnF~V71-vzAO`I-X`fE}_#bPc1CZtd0dgN`1`V4utkdvnA&t1gAi_bM z@2Pz4KqJB!KVN&Skgpw9zT{k~FG1!u^2J8M>nJ-67@+LA^t9nLc-pYz2pDzV5lg3L z19t24@-%ifno6nDT7j`=4?OU`R5hqZGe=Zy|8zQbe%It=B9+Jh*?7xnAMigZb-b!R z3Auf*K1Sv?a~s2z$e@+pfMYY2dCF8R5ub@p=k4j4j8>$pc#h**fvOVpNm53}~`3yG2;CWdet?yRrmhA zyA5xXdfe3kV(scc>GKSA@XDrOcNkpq&PyswYkB9HXfhdHC&-mZO~w<+7tzCi+*C_u z`K_iKkpGtE|JthdVK8v63#12{kp^IBfPi2zPw^<51F9HO_BAfhItr6c@uxMDYD zXBUYJWgAvnbvXoL&|Z`F>a;gE+U6x#DdB9_%07637_{G{{W|SmE&qN3+F=2PFvEWu z7%QZ|-Aez<*3CS=^AI6D{C0bjt56*LOHkuDD$Kua-_d53xz)ac#uuF;V82wy$AhwG zqXMq;Q!o9bc=>PXV2+y2+=p-5{WlVVYF6vu9W$&aQ%Y-j+${U2%0b z0q5K`G?dJww2L#U3RGE9^`p_QB~oyyr!_W?25zS--Q^mfiN;8?jRJm`e)=UIURb|uqFHBX{Fs><$?ffh^qjU40;HKva zHvOoPGqEeCPWBV7n|<3f!r9m(Q?LGCI`&sn$WaMiv<9ITK+*#^om4KpYMOEoRR7n4ro#n=5n|-p+rUt)3WuY~h$Mx>k^#m}=Spc42a8m7RALs=9gYoq zTMO@2r8v&Q>ie1to|^MU{gQ{>savp_F$xzyl3}$VcDr$L3&4s_e(7pS{@{8T*_k^S zLhdo)t6WpzLqXtz#fk-*NN4FOSVQ{|IsuA?WGqcxIP5~ewCM{tQIXBi=zP%-jbtfq z<~C6evo)OJ1UAude7P4)QVN(`(e061*?g^>!fkjeFI`n3C$O2XDdf6JuFK^5W8S)U z@y)4KPIK?ubp5nPHur%j#G?|AOgxWy;ZJVF zh1`s*ZQiwN&2@4!uo-!KOm3V6*Q`I+Yr#Ehv)nS33*{EK zWav>)5QUdmKw$g zDn&4{Dccp3MH{mEjv(hOd}@1->$IW;n{Z1{?WkjGT-xG_m5KC1G7^o^e)j7oQE0_E zZR2*-dYIn+3g=A-N8+>AKQOkZq8`)Z@A?_T&nfaRn)E^Ttv8l za0y`sz^$hq)Q16{>EKZ?EHz$U7|Wz@UW`d+l9}i?PQ)|xR)o%CFlhsz+*bp+P#zBd z0Ej{cRWc|OZjimx8VNuNO2sh`f~Jrml?=&bXcsA_8bbjHLtn;#U!c;7MuaIAbX+b` z>C~v?3x9DWxY+|L=_GhvqGti!gdSn81FW{#`@^e5yLwLN=JV+z*U6r_?#=jGXv4n# znquxy%^j--&37<2mY)Q6pC!0qwVGP+_N32cn8z7}*V(UJ7fKBC8pGX%D;>AFc08wU zPCKXWnxn+_uzz%(3i-jMSg{nN8L3z?7fGce`@EVG>F{hK^*mN#%~In`EEx}Ev$1qG zkx8EjhZE^UHXNqF^T}jaX7o>3X8cwenK{jAPN~!$$(4-3O#N!hrKYAo5_C#9v}+wtHdu8 z|8AXDDhTuZN_FE5!XmqQyj-bnty0Ge^?h-U>u0w;ou<;!XK#ZV#ua|{N6!nL9#pmQ z3Tn^2J}ll2(<;nL_n6Y|V6S?Y%J<0Mz0}xpSO*a!95!A*yY;2M`GaDV*e0o%G(SfZ zv-A#l7xmdp(KH*onSi_PduXV-k(L&3#9~mk6%6-fG#aEt8ouEfi7uk;4FGp-sX;W{+2#7*PRWd4*(f?PCy({}cW*nJL9 z9{mnL(KMZzy^)M*UkL4P)7np~tQL^Ih^|^dyg&X!hEv*t)ATJE!Yk<{ABzu6=sz>E K4+owY(euA}y4?Z* diff --git a/config.yaml b/config.yaml index 2b97107..5a7aede 100644 --- a/config.yaml +++ b/config.yaml @@ -3,7 +3,7 @@ model: latent_dim: 32 base_channels: 64 num_layers: 4 - use_resnet_feature: False + use_resnet_feature: True use_mlgffn: False use_enhanced_generator: True # Training parameters From ac7f8de513f45e6f17883015db9278471d74c1b5 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 09:23:07 +1000 Subject: [PATCH 071/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.yaml b/config.yaml index 5a7aede..5ad60f0 100644 --- a/config.yaml +++ b/config.yaml @@ -10,7 +10,7 @@ model: training: initial_video_repeat: 5 final_video_repeat: 2 - use_ema: True + use_ema: False use_r1_reg: True batch_size: 2 # need to redo emodataset to remove the cache npz - smaller numbers won't work... num_epochs: 1000 From 9e3bc170641d684b6a537cba4cdf8d7f224478fc Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 09:24:49 +1000 Subject: [PATCH 072/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 0c412c7..03747f9 100644 --- a/train.py +++ b/train.py @@ -117,9 +117,11 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato total_g_loss = 0 total_d_loss = 0 - + + current_decay = get_ema_decay(epoch, config.training.num_epochs) - ema.decay = current_decay + if ema: + ema.decay = current_decay for batch_idx, batch in enumerate(train_dataloader): # Repeat the current video for the specified number of times From bb2cd03b3385bba10ffc9c768fb69e90bd82357f Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 10:17:12 +1000 Subject: [PATCH 073/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- framedecoder.py | 38 +------------------------------------- 1 file changed, 1 insertion(+), 37 deletions(-) diff --git a/framedecoder.py b/framedecoder.py index 7a62e71..4a3cf3b 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -2,6 +2,7 @@ import torch.nn as nn import torch.nn.functional as F import math +from resblock import UpConvResBlock,FeatResBlock DEBUG = False def debug_print(*args, **kwargs): @@ -77,43 +78,6 @@ def forward(self, features): return x -class UpConvResBlock(nn.Module): - def __init__(self, in_channels, out_channels): - super().__init__() - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) - self.bn1 = nn.BatchNorm2d(out_channels) - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) - self.bn2 = nn.BatchNorm2d(out_channels) - self.relu = nn.ReLU(inplace=True) - self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) - - def forward(self, x): - debug_print(f"UpConvResBlock input shape: {x.shape}") - x = self.upsample(x) - x = self.relu(self.bn1(self.conv1(x))) - x = self.relu(self.bn2(self.conv2(x))) - debug_print(f"UpConvResBlock output shape: {x.shape}") - return x - -class FeatResBlock(nn.Module): - def __init__(self, channels): - super().__init__() - self.conv = nn.Sequential( - nn.Conv2d(channels, channels, kernel_size=3, padding=1), - nn.BatchNorm2d(channels), - nn.ReLU(inplace=True), - nn.Conv2d(channels, channels, kernel_size=3, padding=1), - nn.BatchNorm2d(channels) - ) - - def forward(self, x): - debug_print(f"FeatResBlock input shape: {x.shape}") - residual = self.conv(x) - debug_print(f"FeatResBlock residual shape: {residual.shape}") - out = F.relu(x + residual) - debug_print(f"FeatResBlock output shape: {out.shape}") - return out - class SelfAttention(nn.Module): def __init__(self, in_dim): super().__init__() From 9217e8a9acc777f33e9ae650997b14a4b4b03dd7 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 10:18:42 +1000 Subject: [PATCH 074/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 14261 -> 10985 bytes config.yaml | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 3497f50f3a6d4d76fcd15c8ea9b8e92e53ea4811..d5ed227493abde243bb1adfdd3f8c30a15ad5c67 100644 GIT binary patch delta 1198 zcmZ`&OK;Oa5cWE*<2;)bLQ@DKq$v>EL=~k42}F679)Oba6hX*J;|&SKcDw7eD1ZbK zT<8Iny>a7!pa)JC5*H8>7dQb2BxHU76oj}U4u~D6ttbM^{yh85w>z`5vro*0Q{ji9 zkYC`B~s*lLt&=`Z%Wh>t^|q6Fh`V*3yXO z(WABiS1R^L)RfjoxH*rA1V7Oc^ZgD|smMpE6!Ehw@_w|IJ(sg6$Ue*Ci70oF7{D41 z%Q_&s*e~xcc`cYh_Dw#L=m7Ko+Mz%K1d9`fm0}z5aDvS%_mUZII6hS^*GrVDDg{R> zb)7&tNjEF`(%v99=`F-|=hoQ|d^+L@Mti&LzQAiK)(O*5<*)SokwKJadm|Yx#gX$U zX}^p_5el&P(KO1iAJH-7XQLfkQO+hEGa^@gS)D`OEFSA=cLRqZIzWIOiJg`q0}{5Q z>1U`b$?Ys68{iipY4CId^srCeINeH{kDF^w;Cf}BWnsgrnMPh!Yevmd)#J>_+(AM6 zS0>~^Nfz9a9f2bvJ2*HJtu-+@=9<$rJmYvLG0__qCI!&!0f%Qo*Bx}o@lvy5>7~nL z99$>a)h&gcj>l|R6tYeI@w93=!r5kQWl|NB%#?_B)|RqkP}Yr!)vf*a!kPf%Y< zwf|lgO&_!)y|YMe?V8>1OZy$ATrC*}u2c4^FPAi7;zfV~4itKgXR>^pt8!_%3XZ$% zT{6MWQ&3pELeGylN8c!g6m#Nm&kxp$&GunA4t zw+G&=C{90D5`YWr+F&8!HU|u|k;P;!la#A@!{9x!FxWG34)i()M?ldF#&3ec7m)w} delta 3346 zcma)8YitwQ751I6$76eJ=V8ar6WfFoQ<5dwE?Y=Q3j_jM@IoGJpu4h+J!#xJcDQ4Q z#dTzUu$9F|BH0^hvvj3Kt1b&-)kaFH-By*#&$d#FEx8h!ZGQAumtO^yA5}%Go-=-6 z^U{u!Z_YjE+;d;&d~@$l{yJg%%xX33;Box(i|J2Vmu+tHc==o3adGF25&nwqoXF`$ zbrJm)UBn=BKQb-<#bq~ZDgji+iprv?L{OP3DoInBL1n3^teQ#!m9?z0+9G6J=eOO* z83b4pxoI$gWZd{YN$Y!TAl@RXbW{!{$7JPDED`?kEl%b7_NXMNaA5U#sL0L;he#{? zNbn1|ko!ElU_8-b2az@-I0358sBp7A#=Vv{6k8F@T89P{U)*6(DEqdF=NeSo=myu<8r~tD8?2#`5EJ{f(MPtkKQso2$Xrdk$nK3p zO&2-v`mLjv_}JHu9d>PC2}egTvkvDOZHK7D`nyKjdq5e+Nmd>K1Z@9nku7zNvaXg^ zD;FUV?sNWzUZ=aQXUAH+#K0cg+$_yo*_U34{nBgd$Z(lDw@SE-E@))G@;Xah^w1Qs zX~hH9o{0W)1GBl>*cV=F7PqFP6(UBkh&Ny=x2w~G-nfX8hYnK46rkD(c7o~_N4wxf zJpdDsFve(nE*1*QbU!Fm@q9EEjmshG=T%7whhlPYIuuSOsH&e$L}b+(iba2zjR(UC z8kec!w^A<#^C5I2^kNi47|y$2wHyj1!!v;dogIj%{P}o)$uNNOfT|ypPfV(MD#uc) zI60?;X6ItEDyGVkUrz@xmL&?K&4l7{Ii{#mB9#n=hs#a+DmW;x-Hsx;nE?*UH(|pWHlgCzL-umUoU9oa063cuv1)-fDKP46p4jG`~`8 zegzcbmgbwe`SzU)`LXf5bE4p!C^{#e@9St;8C@I6JNyNQzv%FTlWEJ{_OWHX@yoXS z_*;4RRKYz}bX%stV$H!1Dhkq?dzY}w?pN6*i<|w|Z5I|uMz_W@ug}S@wYrIpZS%R= z|J*i60qwuWf!D1G=m>8kyjyDBnifcsz@p1fgv$`-Y2TJb9^nJ&j{1G_Kq{n?9%Uvp zC-3*9JD;+`-q)D@BE{sSuOqyJ@B@Um5GE1MB1{4JIqf9rL6A!on*}C9>k)u-OZ_l? z(+PSpL?h?WdllfaZcA*=JBHT#H(!Sr{1wClMe#shJaA8HxOVdD$yIMb>L^MbIf9Dg zSC216KAibrW;OZq^!4;j??($CFBDpOi!Hqcsjn#YQT@I8A&PNc}7^eDm`I9c52@0S*abU_hLq2UC;zc0f=O1}r(QKR83d7yg%HVKBt z2fCL*Y2+S_nH&&2jPzc4kQ_8-I#y<;7T^q*on%NIKLNW~43$2R7I(evRwOm;(b(p$ zcgOR3WovrffHmK)CfE(I+>7LPB@tcmNPJB@kia6vj; zln$ddu=?GC)LoRiu`#+jl9wR6yhX|T4Z1wE2u|}O9RSvS4u0B6)g6x>vqn3guXVL0 z&tj6DaBSVOoBfG1@63=4mj=)J9bVw;%9+pbS~>HuyZS)7t#;8*7jvDB%1Shn3dO3m z-d<;Zs?rTLl^({!AV;-gFH;B399nYze3BkmF}n=q*CXz&{s(Ul%|ewY13`!4$}wbj4^E?FqwiGJg1C-acN*NAyTc3XMU$&_~g zXqALjy%J5!GzPI${d`ous2b;FI-5!YfQ}BMfj>VxOQ%sXDv9ajZ0KDYN8<(d_x4k6 zIs>wTcL(^ORv)-E9`D#j7B-?C0n+_q()%_BLP)W%I#0nYcXt(uP^xVvczgyL0V5mm z^md^@-$QubC!8&Kd}NWWd42{#hP>DHLL0hvvcK)v1GhPAkNeH5pisGFLKPzNd=zdk z@1w$$iYpgVvJBgx1^*r@c_kWG`EWv^^BCo)0B~*5?iNB)jw=aDT`bvi(5)@r3g+s{ zQGnI|xu=J$Y&`0@`?A&V!Qcq*vw!S7)uC+@HvRU}&Wmt~3eWV#~bk#y*Qv638f=d~67^c5~Ci1W{%2fX-+qdhD ssRjaHk!N9PY~}1j)IWM+>~18GtJz{u3TY(ve@rt!6#xJL diff --git a/config.yaml b/config.yaml index 5ad60f0..f556138 100644 --- a/config.yaml +++ b/config.yaml @@ -3,7 +3,7 @@ model: latent_dim: 32 base_channels: 64 num_layers: 4 - use_resnet_feature: True + use_resnet_feature: False use_mlgffn: False use_enhanced_generator: True # Training parameters From e94e53cc316e5cccafe897f1c5b039875c7b8766 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 15:23:57 +1000 Subject: [PATCH 075/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- resblock.py | 152 ++++++++++++++++++++++++---------------------------- 1 file changed, 71 insertions(+), 81 deletions(-) diff --git a/resblock.py b/resblock.py index b4fcc96..b0c66e8 100644 --- a/resblock.py +++ b/resblock.py @@ -10,36 +10,25 @@ def debug_print(*args, **kwargs): print(*args, **kwargs) class UpConvResBlock(nn.Module): - def __init__(self, in_channels, out_channels, dropout_rate=0.1): + def __init__(self, in_channels, out_channels): super().__init__() self.upsample = nn.Upsample(scale_factor=2, mode='nearest') self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.bn2 = nn.BatchNorm2d(out_channels) - self.dropout = nn.Dropout2d(dropout_rate) - self.feat_res_block1 = FeatResBlock(out_channels) self.feat_res_block2 = FeatResBlock(out_channels) - - self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): x = self.upsample(x) - residual = self.residual_conv(x) - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - out = self.conv2(out) - out = self.bn2(out) - out = out + residual - out = self.relu(out) - out = self.dropout(out) - out = self.feat_res_block1(out) - out = self.feat_res_block2(out) - return out + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.feat_res_block1(x) + x = self.feat_res_block2(x) + return x class DownConvResBlock(nn.Module): @@ -50,8 +39,8 @@ def __init__(self, in_channels, out_channels, dropout_rate=0.1): self.relu = nn.ReLU(inplace=True) self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.bn2 = nn.BatchNorm2d(out_channels) - self.dropout = nn.Dropout2d(dropout_rate) + # self.bn2 = nn.BatchNorm2d(out_channels) + # self.dropout = nn.Dropout2d(dropout_rate) self.feat_res_block1 = FeatResBlock(out_channels) self.feat_res_block2 = FeatResBlock(out_channels) @@ -61,75 +50,69 @@ def forward(self, x): out = self.relu(out) out = self.avgpool(out) out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - out = self.dropout(out) + # out = self.bn2(out) + # out = self.relu(out) + # out = self.dropout(out) out = self.feat_res_block1(out) out = self.feat_res_block2(out) return out class FeatResBlock(nn.Module): - def __init__(self, channels, dropout_rate=0.1): + def __init__(self, channels): super().__init__() - self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.relu1 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) + self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) self.bn2 = nn.BatchNorm2d(channels) - self.dropout = nn.Dropout2d(dropout_rate) self.relu2 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) + self.relu3 = nn.ReLU(inplace=True) def forward(self, x): residual = x - out = self.conv1(x) - out = self.bn1(out) + out = self.bn1(x) out = self.relu1(out) - out = self.conv2(out) + out = self.conv1(out) out = self.bn2(out) - out = self.dropout(out) - out += residual out = self.relu2(out) + out = self.conv2(out) + out += residual + out = self.relu3(out) return out - + + +class ConvLayer(nn.Module): + def __init__(self, in_channels, out_channels, downsample=False): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2 if downsample else 1, padding=1) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.relu(self.bn(self.conv(x))) class ResBlock(nn.Module): - def __init__(self, in_channels, out_channels, downsample=False, dropout_rate=0.1): + def __init__(self, in_channels, out_channels): super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels # Add this line - self.downsample = downsample - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, - stride=2 if downsample else 1, padding=1) - self.bn1 = nn.BatchNorm2d(out_channels) - self.relu = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.bn2 = nn.BatchNorm2d(out_channels) - self.dropout = nn.Dropout2d(dropout_rate) + # Main path + self.conv1 = ConvLayer(in_channels, out_channels, downsample=True) + self.conv2 = ConvLayer(out_channels, out_channels) - if downsample or in_channels != out_channels: - self.shortcut = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, - stride=2 if downsample else 1, padding=1), - nn.BatchNorm2d(out_channels) - ) - else: - self.shortcut = nn.Identity() + # Skip connection path + self.skip_conv = ConvLayer(in_channels, out_channels, downsample=True) def forward(self, x): - residual = self.shortcut(x) + # Main path + main = self.conv1(x) + main = self.conv2(main) - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - out = self.conv2(out) - out = self.bn2(out) - out = self.dropout(out) - out += residual - out = self.relu(out) + # Skip connection path + skip = self.skip_conv(x) - return out + # Combine paths + return main + skip class ModulatedConv2d(nn.Module): @@ -262,11 +245,11 @@ def test_resblock(resblock, input_shape): output = resblock(x) print(f"Input shape: {x.shape}, Output shape: {output.shape}") + # Check output shape expected_output_shape = list(input_shape) - expected_output_shape[1] = resblock.out_channels - if resblock.downsample: - expected_output_shape[2] //= 2 - expected_output_shape[3] //= 2 + expected_output_shape[1] = resblock.conv2.conv.out_channels # Use conv2's out_channels + expected_output_shape[2] //= 2 # Always downsample + expected_output_shape[3] //= 2 assert tuple(output.shape) == tuple(expected_output_shape), f"Expected shape {expected_output_shape}, got {output.shape}" # Test gradient flow @@ -279,9 +262,9 @@ def test_resblock(resblock, input_shape): resblock.eval() with torch.no_grad(): residual_output = resblock(x) - identity = resblock.shortcut(x) - main_path = resblock.bn2(resblock.conv2(resblock.relu(resblock.bn1(resblock.conv1(x))))) - direct_output = resblock.relu(main_path + identity) + main_path = resblock.conv2(resblock.conv1(x)) + skip_path = resblock.skip_conv(x) + direct_output = main_path + skip_path assert torch.allclose(residual_output, direct_output, atol=1e-6), "Residual connection not working correctly" print("ResBlock test passed successfully!") @@ -326,20 +309,27 @@ def visualize_feature_maps(block, input_shape, num_channels=4, latent_dim=None): # Determine the number of intermediate outputs if isinstance(block, UpConvResBlock): intermediate_outputs = [ - block.conv2(block.relu(block.bn1(block.conv1(x)))), - block.feat_res_block1(block.dropout(block.relu(block.bn2(block.conv2(block.relu(block.bn1(block.conv1(x))))) + block.residual_conv(x)))), - block.feat_res_block2(block.feat_res_block1(block.dropout(block.relu(block.bn2(block.conv2(block.relu(block.bn1(block.conv1(x))))) + block.residual_conv(x))))), + block.conv1(block.upsample(x)), + block.conv2(block.relu(block.bn1(block.conv1(block.upsample(x))))), + block.feat_res_block1(block.conv2(block.relu(block.bn1(block.conv1(block.upsample(x)))))), output ] - titles = ['After Conv2', 'After FeatResBlock1', 'After FeatResBlock2', 'Final Output'] + titles = ['After Conv1', 'After Conv2', 'After FeatResBlock1', 'Final Output'] elif isinstance(block, DownConvResBlock): intermediate_outputs = [ block.conv2(block.avgpool(block.relu(block.bn1(block.conv1(x))))), - block.feat_res_block1(block.dropout(block.relu(block.bn2(block.conv2(block.avgpool(block.relu(block.bn1(block.conv1(x))))))))), - block.feat_res_block2(block.feat_res_block1(block.dropout(block.relu(block.bn2(block.conv2(block.avgpool(block.relu(block.bn1(block.conv1(x)))))))))), + block.feat_res_block1(block.conv2(block.avgpool(block.relu(block.bn1(block.conv1(x)))))), + block.feat_res_block2(block.feat_res_block1(block.conv2(block.avgpool(block.relu(block.bn1(block.conv1(x))))))), output ] titles = ['After Conv2', 'After FeatResBlock1', 'After FeatResBlock2', 'Final Output'] + elif isinstance(block, FeatResBlock): + intermediate_outputs = [ + block.conv1(block.relu1(block.bn1(x))), + block.conv2(block.relu2(block.bn2(block.conv1(block.relu1(block.bn1(x)))))), + output + ] + titles = ['After Conv1', 'After Conv2', 'Final Output'] else: intermediate_outputs = [output] titles = ['Output'] @@ -384,7 +374,7 @@ def visualize_feature_maps(block, input_shape, num_channels=4, latent_dim=None): visualize_feature_maps(styledconv, (1, 64, 56, 56), num_channels=4, latent_dim=32) - resblock = ResBlock(64, 128, downsample=True) + resblock = ResBlock(64, 128) test_resblock(resblock, (1, 64, 56, 56)) visualize_feature_maps(resblock, (1, 64, 56, 56)) @@ -394,14 +384,14 @@ def visualize_feature_maps(block, input_shape, num_channels=4, latent_dim=None): # dropout - upconv = UpConvResBlock(64, 128, dropout_rate=0.1) + upconv = UpConvResBlock(64, 128) test_block_with_dropout(upconv, (1, 64, 56, 56), "UpConvResBlock") - downconv = DownConvResBlock(128, 256, dropout_rate=0.1) + downconv = DownConvResBlock(128, 256) test_block_with_dropout(downconv, (1, 128, 56, 56), "DownConvResBlock") - featres = FeatResBlock(256, dropout_rate=0.1) + featres = FeatResBlock(256) test_block_with_dropout(featres, (1, 256, 28, 28), "FeatResBlock") - resblock = ResBlock(64, 128, downsample=True, dropout_rate=0.1) + resblock = ResBlock(64, 128) test_block_with_dropout(resblock, (1, 64, 56, 56), "ResBlock") From 45de1201b116bcf936b53b135d5c1d7f4c3d6f0d Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 15:24:36 +1000 Subject: [PATCH 076/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/resblock.cpython-311.pyc | Bin 29620 -> 27853 bytes model.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/__pycache__/resblock.cpython-311.pyc b/__pycache__/resblock.cpython-311.pyc index 9283553816254d62699993a4b425c861e082a079..44d444173735b2916ef19ef3ae4bfaa84a3e8a9b 100644 GIT binary patch delta 7417 zcmbVQ3s76vnby?{y^s)Gfdm319tQCckJz!nHhyAkUbX>0Lcp@jy~cnL_)6G@V3X}Q zEv6(UIZorcNrT;PY$u*1yKJ_OcW2|Kn{~Rg+331=$Bs6hrk&a8dS`Zv)9kkOOeXvP z=OSKK(wSbkpU(ZyfBy6MALl>kZ+^pk^a*3WoS0}(!1JH=*AM%upEui?^1mv`XADY3 zoq|g`t>CProXMz5=8{h<+Ecg`K$c?xZJZ4-l}iOo#aNnf7!(gUjes zqU5KyC>0d)PgA@j& zz$-uPS-@@Mvp9ELDRcC(g>I>ME9Zm(<-Dc>QSi=*$_7!Xd^VTGIhWSZHpy}uE|<%@ zmj9XtTEaWm(giKkco($HjngunAHymKmwRo+Yiei+ul!KKNho;f=-b^O(9XNLtfln1 zwTVwhInJg0F*%iS1%Zkl>DUP!J9sCTMV?cojm1}*1C_%i#F=S)^*B&HoIVc!*G=7l zQ|A)Xmp9>3v@K_5U+8h|@>LUZGsmf~%dA78LtjjkLwe~B2gt{|GUZh@vFLNs^*&!$z#j_vd?D1Y%jJ4n&_h9Rl6JkD z`CBroKSw5#H!>w;(y(8{j4Gy;OV-5)0zThWM(`4&IkTY->i+OFj8y)c zRq8a=vx+#t>NHlY_JKG}(()yvOnlD&9y{tZNohaFi^94WLbMV9^5LQ3|P8 z#ufEggbXaxDP>~GMf&ZTgtM~SsZI)zHm?|uhAKuVL24;L)Wit`g9AfhpWqMkf)_|3 zj4EI}3)b_tHA{=C4qXQm$|Mi0$29+|Rk*;6l540jc}c{bPAn`-K29kj$;MP#ek!Gn z{4u4lNx5J;%dDmbgf253k4he41yn^J*j~ zaxSs7z{3b@QQl(+YmftD^hUW|H4+hWh{b6mQCsac{OW>t{`oVUStuZ1RXCJ47u@%% zA&J*HEjX0Bp8f~ znvklbauo<8y-68k)HJlsbGES5Ks<@dEg=VHCc9`Jl%g7e4-8cZE~KI!qN+o|N|{R#r&({;|3HyAg)X+QBQpz2z zWRs%;-jtck4xSl9D-yOMP)n%A(l!KYtLSzf^&1yTWS36%*=lw%?0X8Ci}G6 ziFZXYcez_7|z!g<@L2m-;Gpalk6i}Ts>^|zw25#VZFy!wa z3@kWv5xUP3#U;;=eqL*PL1w5IwbsI>p4)N0@5;KHyCZD}B0KpHkA8SG!ts%T!3cX) zWRFVh(Q);p{$4VB&O7ZClf6>17YGUWpkdwg)(GnnS&zhekZBzLAnKx%oKTQw4VmP& z+pb#Jcr4@}I-Fy%0HKmMlbtxfW`uGa75e48|3i7ix-k4Xjdvhl5ddi1l5%!<*bD*BQ^W{Q&$>qwnc39qOD%C)q{x`?%A@=k6!5zZPk*k8VSId zo;$tvZPSNr#L+A|nk7dw(9PrUM+OZU$_ccj(nWhtDkdyrDww7UGmHgunjfsN2k1*; zi@M+66O5{H4zZCh4lGHrRrb!ASTD{5!_d=;&%ne+CPvP?&zzh*Ij%;FgU{ykd7>p- zvSg2|XSIqX_Ux&NQ|JAn#U)u>bEGNm%u^9vx@-+`!%i!7ARzd_?9#UGWH7ski>zLg zF{}6{kh^JxfK^1M^GX?({A=D9#zpA@`6xes9cB@N9UyAp`2GQIC;*9oN4G3bSjadi zuW%4sd0!|?&BM%N#NIU5kVDt3Xs1gv4(M;jfl3J#HZV8<;M@k*c>Vr_( zec}7~{CtDQu*gAZTN=196^m*cdxE@Q5c=Um$R_W4Gm9`SUdX=as)8SdhXcHi>*<$M z>x0OHYnx8jHl-D->B7!0>x(nxWxJR$+(M$dLp=f7&oh6)?R{8f>s-+hGxF1ctZ(CnKxE$wBT%-ufpLWs`Qo%73AT{&CHu*{VHCI>Uvl~{%3P0 z`C!#1cP9!ufWQGnHQ^!L?q~{uIjzt`3Lam_R1og*8qG^UxH@zB@lRC7zreb0Ay~+F zt81L6vG@$aw*i)HGvnkxs!N%_Cn;+Wr&EJDgKQHB-$j^2c%G28#hMFH{K?GE)~;7+ z)u=~qksdlkzB0GApn)jbe) z&6y?PcXIoaWpW==04J!*+%?>?C!s>dzYH|P%^k#h%fVON#mEnxi|-cX3#+tvCc zZ2p<<=?u|XDmhCdYRBE|f|pN8*_BfE2Ka(21GV!m{zlYJ_?Og9IQFY^mBg+Z*WYEGci0uT*%i}}=F$vSf7kAw-XYpc z$^DkUwB~Ipt8=Mtxzu%e3AYNEy1bcJTRT!5X zf?^l$qp=6K12nPU9WjtV`)|m_^mLNlVJB~OWRktU3bMhUPHuI2;k%ybf&K(o?E2zX`X}eW)VrtqPyt|t@nZRiPb>aNjloX2Ilvp|JAqTtCb&z6f zxdi#cH3ExqSb0si^rEl>Qk;`|@Ne28nbXTR(}HT$0OAedVh~iol)^P%I=R&4X3}VI z_@pbDJjUgc(_AVBptIxIm0z0Y3GWW4B%pwoQb^5t!A9G$$c(Kl{lJ6ruN#1 zm3LAqTf~`16Mvw1MY;5*no&%CI{9G9)G|p$jorM2l!YYuroqxIV5;U(KQwMS45@+( zle>z=AC(VNCNUUj+ZQq`3`Z$0_?lh;B+a>03QVu7Yc*(@y zH6_pLH7-L$=bSYulJ_wY>t6UD*PhYd)2B?;{#!<|s4tQ9B@Yx<;^}P|ll$y{z5jac zd-XT!Zw}oWet$T!d+(jy9k+LPL^=xjA#v8_)wYQqyZ;G_+yVKHfy9IL9aJyTG zgo?vr_mI>*Bz7K?I**Ah$EB9z;?@(=))V5IVQI~(Nrm!Di`#$z{J1R_j=#;{P&7) z6yM};^}OE`*|qzpfqx8$)vZ!>tLSZ$ylrCM9w~2+Xx}T@_ll;BeUfP(E^>leUSxv; zZr&r-4e&ox6_{qm==S~zn9%(-Xk!14B?VX$_zSIy#;CL^M*S4o@LV9=L%`T{-vxjN zcC>#uGezJ>K(mE1?g-7E9O^$&U3fs~I^qomg-md-y}4hN_kHI0O?;W zRI)4sr{qUrT$mhwfb=gGD!t^n!S~rp#xko|3cWMj(Mm=uALYqKVV_n^%##xHh$@r= bRF7owkn)gAE1I(8Ul_3Q8-oqDf%w70$TEKc!m_ac5`JK)+%;>l z=Q@e6dqNX;rCI9Osk=C#TYPfX*v;vhG&@}+q6VU~Wa$>4E@#^rPQ0z%IXm6GfAJC_ zNzeYE-*@l*?)SX+yWjo)fAv25$#0qN1+7*k!sq%culC&D^pwuZ94{9YurQbUxTrmi zGawqdG(;0;L^N|IL>uNnjzGrbGf%y+>l~B{3RZtW4$n$yxRrIE!!Q~XS&ysv|C|pH@jWQ6?1axEHim0 zmhEUxUoaC_c)MuH;AAQ*nS_%qt|?wH3TI!YnN4VBAygF#(IhFz)5&YQ!1_( z!^w<_9H>aKt5Xxbh(rmV8Bi@hUv_)R zyoW{Zv^h;qgtJh=k(5paGi264DjrhPamZ z27TP1$M5bQ2%Nqpg9b&OQ!}gL<~<>ATg{F=yF`}qOt_MIb{&qmk#e(%%Dqef#zDGjc_zw z3P;MaA!0GXqr(=VT+x;&P3=^sbxK-AJUP5m#XD%iD+%%lRuB{t6cM0riTDzNG6EOE zEh%4)MVy_DE9N#SF2OwvQjZP?e5@)iw;>t)TZ9XuX-(=AO_!R+a-$ksOk*2m)5i2E zW5J}cVB8Zm7RQXmql%BtCOD*GFKG&QS+AKS0Vd(-^|r#%gNxuX>!nG`qOzIe5Hp% zwQ{TYMRLAVa3%FB_;M^*n!sTwhEES$1)7}fN@`kIJm*}>r7l%R(_5H#vR^qRf?x5K zq*O)oqbYGr*tm@Ep|!Zy{U|!A}iMaMUV= zI;$T3(U3!qRKV9@9lw!aE2+{|wI$KllY%4AEaskhM|ho~ zBe)rZY;6@hqtwIe#yl9+mcrL^mC|80OcfbSiP6ac&nk6rPoBZl@KUOE&R;OlWv|iB z^((l9e0i5}>Xvwb-7n+4{=texJ2~$GaiMCpbM}lXy8AqVfY%@7*Wgh6T7=nCIv=V~ zG-R&xdP4JC>UR4AzL4A9K-#a6kEkqmcwpX!KSS!$*fmjTImap&V#jV8Xo-#YkZe75 zN2dx+I%iY(xActER;svwfU&}88dabH^Dq|Z@t3XkU-FNwiRyA5ko=+1{S&2Wg@@H4*c-L(}`-$B18Pc#Axp^lT zmLNAD8$xPqHv9C;;pZ8ybaD~lNvcx_wl6bjwToM*pgPmuIxk6TA)o69#}{*qlW>yT z(&P*WXQG+YN@^>Nro-xa2Mej;WP=HOZK+#OJeSLfh9yBcu2y_pSRdbz)XTzMJ3m8_ zNLXcwE=f@st}Hph!BLU|UQ&ZJI8|+FUaI0unLPkB@Y2BOG%h*?6{N5lA+F*uAOt-p zPWrvP5FNuyFRz88p=$M>8uSK2K94`H2=)!|q3*#@T+zUhEOgE(PsE@3&iQ5goj6;qipow))peEkD+`n9&P^EVss zd!|d*MG7~Ld&WI?EAD${B+Syl9*PJb35e>d2XH}O0Cr_|_$veTM{4toC`DN^mUUAz znXxvKvGxO7_VuO@Qq!&}$4bYC?;EBI%Okm!b&o^{qdP7)jar&UoBmTH!bD}ycbl#^2{IH3smW2g9dxSmwaQAi-F65QOQU8v{4mL zuJ05lg2wXq$3E2H36A0~*ocsnmzZTGGooZzv%s&^P)hP;Ak2W@he7NRGlM?&k$Q1H zpvx=RNNEC}fqV+Mh-_IKpCrkIU zFa^1|;+W4b3{;2N552MQMiOg1eqV0@-#qw$mk(lKa!^$tRjGY}ke5F>KsHXyE<|-_ z5N0oZA%Tj^Ng)stRF>Yls5@jzTQs%VbM#|qb@4eHslTO35>G;cAHv73cE(R^Uil+2 zONDB)~Q!wXypm+5~Rte9qu9bZU1+Pu~cJ=!bHN9W+7YI_}@7CAa{)EVH5Il|$Bww6M zD82~=l`i%xXsztY6x`%nq(&ahUnICh@C3YDSt=Vt@~sKUhD{QAVgRVCS-Ear70(aF zdp>#2d=JSe)bQUXxI*v*+^Q*(J%i+X6Thz6AhA(@L;5m@u)~gEpXa2v+L2g+8JwE~ zo0oJ-{=TSZ>J)7**!3x@qWG3biv@SQ3h!+$kzGgOKTpWEFlmu6UTT@6Q84+E0jaHU zp_;(8!xdgwpk}!R#-)B=Fq9Zun3jXuFSBZ=a$$mTRl?bXb5KbQMhLDV#C6`YC%xSv zFXtBaA%zF3x&|NZD3<*I$(af5&R3Y4@)*iNKD#7^nT-uFw)Oz@?JBVhOF1byW_6_0@ zC49Xzqd+;V7{WOuf9)Dkg@p7eF}r>f)b9SCdBJG~Kb*C|%x=3pVF57i;d;%>nSlRu z>`2IN5(&tk5xa~9;!^Kv%w$pm1MXhl!zFIiVp0}PR3T9-2=d^aJ(>FjPZEqH*u^|= z7kqO+IT%WqXVyB{Dz4|}K7b??;tJnvq7+x2@c05y-B_jiF=|Q}9>SY2(%9;LU~`OT zMQyH_%@vVmOy?BEa@Hb&ul`cobc#Nfk~t$1%ZjEI_AyUH;lTfx!ZF?)RTL(;sKW7O zbx_*$Y1o)GWn3|7T!GDuMKNR1sAAe=d%E#z?}+Z{jP2=+)*gw~9*LHB#mc*)x$ao5J8C)_GaVgOPFo${*FK{i_uVOu7FI{C zH8E?=sP+TB>GIdc25yerUmMZaNA>kFeSHLu@4b?iyRD|)DS68&t#_roRn8FM!&1vc z?!Gz+y8wMH-E&5cwiLC@A9f6QJZ_v2{P^d|GCw4^L9mQ{F#KMK4x&xjQ zUW~B2TXW>I=V+jHm23hv@+bbH)hc$lkxcq@LD!UUu(n_(^Rx(D2d(%*st57GzhQ4e zb6W-k4q4&tw#=*esc1__6;!ol!B5((aPp7^UaC(A*WrqV>Rq737BdI+YzW2=mBTab zR@j2=|HuX|_`^ZHLEbAJN$C}F;^)|i;z<_x518mEmf|PscZK=*o!JuMEaoZ(lRH0!pwwqUAdUZeO+Y+e^(l_ zz}D(~IM_w8OFV9c->l1%a*EqZ`1h_-c&^JV=2WoCT?xPFGQ)YdL82bkLV1H3o_E_w zR>En9wa<#aCr%DYOmz7lE=S|s+&h|$Yy9x&ssr|g8#J!+1$}{F$P?)H3b_PNucx1j zRwOGUC?OC68hx!2zC8-xY4|$udrlND91cxj@PvE*OUL&~dC@Hd?AR(+1r^6C)lsaf zrEq1EL-5#fX`0OT5&ZA5jIiNW*X^!o{>E7T#;A2u%(^M6sf}rBr!|HtP41*7cf8;Q z*K@8rvR75FsP3-2SM^%e{R3}yywMS9J~-9fG1=S^>GVu>dM7)*k)GpIJ*OsnPDS`& zv?mnn2}L^xW1WN1=F_p})6t!0Vmr@7YtF`M&d%Vi#5RJHK##4EflX;bQW%snNY43aA?0=(Qm~H!H zb9*iApP%H<;{-<7Hk#H5tEX8}fx-syZW?KX?V+Kl zKrYA{iK7C!KGsAW707k6y~I(0TnnxZIV!NT7V#k(p+)Sa5nB4NL}`z>S5OuA3948K z4TWMC^O2AW6jt>UgNStgaiS5C*7}J?oYthxC}j>+L}8oJh|=6FV(7wu81FVQ75r8A zp70mjvum=sE8^y++`dV-FVa6S)gPSf4@N@grb6M#P#7n=RlJQx8)VyQI4Y1k&UO$- z1#(B(PU5IQu7%YRM+I_qY?nBUES1P^tfx9eq$?VTMnpP)H_?blYlU4${Qnp@G}#-9 z3=T~Vh9?IT)2|bUar?KTNpbtBz(ws5?!5x3D?s6cM3xScpE zkUPyfh@%3zKGsPb707k6Bg9dGTnpUT0w;qk0$5@+TYUtpA!Q5gu!N0TQZWb5fydPWm3H*{EuYz(lMA;V~!Usl<2H zx3SiQvy#=ff#(Q(v zPvL%V*0$}`Q7M7&rQi@zJi&Je=rx3=cL$yldj1CllvVL`sq;~SO_0~OR{jE!-+<0O ztLkGUak==B3BK0X=gMaI3E}hbWXrX-ha`WrP`Qmkn(we{1-1A*jPjSyJtX;~h047Q z-tfI)DrNc^Q8IK+jPYF Date: Sun, 11 Aug 2024 15:34:38 +1000 Subject: [PATCH 077/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.yaml b/config.yaml index f556138..c77a787 100644 --- a/config.yaml +++ b/config.yaml @@ -21,7 +21,7 @@ training: # learning_rate_d: 5.0e-4 # Increased learning rate for discriminator ema_decay: 0.999 style_mixing_prob: 0.9 - initial_noise_magnitude: 0.1 + initial_noise_magnitude: 0.01 final_noise_magnitude: 0.001 gradient_accumulation_steps: 1 lambda_pixel: 10 # in paper lambda-pixel = 10 Adjust this value as needed From b915399727b7006c990e907818d0a9bff5bac8ec Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 15:46:52 +1000 Subject: [PATCH 078/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- helper.py | 40 +++++++++------------------------------- 1 file changed, 9 insertions(+), 31 deletions(-) diff --git a/helper.py b/helper.py index df5bfc0..9045b7b 100644 --- a/helper.py +++ b/helper.py @@ -107,7 +107,6 @@ def log_loss_landscape(model, loss_fns, dataloader, step): # Global variable to store the current table structure current_table_columns = None - def log_grad_flow(named_parameters, global_step): global current_table_columns @@ -124,15 +123,21 @@ def log_grad_flow(named_parameters, global_step): # Normalize gradients max_grad = max(grads) - normalized_grads = [g / max_grad for g in grads] + if max_grad == 0: + print("Warning: All gradients are zero.") + normalized_grads = grads # Use unnormalized grads if max is zero + else: + normalized_grads = [g / max_grad for g in grads] # Create the matplotlib figure plt.figure(figsize=(12, 6)) plt.bar(range(len(grads)), normalized_grads, alpha=0.5) plt.xticks(range(len(grads)), layers, rotation="vertical") plt.xlabel("Layers") - plt.ylabel("Normalized Gradient Magnitude") - plt.title(f"Normalized Gradient Flow (Step {global_step})") + plt.ylabel("Gradient Magnitude") + plt.title(f"Gradient Flow (Step {global_step})") + if max_grad == 0: + plt.title(f"Gradient Flow (Step {global_step}) - All Gradients Zero") plt.tight_layout() # Save the figure to a bytes buffer @@ -145,28 +150,6 @@ def log_grad_flow(named_parameters, global_step): plt.close() - # Check if the table structure has changed - # new_columns = ["step"] + layers - # if current_table_columns != new_columns: - # # If the structure has changed, delete the existing table and create a new one - # if wandb.run is not None: - # wandb.run.config.update({"gradient_flow_columns": new_columns}, allow_val_change=True) - - # # Delete the existing table artifact - # api = wandb.Api() - # try: - # artifact = api.artifact(f"{wandb.run.entity}/{wandb.run.project}/gradient_flow_table:latest") - # artifact.delete(delete_aliases=True) - # print("Deleted existing gradient_flow_table artifact.") - # except wandb.errors.CommError: - # print("No existing gradient_flow_table artifact found.") - - # current_table_columns = new_columns - - # # Create or update the wandb.Table - # data = [[global_step] + normalized_grads] - # table = wandb.Table(data=data, columns=current_table_columns) - # Calculate statistics stats = { "max_gradient": max_grad, @@ -187,11 +170,6 @@ def log_grad_flow(named_parameters, global_step): "step": global_step } - # Log the table as an artifact - # artifact = wandb.Artifact('gradient_flow_table', type='table') - # artifact.add(table, 'gradient_flow_data') - # wandb.log_artifact(artifact) - # Log other metrics wandb.log(log_dict) From ff29550b3626697dadcaa84f7692d5b032211d62 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 15:59:49 +1000 Subject: [PATCH 079/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 10985 -> 10985 bytes __pycache__/helper.cpython-311.pyc | Bin 25147 -> 25384 bytes config.yaml | 2 +- framedecoder.py | 2 +- helper.py | 2 +- model.py | 2 +- test.py | 81 ------------ test_gradient_flow.py | 152 +++++++++++++++++++++++ 8 files changed, 156 insertions(+), 85 deletions(-) delete mode 100644 test.py create mode 100644 test_gradient_flow.py diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index d5ed227493abde243bb1adfdd3f8c30a15ad5c67..1cdb2c91880296496f05a536bd79439e6da0cc48 100644 GIT binary patch delta 35 pcmaDE`ZAPvIWI340}$8+??`*Mk(ZT`F=VqaqlP9U&tw~|wE(}{3Dy7r delta 35 pcmaDE`ZAPvIWI340}!wvp_)=bUDP+qEO^s(Vlnmn;oSYg*}`8DtbigF9ulu5q0V#DRzp z1wp}iF)Fe@)Iy=wc_v`?Rd@=E>AVY6VHA^H3HAb0}z{ z;F63@J>U6n96_zi{3b`&T7Lm7(6xYS2a$^VUm33RNYQI2j~2lcT|V>7R>2+7V zmQ&~FsbQSWkit8c8Oek#auLjYYxpPg%WE zPG(hVih_GlVoGLeUWtNlVtQU?Noh)Il{_|Sx19WP1&!d6)B=THjG9H(K#PiOKm^m~ zCZ!*YZ22JTlP5Q*te1pY4zzSdDv-D+ZgK@E#v(E~Q?*=8?;?x-6&C#qEc$m@*jqd= z@Y!wP1`-f-k;VQBi~R)_`^nmBsf;F*C#touNHMT-O*WKP-E60x!^C)e^J>jFF2*;T zxlA}18UIffHhs#(z&QDr=}#su#>rpJZZQclZoX*#pONv_=8Kl9Y>Wz HB+UT;5=5v% diff --git a/config.yaml b/config.yaml index c77a787..11be71e 100644 --- a/config.yaml +++ b/config.yaml @@ -5,7 +5,7 @@ model: num_layers: 4 use_resnet_feature: False use_mlgffn: False - use_enhanced_generator: True + use_enhanced_generator: False # Training parameters training: initial_video_repeat: 5 diff --git a/framedecoder.py b/framedecoder.py index 4a3cf3b..f43b75f 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -4,7 +4,7 @@ import math from resblock import UpConvResBlock,FeatResBlock -DEBUG = False +DEBUG = True def debug_print(*args, **kwargs): if DEBUG: print(*args, **kwargs) diff --git a/helper.py b/helper.py index 9045b7b..7a0ee32 100644 --- a/helper.py +++ b/helper.py @@ -124,7 +124,7 @@ def log_grad_flow(named_parameters, global_step): # Normalize gradients max_grad = max(grads) if max_grad == 0: - print("Warning: All gradients are zero.") + print("☠☠☠ Warning: All gradients are zero. ☠☠☠") normalized_grads = grads # Use unnormalized grads if max is zero else: normalized_grads = [g / max_grad for g in grads] diff --git a/model.py b/model.py index 5c06618..45ccf38 100644 --- a/model.py +++ b/model.py @@ -16,7 +16,7 @@ # from common import DownConvResBlock,UpConvResBlock import colored_traceback.auto # makes terminal show color coded output when crash from framedecoder import EnhancedFrameDecoder -DEBUG = False +DEBUG = True def debug_print(*args, **kwargs): if DEBUG: print(*args, **kwargs) diff --git a/test.py b/test.py deleted file mode 100644 index ee76cbd..0000000 --- a/test.py +++ /dev/null @@ -1,81 +0,0 @@ -import unittest -import torch -import torch.nn as nn -import sys -import os - -# Add the directory containing the module to the Python path -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from vit import PositionalEncoding, TransformerBlock, ImplicitMotionAlignment, CrossAttentionModule -from model import LatentTokenEncoder, LatentTokenDecoder, DenseFeatureEncoder, FrameDecoder - -class TestNeuralNetworkComponents(unittest.TestCase): - def setUp(self): - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.B, self.C_f, self.C_m, self.H, self.W = 2, 256, 256, 64, 64 - self.feature_dim = self.C_f - self.motion_dim = self.C_m - self.heads = 8 - self.dim_head = 64 - self.mlp_dim = 1024 - self.dm = 32 - self.input_size = 256 - self.input_channels = 3 - - def test_positional_encoding(self): - pe = PositionalEncoding(d_model=self.motion_dim).to(self.device) - x = torch.randn(100, 1, self.motion_dim).to(self.device) - output = pe(x) - self.assertEqual(output.shape, (100, 1, self.motion_dim)) - - - - - def test_latent_token_encoder(self): - et = LatentTokenEncoder(dm=self.dm).to(self.device) - x = torch.randn(self.B, self.input_channels, self.input_size, self.input_size).to(self.device) - output = et(x) - - self.assertEqual(len(output.shape), 2, f"ET output should be a 2D tensor (batch_size, dm), but got shape {output.shape}") - self.assertEqual(output.shape[0], self.B, f"First dimension should be batch_size ({self.B}), but got {output.shape[0]}") - self.assertTrue(32 <= output.shape[1] <= 1024, f"dm should be between 32 and 1024, got {output.shape[1]}") - - def test_dense_feature_encoder(self): - ef = DenseFeatureEncoder().to(self.device) - x = torch.randn(self.B, self.input_channels, self.input_size, self.input_size).to(self.device) - outputs = ef(x) - - self.assertEqual(len(outputs), 4, f"EF should produce 4 outputs, but got {len(outputs)}") - - expected_shapes = [ - (self.B, 128, 64, 64), - (self.B, 256, 32, 32), - (self.B, 512, 16, 16), - (self.B, 512, 8, 8) - ] - - for i, (output, expected_shape) in enumerate(zip(outputs, expected_shapes)): - self.assertEqual(output.shape, expected_shape, f"fr{i+1} should be {expected_shape}, but got {output.shape}") - - def test_latent_token_decoder(self): - et = LatentTokenEncoder(dm=self.dm).to(self.device) - latent_token_decoder = LatentTokenDecoder(latent_dim=self.dm).to(self.device) - - x_current = torch.randn(self.B, self.input_channels, self.input_size, self.input_size).to(self.device) - x_ref = torch.randn(self.B, self.input_channels, self.input_size, self.input_size).to(self.device) - - t_c = et(x_current) - t_r = et(x_ref) - - m_r = latent_token_decoder(t_r) - m_c = latent_token_decoder(t_c) - - self.assertEqual(len(m_r), len(m_c), "Number of outputs from LatentTokenDecoder should be the same for reference and current") - - for m_r_x, m_c_x in zip(m_r, m_c): - self.assertEqual(m_r_x.shape, m_c_x.shape, "Shapes of reference and current outputs should match") - - -if __name__ == '__main__': - unittest.main() \ No newline at end of file diff --git a/test_gradient_flow.py b/test_gradient_flow.py new file mode 100644 index 0000000..a9e80fb --- /dev/null +++ b/test_gradient_flow.py @@ -0,0 +1,152 @@ +import torch +import torch.nn as nn +from model import IMFModel, DenseFeatureEncoder, LatentTokenEncoder, LatentTokenDecoder, ImplicitMotionAlignment, FrameDecoder, ResNetFeatureExtractor + +def test_gradient_flow(model, input_shape): + print(f"\nTesting gradient flow for {type(model).__name__}") + x = torch.randn(input_shape) + x.requires_grad_(True) + + if isinstance(model, IMFModel): + x_current = x + x_reference = torch.randn_like(x) + output = model(x_current, x_reference) + elif isinstance(model, (DenseFeatureEncoder, LatentTokenEncoder, ResNetFeatureExtractor)): + output = model(x) + elif isinstance(model, LatentTokenDecoder): + latent = torch.randn(input_shape[0], model.const.shape[1]) # Adjust size as needed + output = model(latent) + elif isinstance(model, ImplicitMotionAlignment): + m_c = torch.randn_like(x) + m_r = torch.randn_like(x) + f_r = torch.randn_like(x) + output = model(m_c, m_r, f_r) + elif isinstance(model, FrameDecoder): + # Assuming FrameDecoder expects a list of tensors + features = [torch.randn_like(x) for _ in range(4)] # Adjust number as needed + output = model(features) + else: + raise ValueError(f"Unsupported model type: {type(model).__name__}") + + if isinstance(output, tuple): + loss = sum(o.sum() for o in output if isinstance(o, torch.Tensor)) + elif isinstance(output, list): + loss = sum(o.sum() for o in output) + else: + loss = output.sum() + + loss.backward() + + for name, param in model.named_parameters(): + if param.grad is None: + print(f"Warning: No gradient for {name}") + else: + grad_norm = param.grad.norm().item() + print(f"{name}: gradient norm = {grad_norm:.6f}") + assert grad_norm != 0, f"Zero gradient for {name}" + assert not torch.isnan(param.grad).any(), f"NaN gradient for {name}" + assert not torch.isinf(param.grad).any(), f"Inf gradient for {name}" + + print(f"Gradient flow test passed for {type(model).__name__}") + +def test_feature_extractor(feature_extractor, input_shape): + print(f"\nTesting {type(feature_extractor).__name__}") + x = torch.randn(input_shape) + features = feature_extractor(x) + + assert isinstance(features, list), f"Expected list output, got {type(features)}" + print(f"Number of feature maps: {len(features)}") + for i, feature in enumerate(features): + print(f"Feature map {i} shape: {feature.shape}") + + # Test gradient flow + loss = sum(feature.sum() for feature in features) + loss.backward() + + for name, param in feature_extractor.named_parameters(): + if param.grad is None: + print(f"Warning: No gradient for {name}") + else: + grad_norm = param.grad.norm().item() + print(f"{name}: gradient norm = {grad_norm:.6f}") + assert grad_norm != 0, f"Zero gradient for {name}" + assert not torch.isnan(param.grad).any(), f"NaN gradient for {name}" + assert not torch.isinf(param.grad).any(), f"Inf gradient for {name}" + + print(f"Gradient flow test passed for {type(feature_extractor).__name__}") + +def test_implicit_motion_alignment_modules(use_mlgffn=False): + print("\nTesting ImplicitMotionAlignment modules") + motion_dims = [128, 256, 512, 512] + batch_size = 1 + spatial_sizes = [64, 32, 16, 8] # Example spatial sizes for different levels + + for i, dim in enumerate(motion_dims): + model = ImplicitMotionAlignment( + feature_dim=dim, + motion_dim=dim, + depth=4, + num_heads=8, + window_size=8, + mlp_ratio=4, + use_mlgffn=use_mlgffn + ) + + spatial_size = spatial_sizes[i] + input_shape = (batch_size, dim, spatial_size, spatial_size) + + m_c = torch.randn(input_shape) + m_r = torch.randn(input_shape) + f_r = torch.randn(input_shape) + + output = model(m_c, m_r, f_r) + loss = output.sum() + loss.backward() + + print(f"Testing ImplicitMotionAlignment for dim={dim}, spatial_size={spatial_size}") + for name, param in model.named_parameters(): + if param.grad is None: + print(f"Warning: No gradient for {name}") + else: + grad_norm = param.grad.norm().item() + print(f"{name}: gradient norm = {grad_norm:.6f}") + assert grad_norm != 0, f"Zero gradient for {name}" + assert not torch.isnan(param.grad).any(), f"NaN gradient for {name}" + assert not torch.isinf(param.grad).any(), f"Inf gradient for {name}" + + print(f"Gradient flow test passed for ImplicitMotionAlignment with dim={dim}") + +def run_all_gradient_flow_tests(): + # Test IMFModel + model = IMFModel() + test_gradient_flow(model, (1, 3, 256, 256)) + + # Test DenseFeatureEncoder + dense_feature_encoder = DenseFeatureEncoder(output_channels=[128, 256, 512, 512]) + test_feature_extractor(dense_feature_encoder, (1, 3, 256, 256)) + + # Test ResNetFeatureExtractor + resnet_feature_extractor = ResNetFeatureExtractor(output_channels=[128, 256, 512, 512]) + test_feature_extractor(resnet_feature_extractor, (1, 3, 256, 256)) + + # Test LatentTokenEncoder + latent_token_encoder = LatentTokenEncoder( + initial_channels=64, + output_channels=[128, 256, 512, 512, 512, 512], + dm=32 + ) + test_gradient_flow(latent_token_encoder, (1, 3, 256, 256)) + + # Test LatentTokenDecoder + latent_token_decoder = LatentTokenDecoder() + test_gradient_flow(latent_token_decoder, (1, 32)) # Adjust latent dim as needed + + # Test ImplicitMotionAlignment modules + test_implicit_motion_alignment_modules() + + # Test FrameDecoder + frame_decoder = FrameDecoder() + test_gradient_flow(frame_decoder, (1, 512, 32, 32)) + +if __name__ == "__main__": + run_all_gradient_flow_tests() \ No newline at end of file From c7042255cc522d8372a272bc0172f098aead9de6 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 16:02:15 +1000 Subject: [PATCH 080/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- framedecoder.py | 2 +- gradients.txt | 1449 +++++++++++++++++++++++++++++++++++++++++ model.py | 2 +- test_gradient_flow.py | 5 +- 4 files changed, 1454 insertions(+), 4 deletions(-) create mode 100644 gradients.txt diff --git a/framedecoder.py b/framedecoder.py index f43b75f..4a3cf3b 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -4,7 +4,7 @@ import math from resblock import UpConvResBlock,FeatResBlock -DEBUG = True +DEBUG = False def debug_print(*args, **kwargs): if DEBUG: print(*args, **kwargs) diff --git a/gradients.txt b/gradients.txt new file mode 100644 index 0000000..b37b5c9 --- /dev/null +++ b/gradients.txt @@ -0,0 +1,1449 @@ + +Testing gradient flow for IMFModel +⚾ DenseFeatureEncoder input shape: torch.Size([1, 3, 256, 256]) + After initial conv: torch.Size([1, 64, 256, 256]) + After down_block 1: torch.Size([1, 64, 128, 128]) + After down_block 2: torch.Size([1, 128, 64, 64]) + After down_block 3: torch.Size([1, 256, 32, 32]) + After down_block 4: torch.Size([1, 512, 16, 16]) + After down_block 5: torch.Size([1, 512, 8, 8]) + DenseFeatureEncoder output shapes: [torch.Size([1, 128, 64, 64]), torch.Size([1, 256, 32, 32]), torch.Size([1, 512, 16, 16]), torch.Size([1, 512, 8, 8])] +LatentTokenEncoder input shape: torch.Size([1, 3, 256, 256]) +After initial conv and activation: torch.Size([1, 64, 256, 256]) +After res_block 1: torch.Size([1, 128, 128, 128]) +After res_block 2: torch.Size([1, 256, 64, 64]) +After res_block 3: torch.Size([1, 512, 32, 32]) +After res_block 4: torch.Size([1, 512, 16, 16]) +After res_block 5: torch.Size([1, 512, 8, 8]) +After res_block 6: torch.Size([1, 512, 4, 4]) +After equalconv: torch.Size([1, 512, 4, 4]) +After global average pooling: torch.Size([1, 512]) +After linear layer 1: torch.Size([1, 512]) +After linear layer 2: torch.Size([1, 512]) +After linear layer 3: torch.Size([1, 512]) +After linear layer 4: torch.Size([1, 512]) +Final output: torch.Size([1, 32]) +LatentTokenEncoder input shape: torch.Size([1, 3, 256, 256]) +After initial conv and activation: torch.Size([1, 64, 256, 256]) +After res_block 1: torch.Size([1, 128, 128, 128]) +After res_block 2: torch.Size([1, 256, 64, 64]) +After res_block 3: torch.Size([1, 512, 32, 32]) +After res_block 4: torch.Size([1, 512, 16, 16]) +After res_block 5: torch.Size([1, 512, 8, 8]) +After res_block 6: torch.Size([1, 512, 4, 4]) +After equalconv: torch.Size([1, 512, 4, 4]) +After global average pooling: torch.Size([1, 512]) +After linear layer 1: torch.Size([1, 512]) +After linear layer 2: torch.Size([1, 512]) +After linear layer 3: torch.Size([1, 512]) +After linear layer 4: torch.Size([1, 512]) +Final output: torch.Size([1, 32]) +🎒 FrameDecoder input shapes +f:torch.Size([1, 128, 64, 64]) +f:torch.Size([1, 256, 32, 32]) +f:torch.Size([1, 512, 16, 16]) +f:torch.Size([1, 512, 8, 8]) +Reshaped features: [torch.Size([1, 128, 64, 64]), torch.Size([1, 256, 32, 32]), torch.Size([1, 512, 16, 16]), torch.Size([1, 512, 8, 8])] + Initial x shape: torch.Size([1, 512, 8, 8]) + + Processing upconv_block 1 + After upconv_block 1: torch.Size([1, 512, 16, 16]) + Processing feat_block 1 + feat_block 1 input shape: torch.Size([1, 512, 16, 16]) + feat_block 1 output shape: torch.Size([1, 512, 16, 16]) + Concatenating: x torch.Size([1, 512, 16, 16]) and feat torch.Size([1, 512, 16, 16]) + After concatenation: torch.Size([1, 1024, 16, 16]) + + Processing upconv_block 2 + After upconv_block 2: torch.Size([1, 512, 32, 32]) + Processing feat_block 2 + feat_block 2 input shape: torch.Size([1, 256, 32, 32]) + feat_block 2 output shape: torch.Size([1, 256, 32, 32]) + Concatenating: x torch.Size([1, 512, 32, 32]) and feat torch.Size([1, 256, 32, 32]) + After concatenation: torch.Size([1, 768, 32, 32]) + + Processing upconv_block 3 + After upconv_block 3: torch.Size([1, 256, 64, 64]) + Processing feat_block 3 + feat_block 3 input shape: torch.Size([1, 128, 64, 64]) + feat_block 3 output shape: torch.Size([1, 128, 64, 64]) + Concatenating: x torch.Size([1, 256, 64, 64]) and feat torch.Size([1, 128, 64, 64]) + After concatenation: torch.Size([1, 384, 64, 64]) + + Processing upconv_block 4 + After upconv_block 4: torch.Size([1, 128, 128, 128]) + + Processing upconv_block 5 + After upconv_block 5: torch.Size([1, 64, 256, 256]) + + Applying final convolution + FrameDecoder final output shape: torch.Size([1, 3, 256, 256]) +latent_token_encoder.conv1.weight: gradient norm = 11.320090 +latent_token_encoder.conv1.bias: gradient norm = 1.302591 +latent_token_encoder.res_blocks.0.conv1.conv.weight: gradient norm = 35.192299 +latent_token_encoder.res_blocks.0.conv1.conv.bias: gradient norm = 0.000006 +latent_token_encoder.res_blocks.0.conv1.bn.weight: gradient norm = 0.598824 +latent_token_encoder.res_blocks.0.conv1.bn.bias: gradient norm = 0.571394 +latent_token_encoder.res_blocks.0.conv2.conv.weight: gradient norm = 40.653366 +latent_token_encoder.res_blocks.0.conv2.conv.bias: gradient norm = 0.000003 +latent_token_encoder.res_blocks.0.conv2.bn.weight: gradient norm = 0.596908 +latent_token_encoder.res_blocks.0.conv2.bn.bias: gradient norm = 0.471818 +latent_token_encoder.res_blocks.0.skip_conv.conv.weight: gradient norm = 29.697824 +latent_token_encoder.res_blocks.0.skip_conv.conv.bias: gradient norm = 0.000006 +latent_token_encoder.res_blocks.0.skip_conv.bn.weight: gradient norm = 0.647219 +latent_token_encoder.res_blocks.0.skip_conv.bn.bias: gradient norm = 0.563982 +latent_token_encoder.res_blocks.1.conv1.conv.weight: gradient norm = 37.058834 +latent_token_encoder.res_blocks.1.conv1.conv.bias: gradient norm = 0.000001 +latent_token_encoder.res_blocks.1.conv1.bn.weight: gradient norm = 0.529294 +latent_token_encoder.res_blocks.1.conv1.bn.bias: gradient norm = 0.432100 +latent_token_encoder.res_blocks.1.conv2.conv.weight: gradient norm = 43.283562 +latent_token_encoder.res_blocks.1.conv2.conv.bias: gradient norm = 0.000001 +latent_token_encoder.res_blocks.1.conv2.bn.weight: gradient norm = 0.460657 +latent_token_encoder.res_blocks.1.conv2.bn.bias: gradient norm = 0.343939 +latent_token_encoder.res_blocks.1.skip_conv.conv.weight: gradient norm = 30.393908 +latent_token_encoder.res_blocks.1.skip_conv.conv.bias: gradient norm = 0.000001 +latent_token_encoder.res_blocks.1.skip_conv.bn.weight: gradient norm = 0.408579 +latent_token_encoder.res_blocks.1.skip_conv.bn.bias: gradient norm = 0.340287 +latent_token_encoder.res_blocks.2.conv1.conv.weight: gradient norm = 39.043488 +latent_token_encoder.res_blocks.2.conv1.conv.bias: gradient norm = 0.000000 +latent_token_encoder.res_blocks.2.conv1.bn.weight: gradient norm = 0.392250 +latent_token_encoder.res_blocks.2.conv1.bn.bias: gradient norm = 0.339671 +latent_token_encoder.res_blocks.2.conv2.conv.weight: gradient norm = 45.689297 +latent_token_encoder.res_blocks.2.conv2.conv.bias: gradient norm = 0.000000 +latent_token_encoder.res_blocks.2.conv2.bn.weight: gradient norm = 0.328779 +latent_token_encoder.res_blocks.2.conv2.bn.bias: gradient norm = 0.265127 +latent_token_encoder.res_blocks.2.skip_conv.conv.weight: gradient norm = 32.098759 +latent_token_encoder.res_blocks.2.skip_conv.conv.bias: gradient norm = 0.000000 +latent_token_encoder.res_blocks.2.skip_conv.bn.weight: gradient norm = 0.305273 +latent_token_encoder.res_blocks.2.skip_conv.bn.bias: gradient norm = 0.261218 +latent_token_encoder.res_blocks.3.conv1.conv.weight: gradient norm = 41.664135 +latent_token_encoder.res_blocks.3.conv1.conv.bias: gradient norm = 0.000000 +latent_token_encoder.res_blocks.3.conv1.bn.weight: gradient norm = 0.282246 +latent_token_encoder.res_blocks.3.conv1.bn.bias: gradient norm = 0.257811 +latent_token_encoder.res_blocks.3.conv2.conv.weight: gradient norm = 34.451809 +latent_token_encoder.res_blocks.3.conv2.conv.bias: gradient norm = 0.000000 +latent_token_encoder.res_blocks.3.conv2.bn.weight: gradient norm = 0.250673 +latent_token_encoder.res_blocks.3.conv2.bn.bias: gradient norm = 0.212399 +latent_token_encoder.res_blocks.3.skip_conv.conv.weight: gradient norm = 34.148746 +latent_token_encoder.res_blocks.3.skip_conv.conv.bias: gradient norm = 0.000000 +latent_token_encoder.res_blocks.3.skip_conv.bn.weight: gradient norm = 0.249347 +latent_token_encoder.res_blocks.3.skip_conv.bn.bias: gradient norm = 0.207034 +latent_token_encoder.res_blocks.4.conv1.conv.weight: gradient norm = 31.698423 +latent_token_encoder.res_blocks.4.conv1.conv.bias: gradient norm = 0.000000 +latent_token_encoder.res_blocks.4.conv1.bn.weight: gradient norm = 0.229714 +latent_token_encoder.res_blocks.4.conv1.bn.bias: gradient norm = 0.203783 +latent_token_encoder.res_blocks.4.conv2.conv.weight: gradient norm = 26.827261 +latent_token_encoder.res_blocks.4.conv2.conv.bias: gradient norm = 0.000000 +latent_token_encoder.res_blocks.4.conv2.bn.weight: gradient norm = 0.198378 +latent_token_encoder.res_blocks.4.conv2.bn.bias: gradient norm = 0.168535 +latent_token_encoder.res_blocks.4.skip_conv.conv.weight: gradient norm = 26.452568 +latent_token_encoder.res_blocks.4.skip_conv.conv.bias: gradient norm = 0.000000 +latent_token_encoder.res_blocks.4.skip_conv.bn.weight: gradient norm = 0.204645 +latent_token_encoder.res_blocks.4.skip_conv.bn.bias: gradient norm = 0.180799 +latent_token_encoder.res_blocks.5.conv1.conv.weight: gradient norm = 25.184984 +latent_token_encoder.res_blocks.5.conv1.conv.bias: gradient norm = 0.000000 +latent_token_encoder.res_blocks.5.conv1.bn.weight: gradient norm = 0.194987 +latent_token_encoder.res_blocks.5.conv1.bn.bias: gradient norm = 0.181515 +latent_token_encoder.res_blocks.5.conv2.conv.weight: gradient norm = 23.415251 +latent_token_encoder.res_blocks.5.conv2.conv.bias: gradient norm = 0.000000 +latent_token_encoder.res_blocks.5.conv2.bn.weight: gradient norm = 0.679544 +latent_token_encoder.res_blocks.5.conv2.bn.bias: gradient norm = 0.828077 +latent_token_encoder.res_blocks.5.skip_conv.conv.weight: gradient norm = 23.278402 +latent_token_encoder.res_blocks.5.skip_conv.conv.bias: gradient norm = 0.000000 +latent_token_encoder.res_blocks.5.skip_conv.bn.weight: gradient norm = 0.669718 +latent_token_encoder.res_blocks.5.skip_conv.bn.bias: gradient norm = 0.814979 +latent_token_encoder.equalconv.conv.bias: gradient norm = 1.513508 +latent_token_encoder.equalconv.conv.weight_orig: gradient norm = 1.269233 +latent_token_encoder.linear_layers.0.linear.bias: gradient norm = 1.082299 +latent_token_encoder.linear_layers.0.linear.weight_orig: gradient norm = 1.281194 +latent_token_encoder.linear_layers.1.linear.bias: gradient norm = 1.083161 +latent_token_encoder.linear_layers.1.linear.weight_orig: gradient norm = 1.365153 +latent_token_encoder.linear_layers.2.linear.bias: gradient norm = 1.175855 +latent_token_encoder.linear_layers.2.linear.weight_orig: gradient norm = 1.566979 +latent_token_encoder.linear_layers.3.linear.bias: gradient norm = 1.122202 +latent_token_encoder.linear_layers.3.linear.weight_orig: gradient norm = 1.604452 +latent_token_encoder.final_linear.linear.bias: gradient norm = 1.085345 +latent_token_encoder.final_linear.linear.weight_orig: gradient norm = 1.654073 +latent_token_decoder.const: gradient norm = 97.677322 +latent_token_decoder.style_conv_layers.0.conv.weight: gradient norm = 103.677361 +latent_token_decoder.style_conv_layers.0.style.weight: gradient norm = 637.967407 +latent_token_decoder.style_conv_layers.0.style.bias: gradient norm = 816.861206 +latent_token_decoder.style_conv_layers.1.conv.weight: gradient norm = 116.547768 +latent_token_decoder.style_conv_layers.1.style.weight: gradient norm = 712.670349 +latent_token_decoder.style_conv_layers.1.style.bias: gradient norm = 903.644470 +latent_token_decoder.style_conv_layers.2.conv.weight: gradient norm = 112.437096 +latent_token_decoder.style_conv_layers.2.style.weight: gradient norm = 706.033447 +latent_token_decoder.style_conv_layers.2.style.bias: gradient norm = 912.103760 +latent_token_decoder.style_conv_layers.3.conv.weight: gradient norm = 108.083435 +latent_token_decoder.style_conv_layers.3.style.weight: gradient norm = 646.467590 +latent_token_decoder.style_conv_layers.3.style.bias: gradient norm = 814.138550 +latent_token_decoder.style_conv_layers.4.conv.weight: gradient norm = 114.735664 +latent_token_decoder.style_conv_layers.4.style.weight: gradient norm = 702.971863 +latent_token_decoder.style_conv_layers.4.style.bias: gradient norm = 884.571960 +latent_token_decoder.style_conv_layers.5.conv.weight: gradient norm = 112.581055 +latent_token_decoder.style_conv_layers.5.style.weight: gradient norm = 683.582458 +latent_token_decoder.style_conv_layers.5.style.bias: gradient norm = 865.454529 +latent_token_decoder.style_conv_layers.6.conv.weight: gradient norm = 106.337189 +latent_token_decoder.style_conv_layers.6.style.weight: gradient norm = 667.077759 +latent_token_decoder.style_conv_layers.6.style.bias: gradient norm = 826.534668 +latent_token_decoder.style_conv_layers.7.conv.weight: gradient norm = 13.221072 +latent_token_decoder.style_conv_layers.7.style.weight: gradient norm = 70.980080 +latent_token_decoder.style_conv_layers.7.style.bias: gradient norm = 88.377274 +latent_token_decoder.style_conv_layers.8.conv.weight: gradient norm = 13.097474 +latent_token_decoder.style_conv_layers.8.style.weight: gradient norm = 74.613533 +latent_token_decoder.style_conv_layers.8.style.bias: gradient norm = 90.828224 +latent_token_decoder.style_conv_layers.9.conv.weight: gradient norm = 12.330044 +latent_token_decoder.style_conv_layers.9.style.weight: gradient norm = 77.935249 +latent_token_decoder.style_conv_layers.9.style.bias: gradient norm = 94.628799 +latent_token_decoder.style_conv_layers.10.conv.weight: gradient norm = 1.301443 +latent_token_decoder.style_conv_layers.10.style.weight: gradient norm = 8.505334 +latent_token_decoder.style_conv_layers.10.style.bias: gradient norm = 11.019262 +latent_token_decoder.style_conv_layers.11.conv.weight: gradient norm = 1.418244 +latent_token_decoder.style_conv_layers.11.style.weight: gradient norm = 6.965727 +latent_token_decoder.style_conv_layers.11.style.bias: gradient norm = 8.955790 +latent_token_decoder.style_conv_layers.12.conv.weight: gradient norm = 1.536332 +latent_token_decoder.style_conv_layers.12.style.weight: gradient norm = 8.992867 +latent_token_decoder.style_conv_layers.12.style.bias: gradient norm = 11.907469 +dense_feature_encoder.initial_conv.0.weight: gradient norm = 1261954.125000 +dense_feature_encoder.initial_conv.0.bias: gradient norm = 0.528200 +dense_feature_encoder.initial_conv.1.weight: gradient norm = 42890.601562 +dense_feature_encoder.initial_conv.1.bias: gradient norm = 40748.906250 +dense_feature_encoder.down_blocks.0.conv1.weight: gradient norm = 2045675.625000 +dense_feature_encoder.down_blocks.0.conv1.bias: gradient norm = 0.577734 +dense_feature_encoder.down_blocks.0.bn1.weight: gradient norm = 43814.976562 +dense_feature_encoder.down_blocks.0.bn1.bias: gradient norm = 40440.164062 +dense_feature_encoder.down_blocks.0.conv2.weight: gradient norm = 1723571.750000 +dense_feature_encoder.down_blocks.0.conv2.bias: gradient norm = 37497.964844 +dense_feature_encoder.down_blocks.0.feat_res_block1.bn1.weight: gradient norm = 34612.054688 +dense_feature_encoder.down_blocks.0.feat_res_block1.bn1.bias: gradient norm = 28645.166016 +dense_feature_encoder.down_blocks.0.feat_res_block1.conv1.weight: gradient norm = 1316390.125000 +dense_feature_encoder.down_blocks.0.feat_res_block1.conv1.bias: gradient norm = 0.181676 +dense_feature_encoder.down_blocks.0.feat_res_block1.bn2.weight: gradient norm = 30390.326172 +dense_feature_encoder.down_blocks.0.feat_res_block1.bn2.bias: gradient norm = 27264.349609 +dense_feature_encoder.down_blocks.0.feat_res_block1.conv2.weight: gradient norm = 1147479.625000 +dense_feature_encoder.down_blocks.0.feat_res_block1.conv2.bias: gradient norm = 37497.945312 +dense_feature_encoder.down_blocks.0.feat_res_block2.bn1.weight: gradient norm = 16371.284180 +dense_feature_encoder.down_blocks.0.feat_res_block2.bn1.bias: gradient norm = 9709.594727 +dense_feature_encoder.down_blocks.0.feat_res_block2.conv1.weight: gradient norm = 751728.500000 +dense_feature_encoder.down_blocks.0.feat_res_block2.conv1.bias: gradient norm = 0.087378 +dense_feature_encoder.down_blocks.0.feat_res_block2.bn2.weight: gradient norm = 17468.033203 +dense_feature_encoder.down_blocks.0.feat_res_block2.bn2.bias: gradient norm = 14630.524414 +dense_feature_encoder.down_blocks.0.feat_res_block2.conv2.weight: gradient norm = 659186.937500 +dense_feature_encoder.down_blocks.0.feat_res_block2.conv2.bias: gradient norm = 21049.035156 +dense_feature_encoder.down_blocks.1.conv1.weight: gradient norm = 726457.062500 +dense_feature_encoder.down_blocks.1.conv1.bias: gradient norm = 0.203926 +dense_feature_encoder.down_blocks.1.bn1.weight: gradient norm = 15529.459961 +dense_feature_encoder.down_blocks.1.bn1.bias: gradient norm = 13290.531250 +dense_feature_encoder.down_blocks.1.conv2.weight: gradient norm = 884565.125000 +dense_feature_encoder.down_blocks.1.conv2.bias: gradient norm = 15043.881836 +dense_feature_encoder.down_blocks.1.feat_res_block1.bn1.weight: gradient norm = 9918.655273 +dense_feature_encoder.down_blocks.1.feat_res_block1.bn1.bias: gradient norm = 8931.614258 +dense_feature_encoder.down_blocks.1.feat_res_block1.conv1.weight: gradient norm = 674186.562500 +dense_feature_encoder.down_blocks.1.feat_res_block1.conv1.bias: gradient norm = 0.026552 +dense_feature_encoder.down_blocks.1.feat_res_block1.bn2.weight: gradient norm = 10514.752930 +dense_feature_encoder.down_blocks.1.feat_res_block1.bn2.bias: gradient norm = 9401.004883 +dense_feature_encoder.down_blocks.1.feat_res_block1.conv2.weight: gradient norm = 592075.250000 +dense_feature_encoder.down_blocks.1.feat_res_block1.conv2.bias: gradient norm = 15043.884766 +dense_feature_encoder.down_blocks.1.feat_res_block2.bn1.weight: gradient norm = 6842.926270 +dense_feature_encoder.down_blocks.1.feat_res_block2.bn1.bias: gradient norm = 4114.761230 +dense_feature_encoder.down_blocks.1.feat_res_block2.conv1.weight: gradient norm = 390124.812500 +dense_feature_encoder.down_blocks.1.feat_res_block2.conv1.bias: gradient norm = 0.013663 +dense_feature_encoder.down_blocks.1.feat_res_block2.bn2.weight: gradient norm = 5605.596191 +dense_feature_encoder.down_blocks.1.feat_res_block2.bn2.bias: gradient norm = 5269.239746 +dense_feature_encoder.down_blocks.1.feat_res_block2.conv2.weight: gradient norm = 344331.781250 +dense_feature_encoder.down_blocks.1.feat_res_block2.conv2.bias: gradient norm = 8606.774414 +dense_feature_encoder.down_blocks.2.conv1.weight: gradient norm = 392616.437500 +dense_feature_encoder.down_blocks.2.conv1.bias: gradient norm = 0.043205 +dense_feature_encoder.down_blocks.2.bn1.weight: gradient norm = 5644.379395 +dense_feature_encoder.down_blocks.2.bn1.bias: gradient norm = 5424.253418 +dense_feature_encoder.down_blocks.2.conv2.weight: gradient norm = 482836.343750 +dense_feature_encoder.down_blocks.2.conv2.bias: gradient norm = 5319.142090 +dense_feature_encoder.down_blocks.2.feat_res_block1.bn1.weight: gradient norm = 4363.792969 +dense_feature_encoder.down_blocks.2.feat_res_block1.bn1.bias: gradient norm = 3858.229736 +dense_feature_encoder.down_blocks.2.feat_res_block1.conv1.weight: gradient norm = 369035.843750 +dense_feature_encoder.down_blocks.2.feat_res_block1.conv1.bias: gradient norm = 0.005482 +dense_feature_encoder.down_blocks.2.feat_res_block1.bn2.weight: gradient norm = 3692.497070 +dense_feature_encoder.down_blocks.2.feat_res_block1.bn2.bias: gradient norm = 3563.835693 +dense_feature_encoder.down_blocks.2.feat_res_block1.conv2.weight: gradient norm = 321398.812500 +dense_feature_encoder.down_blocks.2.feat_res_block1.conv2.bias: gradient norm = 5319.141602 +dense_feature_encoder.down_blocks.2.feat_res_block2.bn1.weight: gradient norm = 2592.855713 +dense_feature_encoder.down_blocks.2.feat_res_block2.bn1.bias: gradient norm = 1626.711792 +dense_feature_encoder.down_blocks.2.feat_res_block2.conv1.weight: gradient norm = 218663.015625 +dense_feature_encoder.down_blocks.2.feat_res_block2.conv1.bias: gradient norm = 0.002826 +dense_feature_encoder.down_blocks.2.feat_res_block2.bn2.weight: gradient norm = 2311.634277 +dense_feature_encoder.down_blocks.2.feat_res_block2.bn2.bias: gradient norm = 2285.262939 +dense_feature_encoder.down_blocks.2.feat_res_block2.conv2.weight: gradient norm = 194630.562500 +dense_feature_encoder.down_blocks.2.feat_res_block2.conv2.bias: gradient norm = 3578.514160 +dense_feature_encoder.down_blocks.3.conv1.weight: gradient norm = 216363.968750 +dense_feature_encoder.down_blocks.3.conv1.bias: gradient norm = 0.007331 +dense_feature_encoder.down_blocks.3.bn1.weight: gradient norm = 2331.440186 +dense_feature_encoder.down_blocks.3.bn1.bias: gradient norm = 2060.851807 +dense_feature_encoder.down_blocks.3.conv2.weight: gradient norm = 278709.406250 +dense_feature_encoder.down_blocks.3.conv2.bias: gradient norm = 2808.988525 +dense_feature_encoder.down_blocks.3.feat_res_block1.bn1.weight: gradient norm = 1753.826050 +dense_feature_encoder.down_blocks.3.feat_res_block1.bn1.bias: gradient norm = 1576.393188 +dense_feature_encoder.down_blocks.3.feat_res_block1.conv1.weight: gradient norm = 208968.453125 +dense_feature_encoder.down_blocks.3.feat_res_block1.conv1.bias: gradient norm = 0.001112 +dense_feature_encoder.down_blocks.3.feat_res_block1.bn2.weight: gradient norm = 1643.449463 +dense_feature_encoder.down_blocks.3.feat_res_block1.bn2.bias: gradient norm = 1535.514404 +dense_feature_encoder.down_blocks.3.feat_res_block1.conv2.weight: gradient norm = 188005.406250 +dense_feature_encoder.down_blocks.3.feat_res_block1.conv2.bias: gradient norm = 2808.988525 +dense_feature_encoder.down_blocks.3.feat_res_block2.bn1.weight: gradient norm = 1016.050171 +dense_feature_encoder.down_blocks.3.feat_res_block2.bn1.bias: gradient norm = 594.455688 +dense_feature_encoder.down_blocks.3.feat_res_block2.conv1.weight: gradient norm = 123049.718750 +dense_feature_encoder.down_blocks.3.feat_res_block2.conv1.bias: gradient norm = 0.000549 +dense_feature_encoder.down_blocks.3.feat_res_block2.bn2.weight: gradient norm = 1030.693726 +dense_feature_encoder.down_blocks.3.feat_res_block2.bn2.bias: gradient norm = 1057.232300 +dense_feature_encoder.down_blocks.3.feat_res_block2.conv2.weight: gradient norm = 124434.140625 +dense_feature_encoder.down_blocks.3.feat_res_block2.conv2.bias: gradient norm = 2772.211426 +dense_feature_encoder.down_blocks.4.conv1.weight: gradient norm = 17605.542969 +dense_feature_encoder.down_blocks.4.conv1.bias: gradient norm = 0.000225 +dense_feature_encoder.down_blocks.4.bn1.weight: gradient norm = 186.101883 +dense_feature_encoder.down_blocks.4.bn1.bias: gradient norm = 206.837646 +dense_feature_encoder.down_blocks.4.conv2.weight: gradient norm = 22415.679688 +dense_feature_encoder.down_blocks.4.conv2.bias: gradient norm = 683.605408 +dense_feature_encoder.down_blocks.4.feat_res_block1.bn1.weight: gradient norm = 102.955162 +dense_feature_encoder.down_blocks.4.feat_res_block1.bn1.bias: gradient norm = 85.672539 +dense_feature_encoder.down_blocks.4.feat_res_block1.conv1.weight: gradient norm = 12537.026367 +dense_feature_encoder.down_blocks.4.feat_res_block1.conv1.bias: gradient norm = 0.000046 +dense_feature_encoder.down_blocks.4.feat_res_block1.bn2.weight: gradient norm = 170.146027 +dense_feature_encoder.down_blocks.4.feat_res_block1.bn2.bias: gradient norm = 184.894623 +dense_feature_encoder.down_blocks.4.feat_res_block1.conv2.weight: gradient norm = 19280.949219 +dense_feature_encoder.down_blocks.4.feat_res_block1.conv2.bias: gradient norm = 683.605408 +dense_feature_encoder.down_blocks.4.feat_res_block2.bn1.weight: gradient norm = 64.609283 +dense_feature_encoder.down_blocks.4.feat_res_block2.bn1.bias: gradient norm = 38.744553 +dense_feature_encoder.down_blocks.4.feat_res_block2.conv1.weight: gradient norm = 7630.462402 +dense_feature_encoder.down_blocks.4.feat_res_block2.conv1.bias: gradient norm = 0.000031 +dense_feature_encoder.down_blocks.4.feat_res_block2.bn2.weight: gradient norm = 222.754822 +dense_feature_encoder.down_blocks.4.feat_res_block2.bn2.bias: gradient norm = 271.142609 +dense_feature_encoder.down_blocks.4.feat_res_block2.conv2.weight: gradient norm = 24712.933594 +dense_feature_encoder.down_blocks.4.feat_res_block2.conv2.bias: gradient norm = 1020.370422 +implicit_motion_alignment.0.cross_attention.to_q.weight: gradient norm = 7172.759277 +implicit_motion_alignment.0.cross_attention.to_q.bias: gradient norm = 299.454865 +implicit_motion_alignment.0.cross_attention.to_k.weight: gradient norm = 8212.823242 +implicit_motion_alignment.0.cross_attention.to_k.bias: gradient norm = 0.000764 +implicit_motion_alignment.0.cross_attention.to_v.weight: gradient norm = 5681.493652 +implicit_motion_alignment.0.cross_attention.to_v.bias: gradient norm = 866.093262 +implicit_motion_alignment.0.cross_attention.to_out.weight: gradient norm = 6450.944336 +implicit_motion_alignment.0.cross_attention.to_out.bias: gradient norm = 1444.495117 +implicit_motion_alignment.0.blocks.0.attention.in_proj_weight: gradient norm = 3082.971436 +implicit_motion_alignment.0.blocks.0.attention.in_proj_bias: gradient norm = 272.653687 +implicit_motion_alignment.0.blocks.0.attention.out_proj.weight: gradient norm = 3865.456055 +implicit_motion_alignment.0.blocks.0.attention.out_proj.bias: gradient norm = 452.623993 +implicit_motion_alignment.0.blocks.0.mlp.0.weight: gradient norm = 2978.205566 +implicit_motion_alignment.0.blocks.0.mlp.0.bias: gradient norm = 174.592377 +implicit_motion_alignment.0.blocks.0.mlp.2.weight: gradient norm = 5632.219238 +implicit_motion_alignment.0.blocks.0.mlp.2.bias: gradient norm = 393.436218 +implicit_motion_alignment.0.blocks.0.norm1.weight: gradient norm = 204.647385 +implicit_motion_alignment.0.blocks.0.norm1.bias: gradient norm = 194.814240 +implicit_motion_alignment.0.blocks.0.norm2.weight: gradient norm = 149.777344 +implicit_motion_alignment.0.blocks.0.norm2.bias: gradient norm = 103.171776 +implicit_motion_alignment.0.blocks.1.attention.in_proj_weight: gradient norm = 1947.355347 +implicit_motion_alignment.0.blocks.1.attention.in_proj_bias: gradient norm = 172.135742 +implicit_motion_alignment.0.blocks.1.attention.out_proj.weight: gradient norm = 2403.857666 +implicit_motion_alignment.0.blocks.1.attention.out_proj.bias: gradient norm = 315.627258 +implicit_motion_alignment.0.blocks.1.mlp.0.weight: gradient norm = 2625.773682 +implicit_motion_alignment.0.blocks.1.mlp.0.bias: gradient norm = 147.157471 +implicit_motion_alignment.0.blocks.1.mlp.2.weight: gradient norm = 4314.665527 +implicit_motion_alignment.0.blocks.1.mlp.2.bias: gradient norm = 296.901855 +implicit_motion_alignment.0.blocks.1.norm1.weight: gradient norm = 113.220306 +implicit_motion_alignment.0.blocks.1.norm1.bias: gradient norm = 107.932404 +implicit_motion_alignment.0.blocks.1.norm2.weight: gradient norm = 130.719666 +implicit_motion_alignment.0.blocks.1.norm2.bias: gradient norm = 88.862740 +implicit_motion_alignment.0.blocks.2.attention.in_proj_weight: gradient norm = 1651.054443 +implicit_motion_alignment.0.blocks.2.attention.in_proj_bias: gradient norm = 145.939697 +implicit_motion_alignment.0.blocks.2.attention.out_proj.weight: gradient norm = 1908.379150 +implicit_motion_alignment.0.blocks.2.attention.out_proj.bias: gradient norm = 262.678467 +implicit_motion_alignment.0.blocks.2.mlp.0.weight: gradient norm = 1903.335938 +implicit_motion_alignment.0.blocks.2.mlp.0.bias: gradient norm = 111.045517 +implicit_motion_alignment.0.blocks.2.mlp.2.weight: gradient norm = 3564.890137 +implicit_motion_alignment.0.blocks.2.mlp.2.bias: gradient norm = 232.527222 +implicit_motion_alignment.0.blocks.2.norm1.weight: gradient norm = 86.301308 +implicit_motion_alignment.0.blocks.2.norm1.bias: gradient norm = 95.616722 +implicit_motion_alignment.0.blocks.2.norm2.weight: gradient norm = 93.158737 +implicit_motion_alignment.0.blocks.2.norm2.bias: gradient norm = 68.619881 +implicit_motion_alignment.0.blocks.3.attention.in_proj_weight: gradient norm = 1267.642090 +implicit_motion_alignment.0.blocks.3.attention.in_proj_bias: gradient norm = 112.047752 +implicit_motion_alignment.0.blocks.3.attention.out_proj.weight: gradient norm = 1529.242432 +implicit_motion_alignment.0.blocks.3.attention.out_proj.bias: gradient norm = 202.893173 +implicit_motion_alignment.0.blocks.3.mlp.0.weight: gradient norm = 1553.427856 +implicit_motion_alignment.0.blocks.3.mlp.0.bias: gradient norm = 86.845612 +implicit_motion_alignment.0.blocks.3.mlp.2.weight: gradient norm = 2833.397217 +implicit_motion_alignment.0.blocks.3.mlp.2.bias: gradient norm = 182.641647 +implicit_motion_alignment.0.blocks.3.norm1.weight: gradient norm = 79.112938 +implicit_motion_alignment.0.blocks.3.norm1.bias: gradient norm = 76.958519 +implicit_motion_alignment.0.blocks.3.norm2.weight: gradient norm = 81.617523 +implicit_motion_alignment.0.blocks.3.norm2.bias: gradient norm = 57.876919 +implicit_motion_alignment.1.cross_attention.to_q.weight: gradient norm = 9160.849609 +implicit_motion_alignment.1.cross_attention.to_q.bias: gradient norm = 147.528275 +implicit_motion_alignment.1.cross_attention.to_k.weight: gradient norm = 8823.687500 +implicit_motion_alignment.1.cross_attention.to_k.bias: gradient norm = 0.003274 +implicit_motion_alignment.1.cross_attention.to_v.weight: gradient norm = 15215.882812 +implicit_motion_alignment.1.cross_attention.to_v.bias: gradient norm = 2034.159180 +implicit_motion_alignment.1.cross_attention.to_out.weight: gradient norm = 14636.551758 +implicit_motion_alignment.1.cross_attention.to_out.bias: gradient norm = 3487.177490 +implicit_motion_alignment.1.blocks.0.attention.in_proj_weight: gradient norm = 8962.036133 +implicit_motion_alignment.1.blocks.0.attention.in_proj_bias: gradient norm = 560.366089 +implicit_motion_alignment.1.blocks.0.attention.out_proj.weight: gradient norm = 10652.517578 +implicit_motion_alignment.1.blocks.0.attention.out_proj.bias: gradient norm = 924.748474 +implicit_motion_alignment.1.blocks.0.mlp.0.weight: gradient norm = 6645.636230 +implicit_motion_alignment.1.blocks.0.mlp.0.bias: gradient norm = 330.510010 +implicit_motion_alignment.1.blocks.0.mlp.2.weight: gradient norm = 12979.674805 +implicit_motion_alignment.1.blocks.0.mlp.2.bias: gradient norm = 867.364197 +implicit_motion_alignment.1.blocks.0.norm1.weight: gradient norm = 359.250061 +implicit_motion_alignment.1.blocks.0.norm1.bias: gradient norm = 384.146454 +implicit_motion_alignment.1.blocks.0.norm2.weight: gradient norm = 242.397125 +implicit_motion_alignment.1.blocks.0.norm2.bias: gradient norm = 199.539566 +implicit_motion_alignment.1.blocks.1.attention.in_proj_weight: gradient norm = 6216.146973 +implicit_motion_alignment.1.blocks.1.attention.in_proj_bias: gradient norm = 388.518127 +implicit_motion_alignment.1.blocks.1.attention.out_proj.weight: gradient norm = 7833.398926 +implicit_motion_alignment.1.blocks.1.attention.out_proj.bias: gradient norm = 672.448364 +implicit_motion_alignment.1.blocks.1.mlp.0.weight: gradient norm = 4997.746582 +implicit_motion_alignment.1.blocks.1.mlp.0.bias: gradient norm = 247.142502 +implicit_motion_alignment.1.blocks.1.mlp.2.weight: gradient norm = 9446.717773 +implicit_motion_alignment.1.blocks.1.mlp.2.bias: gradient norm = 613.891663 +implicit_motion_alignment.1.blocks.1.norm1.weight: gradient norm = 258.326782 +implicit_motion_alignment.1.blocks.1.norm1.bias: gradient norm = 258.351593 +implicit_motion_alignment.1.blocks.1.norm2.weight: gradient norm = 182.523193 +implicit_motion_alignment.1.blocks.1.norm2.bias: gradient norm = 141.543350 +implicit_motion_alignment.1.blocks.2.attention.in_proj_weight: gradient norm = 5159.766113 +implicit_motion_alignment.1.blocks.2.attention.in_proj_bias: gradient norm = 322.489685 +implicit_motion_alignment.1.blocks.2.attention.out_proj.weight: gradient norm = 5649.913086 +implicit_motion_alignment.1.blocks.2.attention.out_proj.bias: gradient norm = 525.410339 +implicit_motion_alignment.1.blocks.2.mlp.0.weight: gradient norm = 3930.154785 +implicit_motion_alignment.1.blocks.2.mlp.0.bias: gradient norm = 195.256027 +implicit_motion_alignment.1.blocks.2.mlp.2.weight: gradient norm = 7043.860840 +implicit_motion_alignment.1.blocks.2.mlp.2.bias: gradient norm = 493.602417 +implicit_motion_alignment.1.blocks.2.norm1.weight: gradient norm = 231.309479 +implicit_motion_alignment.1.blocks.2.norm1.bias: gradient norm = 234.348450 +implicit_motion_alignment.1.blocks.2.norm2.weight: gradient norm = 163.887009 +implicit_motion_alignment.1.blocks.2.norm2.bias: gradient norm = 117.587914 +implicit_motion_alignment.1.blocks.3.attention.in_proj_weight: gradient norm = 4141.009277 +implicit_motion_alignment.1.blocks.3.attention.in_proj_bias: gradient norm = 258.815552 +implicit_motion_alignment.1.blocks.3.attention.out_proj.weight: gradient norm = 5263.742676 +implicit_motion_alignment.1.blocks.3.attention.out_proj.bias: gradient norm = 444.193542 +implicit_motion_alignment.1.blocks.3.mlp.0.weight: gradient norm = 3443.606934 +implicit_motion_alignment.1.blocks.3.mlp.0.bias: gradient norm = 171.457214 +implicit_motion_alignment.1.blocks.3.mlp.2.weight: gradient norm = 6313.919434 +implicit_motion_alignment.1.blocks.3.mlp.2.bias: gradient norm = 436.431549 +implicit_motion_alignment.1.blocks.3.norm1.weight: gradient norm = 189.105286 +implicit_motion_alignment.1.blocks.3.norm1.bias: gradient norm = 187.003067 +implicit_motion_alignment.1.blocks.3.norm2.weight: gradient norm = 130.518982 +implicit_motion_alignment.1.blocks.3.norm2.bias: gradient norm = 98.740410 +implicit_motion_alignment.2.cross_attention.to_q.weight: gradient norm = 57918.445312 +implicit_motion_alignment.2.cross_attention.to_q.bias: gradient norm = 616.383240 +implicit_motion_alignment.2.cross_attention.to_k.weight: gradient norm = 56672.425781 +implicit_motion_alignment.2.cross_attention.to_k.bias: gradient norm = 0.001941 +implicit_motion_alignment.2.cross_attention.to_v.weight: gradient norm = 65670.023438 +implicit_motion_alignment.2.cross_attention.to_v.bias: gradient norm = 6943.051270 +implicit_motion_alignment.2.cross_attention.to_out.weight: gradient norm = 68886.468750 +implicit_motion_alignment.2.cross_attention.to_out.bias: gradient norm = 12067.115234 +implicit_motion_alignment.2.blocks.0.attention.in_proj_weight: gradient norm = 44729.753906 +implicit_motion_alignment.2.blocks.0.attention.in_proj_bias: gradient norm = 1977.625977 +implicit_motion_alignment.2.blocks.0.attention.out_proj.weight: gradient norm = 55193.621094 +implicit_motion_alignment.2.blocks.0.attention.out_proj.bias: gradient norm = 3414.075439 +implicit_motion_alignment.2.blocks.0.mlp.0.weight: gradient norm = 31777.855469 +implicit_motion_alignment.2.blocks.0.mlp.0.bias: gradient norm = 1163.918091 +implicit_motion_alignment.2.blocks.0.mlp.2.weight: gradient norm = 56904.546875 +implicit_motion_alignment.2.blocks.0.mlp.2.bias: gradient norm = 3038.209229 +implicit_motion_alignment.2.blocks.0.norm1.weight: gradient norm = 1271.711304 +implicit_motion_alignment.2.blocks.0.norm1.bias: gradient norm = 1395.896851 +implicit_motion_alignment.2.blocks.0.norm2.weight: gradient norm = 771.165527 +implicit_motion_alignment.2.blocks.0.norm2.bias: gradient norm = 640.822021 +implicit_motion_alignment.2.blocks.1.attention.in_proj_weight: gradient norm = 30322.958984 +implicit_motion_alignment.2.blocks.1.attention.in_proj_bias: gradient norm = 1340.140869 +implicit_motion_alignment.2.blocks.1.attention.out_proj.weight: gradient norm = 37086.679688 +implicit_motion_alignment.2.blocks.1.attention.out_proj.bias: gradient norm = 2287.069336 +implicit_motion_alignment.2.blocks.1.mlp.0.weight: gradient norm = 22606.705078 +implicit_motion_alignment.2.blocks.1.mlp.0.bias: gradient norm = 843.317383 +implicit_motion_alignment.2.blocks.1.mlp.2.weight: gradient norm = 41132.113281 +implicit_motion_alignment.2.blocks.1.mlp.2.bias: gradient norm = 2140.683594 +implicit_motion_alignment.2.blocks.1.norm1.weight: gradient norm = 971.328491 +implicit_motion_alignment.2.blocks.1.norm1.bias: gradient norm = 954.319824 +implicit_motion_alignment.2.blocks.1.norm2.weight: gradient norm = 622.454407 +implicit_motion_alignment.2.blocks.1.norm2.bias: gradient norm = 487.610809 +implicit_motion_alignment.2.blocks.2.attention.in_proj_weight: gradient norm = 23788.753906 +implicit_motion_alignment.2.blocks.2.attention.in_proj_bias: gradient norm = 1051.346802 +implicit_motion_alignment.2.blocks.2.attention.out_proj.weight: gradient norm = 28509.636719 +implicit_motion_alignment.2.blocks.2.attention.out_proj.bias: gradient norm = 1809.996948 +implicit_motion_alignment.2.blocks.2.mlp.0.weight: gradient norm = 18232.632812 +implicit_motion_alignment.2.blocks.2.mlp.0.bias: gradient norm = 672.084900 +implicit_motion_alignment.2.blocks.2.mlp.2.weight: gradient norm = 33210.898438 +implicit_motion_alignment.2.blocks.2.mlp.2.bias: gradient norm = 1748.233521 +implicit_motion_alignment.2.blocks.2.norm1.weight: gradient norm = 739.391113 +implicit_motion_alignment.2.blocks.2.norm1.bias: gradient norm = 730.960815 +implicit_motion_alignment.2.blocks.2.norm2.weight: gradient norm = 455.869904 +implicit_motion_alignment.2.blocks.2.norm2.bias: gradient norm = 376.657715 +implicit_motion_alignment.2.blocks.3.attention.in_proj_weight: gradient norm = 20037.000000 +implicit_motion_alignment.2.blocks.3.attention.in_proj_bias: gradient norm = 885.533203 +implicit_motion_alignment.2.blocks.3.attention.out_proj.weight: gradient norm = 23792.023438 +implicit_motion_alignment.2.blocks.3.attention.out_proj.bias: gradient norm = 1512.649170 +implicit_motion_alignment.2.blocks.3.mlp.0.weight: gradient norm = 15432.144531 +implicit_motion_alignment.2.blocks.3.mlp.0.bias: gradient norm = 558.646545 +implicit_motion_alignment.2.blocks.3.mlp.2.weight: gradient norm = 28600.699219 +implicit_motion_alignment.2.blocks.3.mlp.2.bias: gradient norm = 1457.054077 +implicit_motion_alignment.2.blocks.3.norm1.weight: gradient norm = 642.360168 +implicit_motion_alignment.2.blocks.3.norm1.bias: gradient norm = 619.804382 +implicit_motion_alignment.2.blocks.3.norm2.weight: gradient norm = 409.155731 +implicit_motion_alignment.2.blocks.3.norm2.bias: gradient norm = 322.425232 +implicit_motion_alignment.3.cross_attention.to_q.weight: gradient norm = 1614.265991 +implicit_motion_alignment.3.cross_attention.to_q.bias: gradient norm = 37.667973 +implicit_motion_alignment.3.cross_attention.to_k.weight: gradient norm = 1529.219116 +implicit_motion_alignment.3.cross_attention.to_k.bias: gradient norm = 0.000025 +implicit_motion_alignment.3.cross_attention.to_v.weight: gradient norm = 20481.875000 +implicit_motion_alignment.3.cross_attention.to_v.bias: gradient norm = 2778.997559 +implicit_motion_alignment.3.cross_attention.to_out.weight: gradient norm = 20950.460938 +implicit_motion_alignment.3.cross_attention.to_out.bias: gradient norm = 4880.413086 +implicit_motion_alignment.3.blocks.0.attention.in_proj_weight: gradient norm = 16689.347656 +implicit_motion_alignment.3.blocks.0.attention.in_proj_bias: gradient norm = 739.324341 +implicit_motion_alignment.3.blocks.0.attention.out_proj.weight: gradient norm = 20577.023438 +implicit_motion_alignment.3.blocks.0.attention.out_proj.bias: gradient norm = 1304.239380 +implicit_motion_alignment.3.blocks.0.mlp.0.weight: gradient norm = 9882.261719 +implicit_motion_alignment.3.blocks.0.mlp.0.bias: gradient norm = 435.892151 +implicit_motion_alignment.3.blocks.0.mlp.2.weight: gradient norm = 19033.810547 +implicit_motion_alignment.3.blocks.0.mlp.2.bias: gradient norm = 1189.909180 +implicit_motion_alignment.3.blocks.0.norm1.weight: gradient norm = 513.257996 +implicit_motion_alignment.3.blocks.0.norm1.bias: gradient norm = 510.524811 +implicit_motion_alignment.3.blocks.0.norm2.weight: gradient norm = 283.421783 +implicit_motion_alignment.3.blocks.0.norm2.bias: gradient norm = 256.920807 +implicit_motion_alignment.3.blocks.1.attention.in_proj_weight: gradient norm = 11401.266602 +implicit_motion_alignment.3.blocks.1.attention.in_proj_bias: gradient norm = 503.946625 +implicit_motion_alignment.3.blocks.1.attention.out_proj.weight: gradient norm = 14275.846680 +implicit_motion_alignment.3.blocks.1.attention.out_proj.bias: gradient norm = 878.419495 +implicit_motion_alignment.3.blocks.1.mlp.0.weight: gradient norm = 6948.125000 +implicit_motion_alignment.3.blocks.1.mlp.0.bias: gradient norm = 306.458954 +implicit_motion_alignment.3.blocks.1.mlp.2.weight: gradient norm = 13208.842773 +implicit_motion_alignment.3.blocks.1.mlp.2.bias: gradient norm = 846.214172 +implicit_motion_alignment.3.blocks.1.norm1.weight: gradient norm = 374.208740 +implicit_motion_alignment.3.blocks.1.norm1.bias: gradient norm = 352.549347 +implicit_motion_alignment.3.blocks.1.norm2.weight: gradient norm = 181.769089 +implicit_motion_alignment.3.blocks.1.norm2.bias: gradient norm = 180.216583 +implicit_motion_alignment.3.blocks.2.attention.in_proj_weight: gradient norm = 9266.676758 +implicit_motion_alignment.3.blocks.2.attention.in_proj_bias: gradient norm = 409.569305 +implicit_motion_alignment.3.blocks.2.attention.out_proj.weight: gradient norm = 11665.613281 +implicit_motion_alignment.3.blocks.2.attention.out_proj.bias: gradient norm = 741.341492 +implicit_motion_alignment.3.blocks.2.mlp.0.weight: gradient norm = 5668.476562 +implicit_motion_alignment.3.blocks.2.mlp.0.bias: gradient norm = 249.987656 +implicit_motion_alignment.3.blocks.2.mlp.2.weight: gradient norm = 11216.748047 +implicit_motion_alignment.3.blocks.2.mlp.2.bias: gradient norm = 721.428223 +implicit_motion_alignment.3.blocks.2.norm1.weight: gradient norm = 284.530792 +implicit_motion_alignment.3.blocks.2.norm1.bias: gradient norm = 281.622528 +implicit_motion_alignment.3.blocks.2.norm2.weight: gradient norm = 137.971054 +implicit_motion_alignment.3.blocks.2.norm2.bias: gradient norm = 138.916519 +implicit_motion_alignment.3.blocks.3.attention.in_proj_weight: gradient norm = 8320.942383 +implicit_motion_alignment.3.blocks.3.attention.in_proj_bias: gradient norm = 367.763306 +implicit_motion_alignment.3.blocks.3.attention.out_proj.weight: gradient norm = 9882.198242 +implicit_motion_alignment.3.blocks.3.attention.out_proj.bias: gradient norm = 627.369751 +implicit_motion_alignment.3.blocks.3.mlp.0.weight: gradient norm = 5058.784180 +implicit_motion_alignment.3.blocks.3.mlp.0.bias: gradient norm = 223.123825 +implicit_motion_alignment.3.blocks.3.mlp.2.weight: gradient norm = 9312.669922 +implicit_motion_alignment.3.blocks.3.mlp.2.bias: gradient norm = 612.114624 +implicit_motion_alignment.3.blocks.3.norm1.weight: gradient norm = 274.168304 +implicit_motion_alignment.3.blocks.3.norm1.bias: gradient norm = 261.761627 +implicit_motion_alignment.3.blocks.3.norm2.weight: gradient norm = 128.616287 +implicit_motion_alignment.3.blocks.3.norm2.bias: gradient norm = 126.300209 +frame_decoder.upconv_blocks.0.conv1.weight: gradient norm = 60592.308594 +frame_decoder.upconv_blocks.0.conv1.bias: gradient norm = 0.001808 +frame_decoder.upconv_blocks.0.bn1.weight: gradient norm = 363.514557 +frame_decoder.upconv_blocks.0.bn1.bias: gradient norm = 327.577881 +frame_decoder.upconv_blocks.0.conv2.weight: gradient norm = 45056.128906 +frame_decoder.upconv_blocks.0.conv2.bias: gradient norm = 411.357971 +frame_decoder.upconv_blocks.0.feat_res_block1.bn1.weight: gradient norm = 267.631134 +frame_decoder.upconv_blocks.0.feat_res_block1.bn1.bias: gradient norm = 229.369659 +frame_decoder.upconv_blocks.0.feat_res_block1.conv1.weight: gradient norm = 32624.078125 +frame_decoder.upconv_blocks.0.feat_res_block1.conv1.bias: gradient norm = 0.000216 +frame_decoder.upconv_blocks.0.feat_res_block1.bn2.weight: gradient norm = 234.035675 +frame_decoder.upconv_blocks.0.feat_res_block1.bn2.bias: gradient norm = 237.045746 +frame_decoder.upconv_blocks.0.feat_res_block1.conv2.weight: gradient norm = 28679.394531 +frame_decoder.upconv_blocks.0.feat_res_block1.conv2.bias: gradient norm = 411.357941 +frame_decoder.upconv_blocks.0.feat_res_block2.bn1.weight: gradient norm = 172.110153 +frame_decoder.upconv_blocks.0.feat_res_block2.bn1.bias: gradient norm = 102.468239 +frame_decoder.upconv_blocks.0.feat_res_block2.conv1.weight: gradient norm = 20867.496094 +frame_decoder.upconv_blocks.0.feat_res_block2.conv1.bias: gradient norm = 0.000092 +frame_decoder.upconv_blocks.0.feat_res_block2.bn2.weight: gradient norm = 156.510025 +frame_decoder.upconv_blocks.0.feat_res_block2.bn2.bias: gradient norm = 152.596313 +frame_decoder.upconv_blocks.0.feat_res_block2.conv2.weight: gradient norm = 18507.460938 +frame_decoder.upconv_blocks.0.feat_res_block2.conv2.bias: gradient norm = 258.065430 +frame_decoder.upconv_blocks.1.conv1.weight: gradient norm = 52354.042969 +frame_decoder.upconv_blocks.1.conv1.bias: gradient norm = 0.000584 +frame_decoder.upconv_blocks.1.bn1.weight: gradient norm = 270.254639 +frame_decoder.upconv_blocks.1.bn1.bias: gradient norm = 249.904663 +frame_decoder.upconv_blocks.1.conv2.weight: gradient norm = 30800.742188 +frame_decoder.upconv_blocks.1.conv2.bias: gradient norm = 221.023911 +frame_decoder.upconv_blocks.1.feat_res_block1.bn1.weight: gradient norm = 179.602097 +frame_decoder.upconv_blocks.1.feat_res_block1.bn1.bias: gradient norm = 154.816086 +frame_decoder.upconv_blocks.1.feat_res_block1.conv1.weight: gradient norm = 20695.675781 +frame_decoder.upconv_blocks.1.feat_res_block1.conv1.bias: gradient norm = 0.000206 +frame_decoder.upconv_blocks.1.feat_res_block1.bn2.weight: gradient norm = 155.480011 +frame_decoder.upconv_blocks.1.feat_res_block1.bn2.bias: gradient norm = 139.650452 +frame_decoder.upconv_blocks.1.feat_res_block1.conv2.weight: gradient norm = 18071.796875 +frame_decoder.upconv_blocks.1.feat_res_block1.conv2.bias: gradient norm = 221.023926 +frame_decoder.upconv_blocks.1.feat_res_block2.bn1.weight: gradient norm = 120.884003 +frame_decoder.upconv_blocks.1.feat_res_block2.bn1.bias: gradient norm = 67.898956 +frame_decoder.upconv_blocks.1.feat_res_block2.conv1.weight: gradient norm = 14481.229492 +frame_decoder.upconv_blocks.1.feat_res_block2.conv1.bias: gradient norm = 0.000118 +frame_decoder.upconv_blocks.1.feat_res_block2.bn2.weight: gradient norm = 106.740341 +frame_decoder.upconv_blocks.1.feat_res_block2.bn2.bias: gradient norm = 96.545944 +frame_decoder.upconv_blocks.1.feat_res_block2.conv2.weight: gradient norm = 12700.120117 +frame_decoder.upconv_blocks.1.feat_res_block2.conv2.bias: gradient norm = 161.285233 +frame_decoder.upconv_blocks.2.conv1.weight: gradient norm = 25584.107422 +frame_decoder.upconv_blocks.2.conv1.bias: gradient norm = 0.000641 +frame_decoder.upconv_blocks.2.bn1.weight: gradient norm = 144.211197 +frame_decoder.upconv_blocks.2.bn1.bias: gradient norm = 124.190735 +frame_decoder.upconv_blocks.2.conv2.weight: gradient norm = 12412.494141 +frame_decoder.upconv_blocks.2.conv2.bias: gradient norm = 122.391388 +frame_decoder.upconv_blocks.2.feat_res_block1.bn1.weight: gradient norm = 106.961502 +frame_decoder.upconv_blocks.2.feat_res_block1.bn1.bias: gradient norm = 82.072807 +frame_decoder.upconv_blocks.2.feat_res_block1.conv1.weight: gradient norm = 8421.174805 +frame_decoder.upconv_blocks.2.feat_res_block1.conv1.bias: gradient norm = 0.000242 +frame_decoder.upconv_blocks.2.feat_res_block1.bn2.weight: gradient norm = 94.342499 +frame_decoder.upconv_blocks.2.feat_res_block1.bn2.bias: gradient norm = 82.140717 +frame_decoder.upconv_blocks.2.feat_res_block1.conv2.weight: gradient norm = 7411.430176 +frame_decoder.upconv_blocks.2.feat_res_block1.conv2.bias: gradient norm = 122.391388 +frame_decoder.upconv_blocks.2.feat_res_block2.bn1.weight: gradient norm = 71.750580 +frame_decoder.upconv_blocks.2.feat_res_block2.bn1.bias: gradient norm = 38.870594 +frame_decoder.upconv_blocks.2.feat_res_block2.conv1.weight: gradient norm = 6024.892090 +frame_decoder.upconv_blocks.2.feat_res_block2.conv1.bias: gradient norm = 0.000128 +frame_decoder.upconv_blocks.2.feat_res_block2.bn2.weight: gradient norm = 65.086143 +frame_decoder.upconv_blocks.2.feat_res_block2.bn2.bias: gradient norm = 55.891445 +frame_decoder.upconv_blocks.2.feat_res_block2.conv2.weight: gradient norm = 5477.744141 +frame_decoder.upconv_blocks.2.feat_res_block2.conv2.bias: gradient norm = 100.115128 +frame_decoder.upconv_blocks.3.conv1.weight: gradient norm = 13759.450195 +frame_decoder.upconv_blocks.3.conv1.bias: gradient norm = 0.000994 +frame_decoder.upconv_blocks.3.bn1.weight: gradient norm = 143.252609 +frame_decoder.upconv_blocks.3.bn1.bias: gradient norm = 107.856239 +frame_decoder.upconv_blocks.3.conv2.weight: gradient norm = 7888.992188 +frame_decoder.upconv_blocks.3.conv2.bias: gradient norm = 176.174347 +frame_decoder.upconv_blocks.3.feat_res_block1.bn1.weight: gradient norm = 114.358330 +frame_decoder.upconv_blocks.3.feat_res_block1.bn1.bias: gradient norm = 72.426872 +frame_decoder.upconv_blocks.3.feat_res_block1.conv1.weight: gradient norm = 5829.777832 +frame_decoder.upconv_blocks.3.feat_res_block1.conv1.bias: gradient norm = 0.000334 +frame_decoder.upconv_blocks.3.feat_res_block1.bn2.weight: gradient norm = 114.239845 +frame_decoder.upconv_blocks.3.feat_res_block1.bn2.bias: gradient norm = 74.872246 +frame_decoder.upconv_blocks.3.feat_res_block1.conv2.weight: gradient norm = 6110.258789 +frame_decoder.upconv_blocks.3.feat_res_block1.conv2.bias: gradient norm = 176.174438 +frame_decoder.upconv_blocks.3.feat_res_block2.bn1.weight: gradient norm = 104.959328 +frame_decoder.upconv_blocks.3.feat_res_block2.bn1.bias: gradient norm = 38.032383 +frame_decoder.upconv_blocks.3.feat_res_block2.conv1.weight: gradient norm = 5385.477051 +frame_decoder.upconv_blocks.3.feat_res_block2.conv1.bias: gradient norm = 0.000245 +frame_decoder.upconv_blocks.3.feat_res_block2.bn2.weight: gradient norm = 122.739899 +frame_decoder.upconv_blocks.3.feat_res_block2.bn2.bias: gradient norm = 61.308399 +frame_decoder.upconv_blocks.3.feat_res_block2.conv2.weight: gradient norm = 7331.392090 +frame_decoder.upconv_blocks.3.feat_res_block2.conv2.bias: gradient norm = 112.174393 +frame_decoder.upconv_blocks.4.conv1.weight: gradient norm = 14409.164062 +frame_decoder.upconv_blocks.4.conv1.bias: gradient norm = 0.004688 +frame_decoder.upconv_blocks.4.bn1.weight: gradient norm = 1713.187866 +frame_decoder.upconv_blocks.4.bn1.bias: gradient norm = 2384.170654 +frame_decoder.upconv_blocks.4.conv2.weight: gradient norm = 71547.648438 +frame_decoder.upconv_blocks.4.conv2.bias: gradient norm = 8161.481445 +frame_decoder.upconv_blocks.4.feat_res_block1.bn1.weight: gradient norm = 196.721848 +frame_decoder.upconv_blocks.4.feat_res_block1.bn1.bias: gradient norm = 129.611633 +frame_decoder.upconv_blocks.4.feat_res_block1.conv1.weight: gradient norm = 8289.049805 +frame_decoder.upconv_blocks.4.feat_res_block1.conv1.bias: gradient norm = 0.002942 +frame_decoder.upconv_blocks.4.feat_res_block1.bn2.weight: gradient norm = 1934.109497 +frame_decoder.upconv_blocks.4.feat_res_block1.bn2.bias: gradient norm = 2687.408691 +frame_decoder.upconv_blocks.4.feat_res_block1.conv2.weight: gradient norm = 70619.046875 +frame_decoder.upconv_blocks.4.feat_res_block1.conv2.bias: gradient norm = 8161.479980 +frame_decoder.upconv_blocks.4.feat_res_block2.bn1.weight: gradient norm = 168.980469 +frame_decoder.upconv_blocks.4.feat_res_block2.bn1.bias: gradient norm = 66.042244 +frame_decoder.upconv_blocks.4.feat_res_block2.conv1.weight: gradient norm = 6858.655762 +frame_decoder.upconv_blocks.4.feat_res_block2.conv1.bias: gradient norm = 0.001895 +frame_decoder.upconv_blocks.4.feat_res_block2.bn2.weight: gradient norm = 2277.798340 +frame_decoder.upconv_blocks.4.feat_res_block2.bn2.bias: gradient norm = 3192.185303 +frame_decoder.upconv_blocks.4.feat_res_block2.conv2.weight: gradient norm = 92337.460938 +frame_decoder.upconv_blocks.4.feat_res_block2.conv2.bias: gradient norm = 10983.710938 +frame_decoder.feat_blocks.0.0.bn1.weight: gradient norm = 914.015198 +frame_decoder.feat_blocks.0.0.bn1.bias: gradient norm = 1762.393555 +frame_decoder.feat_blocks.0.0.conv1.weight: gradient norm = 112018.046875 +frame_decoder.feat_blocks.0.0.conv1.bias: gradient norm = 0.001068 +frame_decoder.feat_blocks.0.0.bn2.weight: gradient norm = 882.649109 +frame_decoder.feat_blocks.0.0.bn2.bias: gradient norm = 709.563904 +frame_decoder.feat_blocks.0.0.conv2.weight: gradient norm = 99750.523438 +frame_decoder.feat_blocks.0.0.conv2.bias: gradient norm = 1457.059082 +frame_decoder.feat_blocks.0.1.bn1.weight: gradient norm = 304.796997 +frame_decoder.feat_blocks.0.1.bn1.bias: gradient norm = 164.815933 +frame_decoder.feat_blocks.0.1.conv1.weight: gradient norm = 40088.449219 +frame_decoder.feat_blocks.0.1.conv1.bias: gradient norm = 0.000141 +frame_decoder.feat_blocks.0.1.bn2.weight: gradient norm = 287.639709 +frame_decoder.feat_blocks.0.1.bn2.bias: gradient norm = 246.042969 +frame_decoder.feat_blocks.0.1.conv2.weight: gradient norm = 34703.417969 +frame_decoder.feat_blocks.0.1.conv2.bias: gradient norm = 348.631317 +frame_decoder.feat_blocks.0.2.bn1.weight: gradient norm = 203.449966 +frame_decoder.feat_blocks.0.2.bn1.bias: gradient norm = 134.795212 +frame_decoder.feat_blocks.0.2.conv1.weight: gradient norm = 24610.287109 +frame_decoder.feat_blocks.0.2.conv1.bias: gradient norm = 0.000090 +frame_decoder.feat_blocks.0.2.bn2.weight: gradient norm = 186.012344 +frame_decoder.feat_blocks.0.2.bn2.bias: gradient norm = 160.312744 +frame_decoder.feat_blocks.0.2.conv2.weight: gradient norm = 21716.519531 +frame_decoder.feat_blocks.0.2.conv2.bias: gradient norm = 234.856018 +frame_decoder.feat_blocks.1.0.bn1.weight: gradient norm = 367.177979 +frame_decoder.feat_blocks.1.0.bn1.bias: gradient norm = 916.964172 +frame_decoder.feat_blocks.1.0.conv1.weight: gradient norm = 28024.142578 +frame_decoder.feat_blocks.1.0.conv1.bias: gradient norm = 0.001249 +frame_decoder.feat_blocks.1.0.bn2.weight: gradient norm = 296.687073 +frame_decoder.feat_blocks.1.0.bn2.bias: gradient norm = 233.259674 +frame_decoder.feat_blocks.1.0.conv2.weight: gradient norm = 24733.443359 +frame_decoder.feat_blocks.1.0.conv2.bias: gradient norm = 436.446655 +frame_decoder.feat_blocks.1.1.bn1.weight: gradient norm = 152.548508 +frame_decoder.feat_blocks.1.1.bn1.bias: gradient norm = 72.520844 +frame_decoder.feat_blocks.1.1.conv1.weight: gradient norm = 12899.326172 +frame_decoder.feat_blocks.1.1.conv1.bias: gradient norm = 0.000132 +frame_decoder.feat_blocks.1.1.bn2.weight: gradient norm = 133.931915 +frame_decoder.feat_blocks.1.1.bn2.bias: gradient norm = 110.107811 +frame_decoder.feat_blocks.1.1.conv2.weight: gradient norm = 11448.852539 +frame_decoder.feat_blocks.1.1.conv2.bias: gradient norm = 161.139908 +frame_decoder.feat_blocks.1.2.bn1.weight: gradient norm = 96.466309 +frame_decoder.feat_blocks.1.2.bn1.bias: gradient norm = 62.471256 +frame_decoder.feat_blocks.1.2.conv1.weight: gradient norm = 8062.954102 +frame_decoder.feat_blocks.1.2.conv1.bias: gradient norm = 0.000090 +frame_decoder.feat_blocks.1.2.bn2.weight: gradient norm = 85.917229 +frame_decoder.feat_blocks.1.2.bn2.bias: gradient norm = 74.801933 +frame_decoder.feat_blocks.1.2.conv2.weight: gradient norm = 7259.808594 +frame_decoder.feat_blocks.1.2.conv2.bias: gradient norm = 98.378716 +frame_decoder.feat_blocks.2.0.bn1.weight: gradient norm = 273.806976 +frame_decoder.feat_blocks.2.0.bn1.bias: gradient norm = 114.268311 +frame_decoder.feat_blocks.2.0.conv1.weight: gradient norm = 15457.373047 +frame_decoder.feat_blocks.2.0.conv1.bias: gradient norm = 0.000337 +frame_decoder.feat_blocks.2.0.bn2.weight: gradient norm = 210.567581 +frame_decoder.feat_blocks.2.0.bn2.bias: gradient norm = 99.050667 +frame_decoder.feat_blocks.2.0.conv2.weight: gradient norm = 13873.974609 +frame_decoder.feat_blocks.2.0.conv2.bias: gradient norm = 182.663452 +frame_decoder.feat_blocks.2.1.bn1.weight: gradient norm = 91.382500 +frame_decoder.feat_blocks.2.1.bn1.bias: gradient norm = 36.706001 +frame_decoder.feat_blocks.2.1.conv1.weight: gradient norm = 5163.604492 +frame_decoder.feat_blocks.2.1.conv1.bias: gradient norm = 0.000127 +frame_decoder.feat_blocks.2.1.bn2.weight: gradient norm = 85.532173 +frame_decoder.feat_blocks.2.1.bn2.bias: gradient norm = 55.942551 +frame_decoder.feat_blocks.2.1.conv2.weight: gradient norm = 4767.778809 +frame_decoder.feat_blocks.2.1.conv2.bias: gradient norm = 87.282433 +frame_decoder.feat_blocks.2.2.bn1.weight: gradient norm = 52.966972 +frame_decoder.feat_blocks.2.2.bn1.bias: gradient norm = 31.354267 +frame_decoder.feat_blocks.2.2.conv1.weight: gradient norm = 3682.902832 +frame_decoder.feat_blocks.2.2.conv1.bias: gradient norm = 0.000082 +frame_decoder.feat_blocks.2.2.bn2.weight: gradient norm = 58.731979 +frame_decoder.feat_blocks.2.2.bn2.bias: gradient norm = 46.208664 +frame_decoder.feat_blocks.2.2.conv2.weight: gradient norm = 3915.986816 +frame_decoder.feat_blocks.2.2.conv2.bias: gradient norm = 60.335125 +frame_decoder.final_conv.0.weight: gradient norm = 219061.906250 +frame_decoder.final_conv.0.bias: gradient norm = 28033.355469 +mapping_network.net.0.weight: gradient norm = 13.773472 +mapping_network.net.0.bias: gradient norm = 1.867836 +mapping_network.net.2.weight: gradient norm = 25.532295 +mapping_network.net.2.bias: gradient norm = 5.514368 +mapping_network.net.4.weight: gradient norm = 16.174578 +mapping_network.net.4.bias: gradient norm = 9.728312 +mapping_network.net.6.weight: gradient norm = 18.081602 +mapping_network.net.6.bias: gradient norm = 25.428263 +mapping_network.net.8.weight: gradient norm = 26.132113 +mapping_network.net.8.bias: gradient norm = 58.106331 +mapping_network.net.10.weight: gradient norm = 76.844368 +mapping_network.net.10.bias: gradient norm = 128.570923 +mapping_network.net.12.weight: gradient norm = 157.641357 +mapping_network.net.12.bias: gradient norm = 335.587311 +mapping_network.net.14.weight: gradient norm = 265.274292 +mapping_network.net.14.bias: gradient norm = 709.298401 +Gradient flow test passed for IMFModel + +Testing DenseFeatureEncoder +⚾ DenseFeatureEncoder input shape: torch.Size([1, 3, 256, 256]) + After initial conv: torch.Size([1, 64, 256, 256]) + After down_block 1: torch.Size([1, 64, 128, 128]) + After down_block 2: torch.Size([1, 128, 64, 64]) + After down_block 3: torch.Size([1, 256, 32, 32]) + After down_block 4: torch.Size([1, 512, 16, 16]) + After down_block 5: torch.Size([1, 512, 8, 8]) + DenseFeatureEncoder output shapes: [torch.Size([1, 128, 64, 64]), torch.Size([1, 256, 32, 32]), torch.Size([1, 512, 16, 16]), torch.Size([1, 512, 8, 8])] +Number of feature maps: 4 +Feature map 0 shape: torch.Size([1, 128, 64, 64]) +Feature map 1 shape: torch.Size([1, 256, 32, 32]) +Feature map 2 shape: torch.Size([1, 512, 16, 16]) +Feature map 3 shape: torch.Size([1, 512, 8, 8]) +initial_conv.0.weight: gradient norm = 170580.140625 +initial_conv.0.bias: gradient norm = 0.054110 +initial_conv.1.weight: gradient norm = 6165.351074 +initial_conv.1.bias: gradient norm = 6034.975586 +down_blocks.0.conv1.weight: gradient norm = 285275.687500 +down_blocks.0.conv1.bias: gradient norm = 0.091466 +down_blocks.0.bn1.weight: gradient norm = 6036.648438 +down_blocks.0.bn1.bias: gradient norm = 4927.727051 +down_blocks.0.conv2.weight: gradient norm = 237660.609375 +down_blocks.0.conv2.bias: gradient norm = 4562.112793 +down_blocks.0.feat_res_block1.bn1.weight: gradient norm = 5337.809082 +down_blocks.0.feat_res_block1.bn1.bias: gradient norm = 4314.460938 +down_blocks.0.feat_res_block1.conv1.weight: gradient norm = 182041.171875 +down_blocks.0.feat_res_block1.conv1.bias: gradient norm = 0.019406 +down_blocks.0.feat_res_block1.bn2.weight: gradient norm = 3893.249756 +down_blocks.0.feat_res_block1.bn2.bias: gradient norm = 3130.116455 +down_blocks.0.feat_res_block1.conv2.weight: gradient norm = 159164.750000 +down_blocks.0.feat_res_block1.conv2.bias: gradient norm = 4562.115723 +down_blocks.0.feat_res_block2.bn1.weight: gradient norm = 3046.994629 +down_blocks.0.feat_res_block2.bn1.bias: gradient norm = 1648.698608 +down_blocks.0.feat_res_block2.conv1.weight: gradient norm = 109667.984375 +down_blocks.0.feat_res_block2.conv1.bias: gradient norm = 0.010872 +down_blocks.0.feat_res_block2.bn2.weight: gradient norm = 2728.734375 +down_blocks.0.feat_res_block2.bn2.bias: gradient norm = 2385.563477 +down_blocks.0.feat_res_block2.conv2.weight: gradient norm = 97057.367188 +down_blocks.0.feat_res_block2.conv2.bias: gradient norm = 3406.217041 +down_blocks.1.conv1.weight: gradient norm = 104285.085938 +down_blocks.1.conv1.bias: gradient norm = 0.029957 +down_blocks.1.bn1.weight: gradient norm = 5692.065918 +down_blocks.1.bn1.bias: gradient norm = 6687.843750 +down_blocks.1.conv2.weight: gradient norm = 297176.281250 +down_blocks.1.conv2.bias: gradient norm = 20222.232422 +down_blocks.1.feat_res_block1.bn1.weight: gradient norm = 1791.831177 +down_blocks.1.feat_res_block1.bn1.bias: gradient norm = 1511.256714 +down_blocks.1.feat_res_block1.conv1.weight: gradient norm = 97327.398438 +down_blocks.1.feat_res_block1.conv1.bias: gradient norm = 0.004952 +down_blocks.1.feat_res_block1.bn2.weight: gradient norm = 6779.489746 +down_blocks.1.feat_res_block1.bn2.bias: gradient norm = 6748.512695 +down_blocks.1.feat_res_block1.conv2.weight: gradient norm = 279915.812500 +down_blocks.1.feat_res_block1.conv2.bias: gradient norm = 20222.228516 +down_blocks.1.feat_res_block2.bn1.weight: gradient norm = 1110.107666 +down_blocks.1.feat_res_block2.bn1.bias: gradient norm = 536.505493 +down_blocks.1.feat_res_block2.conv1.weight: gradient norm = 56697.941406 +down_blocks.1.feat_res_block2.conv1.bias: gradient norm = 0.002878 +down_blocks.1.feat_res_block2.bn2.weight: gradient norm = 9267.432617 +down_blocks.1.feat_res_block2.bn2.bias: gradient norm = 9886.988281 +down_blocks.1.feat_res_block2.conv2.weight: gradient norm = 411766.656250 +down_blocks.1.feat_res_block2.conv2.bias: gradient norm = 30928.339844 +down_blocks.2.conv1.weight: gradient norm = 53153.492188 +down_blocks.2.conv1.bias: gradient norm = 0.006086 +down_blocks.2.bn1.weight: gradient norm = 2233.302979 +down_blocks.2.bn1.bias: gradient norm = 2593.960449 +down_blocks.2.conv2.weight: gradient norm = 157487.921875 +down_blocks.2.conv2.bias: gradient norm = 7832.495117 +down_blocks.2.feat_res_block1.bn1.weight: gradient norm = 591.916809 +down_blocks.2.feat_res_block1.bn1.bias: gradient norm = 532.923096 +down_blocks.2.feat_res_block1.conv1.weight: gradient norm = 49761.398438 +down_blocks.2.feat_res_block1.conv1.bias: gradient norm = 0.000891 +down_blocks.2.feat_res_block1.bn2.weight: gradient norm = 2610.686279 +down_blocks.2.feat_res_block1.bn2.bias: gradient norm = 2712.192139 +down_blocks.2.feat_res_block1.conv2.weight: gradient norm = 150819.390625 +down_blocks.2.feat_res_block1.conv2.bias: gradient norm = 7832.495117 +down_blocks.2.feat_res_block2.bn1.weight: gradient norm = 323.865509 +down_blocks.2.feat_res_block2.bn1.bias: gradient norm = 193.137772 +down_blocks.2.feat_res_block2.conv1.weight: gradient norm = 28084.486328 +down_blocks.2.feat_res_block2.conv1.bias: gradient norm = 0.000526 +down_blocks.2.feat_res_block2.bn2.weight: gradient norm = 3094.646729 +down_blocks.2.feat_res_block2.bn2.bias: gradient norm = 3292.751465 +down_blocks.2.feat_res_block2.conv2.weight: gradient norm = 203646.796875 +down_blocks.2.feat_res_block2.conv2.bias: gradient norm = 10972.582031 +down_blocks.3.conv1.weight: gradient norm = 26711.464844 +down_blocks.3.conv1.bias: gradient norm = 0.000967 +down_blocks.3.bn1.weight: gradient norm = 760.239929 +down_blocks.3.bn1.bias: gradient norm = 874.010071 +down_blocks.3.conv2.weight: gradient norm = 74343.328125 +down_blocks.3.conv2.bias: gradient norm = 2671.331055 +down_blocks.3.feat_res_block1.bn1.weight: gradient norm = 228.775681 +down_blocks.3.feat_res_block1.bn1.bias: gradient norm = 200.657623 +down_blocks.3.feat_res_block1.conv1.weight: gradient norm = 26002.552734 +down_blocks.3.feat_res_block1.conv1.bias: gradient norm = 0.000190 +down_blocks.3.feat_res_block1.bn2.weight: gradient norm = 844.289490 +down_blocks.3.feat_res_block1.bn2.bias: gradient norm = 834.922668 +down_blocks.3.feat_res_block1.conv2.weight: gradient norm = 70752.148438 +down_blocks.3.feat_res_block1.conv2.bias: gradient norm = 2671.330811 +down_blocks.3.feat_res_block2.bn1.weight: gradient norm = 135.755676 +down_blocks.3.feat_res_block2.bn1.bias: gradient norm = 82.133858 +down_blocks.3.feat_res_block2.conv1.weight: gradient norm = 15811.669922 +down_blocks.3.feat_res_block2.conv1.bias: gradient norm = 0.000133 +down_blocks.3.feat_res_block2.bn2.weight: gradient norm = 1088.709961 +down_blocks.3.feat_res_block2.bn2.bias: gradient norm = 1144.484863 +down_blocks.3.feat_res_block2.conv2.weight: gradient norm = 99523.093750 +down_blocks.3.feat_res_block2.conv2.bias: gradient norm = 3912.964355 +down_blocks.4.conv1.weight: gradient norm = 12301.476562 +down_blocks.4.conv1.bias: gradient norm = 0.000165 +down_blocks.4.bn1.weight: gradient norm = 174.082657 +down_blocks.4.bn1.bias: gradient norm = 193.936859 +down_blocks.4.conv2.weight: gradient norm = 18991.263672 +down_blocks.4.conv2.bias: gradient norm = 662.906006 +down_blocks.4.feat_res_block1.bn1.weight: gradient norm = 76.198135 +down_blocks.4.feat_res_block1.bn1.bias: gradient norm = 63.357056 +down_blocks.4.feat_res_block1.conv1.weight: gradient norm = 8722.328125 +down_blocks.4.feat_res_block1.conv1.bias: gradient norm = 0.000040 +down_blocks.4.feat_res_block1.bn2.weight: gradient norm = 226.620544 +down_blocks.4.feat_res_block1.bn2.bias: gradient norm = 219.217957 +down_blocks.4.feat_res_block1.conv2.weight: gradient norm = 17212.603516 +down_blocks.4.feat_res_block1.conv2.bias: gradient norm = 662.906006 +down_blocks.4.feat_res_block2.bn1.weight: gradient norm = 46.697407 +down_blocks.4.feat_res_block2.bn1.bias: gradient norm = 29.666679 +down_blocks.4.feat_res_block2.conv1.weight: gradient norm = 5199.561523 +down_blocks.4.feat_res_block2.conv1.bias: gradient norm = 0.000033 +down_blocks.4.feat_res_block2.bn2.weight: gradient norm = 258.914001 +down_blocks.4.feat_res_block2.bn2.bias: gradient norm = 268.383759 +down_blocks.4.feat_res_block2.conv2.weight: gradient norm = 23115.939453 +down_blocks.4.feat_res_block2.conv2.bias: gradient norm = 970.836731 +Gradient flow test passed for DenseFeatureEncoder +Initializing ResNetFeatureExtractor with output_channels: [128, 256, 512, 512] + +Testing ResNetFeatureExtractor +👟 ResNetFeatureExtractor input shape: torch.Size([1, 3, 256, 256]) +After layer0: torch.Size([1, 64, 64, 64]) +After layer1: torch.Size([1, 256, 64, 64]) +After adjust1: torch.Size([1, 128, 64, 64]) +After layer2: torch.Size([1, 512, 32, 32]) +After adjust2: torch.Size([1, 256, 32, 32]) +After layer3: torch.Size([1, 1024, 16, 16]) +After adjust3: torch.Size([1, 512, 16, 16]) +After layer4: torch.Size([1, 2048, 8, 8]) +After adjust4: torch.Size([1, 512, 8, 8]) +ResNetFeatureExtractor output: 4 features + Feature 1 shape: torch.Size([1, 128, 64, 64]) + Feature 2 shape: torch.Size([1, 256, 32, 32]) + Feature 3 shape: torch.Size([1, 512, 16, 16]) + Feature 4 shape: torch.Size([1, 512, 8, 8]) +Number of feature maps: 4 +Feature map 0 shape: torch.Size([1, 128, 64, 64]) +Feature map 1 shape: torch.Size([1, 256, 32, 32]) +Feature map 2 shape: torch.Size([1, 512, 16, 16]) +Feature map 3 shape: torch.Size([1, 512, 8, 8]) +resnet.conv1.weight: gradient norm = 4800.786621 +resnet.bn1.weight: gradient norm = 1046.911987 +resnet.bn1.bias: gradient norm = 144.213150 +resnet.layer1.0.conv1.weight: gradient norm = 1544.625000 +resnet.layer1.0.bn1.weight: gradient norm = 451.380371 +resnet.layer1.0.bn1.bias: gradient norm = 287.169098 +resnet.layer1.0.conv2.weight: gradient norm = 1826.635620 +resnet.layer1.0.bn2.weight: gradient norm = 788.603394 +resnet.layer1.0.bn2.bias: gradient norm = 308.688843 +resnet.layer1.0.conv3.weight: gradient norm = 2771.479492 +resnet.layer1.0.bn3.weight: gradient norm = 5670.572266 +resnet.layer1.0.bn3.bias: gradient norm = 14051.648438 +resnet.layer1.0.downsample.0.weight: gradient norm = 2745.248535 +resnet.layer1.0.downsample.1.weight: gradient norm = 7435.537109 +resnet.layer1.0.downsample.1.bias: gradient norm = 14051.648438 +resnet.layer1.1.conv1.weight: gradient norm = 686.393311 +resnet.layer1.1.bn1.weight: gradient norm = 324.551392 +resnet.layer1.1.bn1.bias: gradient norm = 198.787979 +resnet.layer1.1.conv2.weight: gradient norm = 1350.724609 +resnet.layer1.1.bn2.weight: gradient norm = 494.487000 +resnet.layer1.1.bn2.bias: gradient norm = 359.002625 +resnet.layer1.1.conv3.weight: gradient norm = 1761.334351 +resnet.layer1.1.bn3.weight: gradient norm = 5458.712402 +resnet.layer1.1.bn3.bias: gradient norm = 18945.775391 +resnet.layer1.2.conv1.weight: gradient norm = 858.605835 +resnet.layer1.2.bn1.weight: gradient norm = 332.684875 +resnet.layer1.2.bn1.bias: gradient norm = 257.751526 +resnet.layer1.2.conv2.weight: gradient norm = 1588.494629 +resnet.layer1.2.bn2.weight: gradient norm = 538.706299 +resnet.layer1.2.bn2.bias: gradient norm = 387.532867 +resnet.layer1.2.conv3.weight: gradient norm = 1951.876831 +resnet.layer1.2.bn3.weight: gradient norm = 4898.749512 +resnet.layer1.2.bn3.bias: gradient norm = 21995.453125 +resnet.layer2.0.conv1.weight: gradient norm = 2023.863647 +resnet.layer2.0.bn1.weight: gradient norm = 510.163147 +resnet.layer2.0.bn1.bias: gradient norm = 446.813141 +resnet.layer2.0.conv2.weight: gradient norm = 3274.617188 +resnet.layer2.0.bn2.weight: gradient norm = 625.217224 +resnet.layer2.0.bn2.bias: gradient norm = 421.059875 +resnet.layer2.0.conv3.weight: gradient norm = 2471.518799 +resnet.layer2.0.bn3.weight: gradient norm = 2157.436523 +resnet.layer2.0.bn3.bias: gradient norm = 3166.544922 +resnet.layer2.0.downsample.0.weight: gradient norm = 2595.494385 +resnet.layer2.0.downsample.1.weight: gradient norm = 2252.645508 +resnet.layer2.0.downsample.1.bias: gradient norm = 3166.544922 +resnet.layer2.1.conv1.weight: gradient norm = 1008.842529 +resnet.layer2.1.bn1.weight: gradient norm = 361.865540 +resnet.layer2.1.bn1.bias: gradient norm = 206.605804 +resnet.layer2.1.conv2.weight: gradient norm = 1664.446411 +resnet.layer2.1.bn2.weight: gradient norm = 388.908600 +resnet.layer2.1.bn2.bias: gradient norm = 298.457458 +resnet.layer2.1.conv3.weight: gradient norm = 1785.173706 +resnet.layer2.1.bn3.weight: gradient norm = 1831.097534 +resnet.layer2.1.bn3.bias: gradient norm = 3933.596191 +resnet.layer2.2.conv1.weight: gradient norm = 2064.628174 +resnet.layer2.2.bn1.weight: gradient norm = 493.342285 +resnet.layer2.2.bn1.bias: gradient norm = 389.481293 +resnet.layer2.2.conv2.weight: gradient norm = 2773.337646 +resnet.layer2.2.bn2.weight: gradient norm = 518.756042 +resnet.layer2.2.bn2.bias: gradient norm = 398.764740 +resnet.layer2.2.conv3.weight: gradient norm = 2332.246826 +resnet.layer2.2.bn3.weight: gradient norm = 2333.118652 +resnet.layer2.2.bn3.bias: gradient norm = 4345.148438 +resnet.layer2.3.conv1.weight: gradient norm = 2639.072021 +resnet.layer2.3.bn1.weight: gradient norm = 460.071503 +resnet.layer2.3.bn1.bias: gradient norm = 331.195435 +resnet.layer2.3.conv2.weight: gradient norm = 3077.151123 +resnet.layer2.3.bn2.weight: gradient norm = 585.600281 +resnet.layer2.3.bn2.bias: gradient norm = 456.634857 +resnet.layer2.3.conv3.weight: gradient norm = 2201.639404 +resnet.layer2.3.bn3.weight: gradient norm = 2194.386963 +resnet.layer2.3.bn3.bias: gradient norm = 5649.107422 +resnet.layer3.0.conv1.weight: gradient norm = 3144.827881 +resnet.layer3.0.bn1.weight: gradient norm = 695.798462 +resnet.layer3.0.bn1.bias: gradient norm = 537.638306 +resnet.layer3.0.conv2.weight: gradient norm = 4617.796387 +resnet.layer3.0.bn2.weight: gradient norm = 536.139526 +resnet.layer3.0.bn2.bias: gradient norm = 412.010681 +resnet.layer3.0.conv3.weight: gradient norm = 2923.624756 +resnet.layer3.0.bn3.weight: gradient norm = 848.097534 +resnet.layer3.0.bn3.bias: gradient norm = 709.832764 +resnet.layer3.0.downsample.0.weight: gradient norm = 3570.838135 +resnet.layer3.0.downsample.1.weight: gradient norm = 797.042908 +resnet.layer3.0.downsample.1.bias: gradient norm = 709.832764 +resnet.layer3.1.conv1.weight: gradient norm = 2758.773926 +resnet.layer3.1.bn1.weight: gradient norm = 504.663666 +resnet.layer3.1.bn1.bias: gradient norm = 387.509064 +resnet.layer3.1.conv2.weight: gradient norm = 3717.543213 +resnet.layer3.1.bn2.weight: gradient norm = 584.846558 +resnet.layer3.1.bn2.bias: gradient norm = 439.839874 +resnet.layer3.1.conv3.weight: gradient norm = 2612.463623 +resnet.layer3.1.bn3.weight: gradient norm = 800.792175 +resnet.layer3.1.bn3.bias: gradient norm = 802.272583 +resnet.layer3.2.conv1.weight: gradient norm = 2662.187500 +resnet.layer3.2.bn1.weight: gradient norm = 593.774841 +resnet.layer3.2.bn1.bias: gradient norm = 457.772217 +resnet.layer3.2.conv2.weight: gradient norm = 3410.258301 +resnet.layer3.2.bn2.weight: gradient norm = 625.713013 +resnet.layer3.2.bn2.bias: gradient norm = 470.226471 +resnet.layer3.2.conv3.weight: gradient norm = 2476.610107 +resnet.layer3.2.bn3.weight: gradient norm = 787.189880 +resnet.layer3.2.bn3.bias: gradient norm = 865.619385 +resnet.layer3.3.conv1.weight: gradient norm = 2778.352539 +resnet.layer3.3.bn1.weight: gradient norm = 555.105225 +resnet.layer3.3.bn1.bias: gradient norm = 411.569366 +resnet.layer3.3.conv2.weight: gradient norm = 3350.435303 +resnet.layer3.3.bn2.weight: gradient norm = 473.241455 +resnet.layer3.3.bn2.bias: gradient norm = 367.464722 +resnet.layer3.3.conv3.weight: gradient norm = 2241.077393 +resnet.layer3.3.bn3.weight: gradient norm = 763.453979 +resnet.layer3.3.bn3.bias: gradient norm = 904.271545 +resnet.layer3.4.conv1.weight: gradient norm = 2628.679443 +resnet.layer3.4.bn1.weight: gradient norm = 563.440918 +resnet.layer3.4.bn1.bias: gradient norm = 419.874725 +resnet.layer3.4.conv2.weight: gradient norm = 3023.427734 +resnet.layer3.4.bn2.weight: gradient norm = 491.295654 +resnet.layer3.4.bn2.bias: gradient norm = 338.390381 +resnet.layer3.4.conv3.weight: gradient norm = 2069.013916 +resnet.layer3.4.bn3.weight: gradient norm = 728.373901 +resnet.layer3.4.bn3.bias: gradient norm = 954.144836 +resnet.layer3.5.conv1.weight: gradient norm = 2297.656738 +resnet.layer3.5.bn1.weight: gradient norm = 476.655090 +resnet.layer3.5.bn1.bias: gradient norm = 351.150116 +resnet.layer3.5.conv2.weight: gradient norm = 2732.465088 +resnet.layer3.5.bn2.weight: gradient norm = 375.037262 +resnet.layer3.5.bn2.bias: gradient norm = 284.882965 +resnet.layer3.5.conv3.weight: gradient norm = 1963.578247 +resnet.layer3.5.bn3.weight: gradient norm = 821.036682 +resnet.layer3.5.bn3.bias: gradient norm = 1305.015503 +resnet.layer4.0.conv1.weight: gradient norm = 2335.634277 +resnet.layer4.0.bn1.weight: gradient norm = 477.492004 +resnet.layer4.0.bn1.bias: gradient norm = 315.976013 +resnet.layer4.0.conv2.weight: gradient norm = 3563.329834 +resnet.layer4.0.bn2.weight: gradient norm = 314.583069 +resnet.layer4.0.bn2.bias: gradient norm = 236.488907 +resnet.layer4.0.conv3.weight: gradient norm = 2241.919678 +resnet.layer4.0.bn3.weight: gradient norm = 215.211945 +resnet.layer4.0.bn3.bias: gradient norm = 213.119644 +resnet.layer4.0.downsample.0.weight: gradient norm = 2420.404297 +resnet.layer4.0.downsample.1.weight: gradient norm = 175.796448 +resnet.layer4.0.downsample.1.bias: gradient norm = 213.119644 +resnet.layer4.1.conv1.weight: gradient norm = 2086.645508 +resnet.layer4.1.bn1.weight: gradient norm = 389.913391 +resnet.layer4.1.bn1.bias: gradient norm = 279.497681 +resnet.layer4.1.conv2.weight: gradient norm = 3087.079834 +resnet.layer4.1.bn2.weight: gradient norm = 300.808746 +resnet.layer4.1.bn2.bias: gradient norm = 219.829819 +resnet.layer4.1.conv3.weight: gradient norm = 1802.238037 +resnet.layer4.1.bn3.weight: gradient norm = 255.878143 +resnet.layer4.1.bn3.bias: gradient norm = 325.353119 +resnet.layer4.2.conv1.weight: gradient norm = 1601.803223 +resnet.layer4.2.bn1.weight: gradient norm = 306.606506 +resnet.layer4.2.bn1.bias: gradient norm = 216.133408 +resnet.layer4.2.conv2.weight: gradient norm = 2545.529541 +resnet.layer4.2.bn2.weight: gradient norm = 182.802536 +resnet.layer4.2.bn2.bias: gradient norm = 152.100525 +resnet.layer4.2.conv3.weight: gradient norm = 1953.472778 +resnet.layer4.2.bn3.weight: gradient norm = 290.827362 +resnet.layer4.2.bn3.bias: gradient norm = 486.070496 +Warning: No gradient for resnet.fc.weight +Warning: No gradient for resnet.fc.bias +adjust1.weight: gradient norm = 160524.125000 +adjust1.bias: gradient norm = 46340.949219 +adjust2.weight: gradient norm = 46795.097656 +adjust2.bias: gradient norm = 16384.000000 +adjust3.weight: gradient norm = 7931.419434 +adjust3.bias: gradient norm = 5792.618652 +adjust4.weight: gradient norm = 30131.992188 +adjust4.bias: gradient norm = 1448.154663 +Gradient flow test passed for ResNetFeatureExtractor + +Testing gradient flow for LatentTokenEncoder +LatentTokenEncoder input shape: torch.Size([1, 3, 256, 256]) +After initial conv and activation: torch.Size([1, 64, 256, 256]) +After res_block 1: torch.Size([1, 128, 128, 128]) +After res_block 2: torch.Size([1, 256, 64, 64]) +After res_block 3: torch.Size([1, 512, 32, 32]) +After res_block 4: torch.Size([1, 512, 16, 16]) +After res_block 5: torch.Size([1, 512, 8, 8]) +After res_block 6: torch.Size([1, 512, 4, 4]) +After equalconv: torch.Size([1, 512, 4, 4]) +After global average pooling: torch.Size([1, 512]) +After linear layer 1: torch.Size([1, 512]) +After linear layer 2: torch.Size([1, 512]) +After linear layer 3: torch.Size([1, 512]) +After linear layer 4: torch.Size([1, 512]) +Final output: torch.Size([1, 32]) +conv1.weight: gradient norm = 77.054909 +conv1.bias: gradient norm = 8.658038 +res_blocks.0.conv1.conv.weight: gradient norm = 249.353256 +res_blocks.0.conv1.conv.bias: gradient norm = 0.000042 +res_blocks.0.conv1.bn.weight: gradient norm = 4.572390 +res_blocks.0.conv1.bn.bias: gradient norm = 3.931448 +res_blocks.0.conv2.conv.weight: gradient norm = 289.736969 +res_blocks.0.conv2.conv.bias: gradient norm = 0.000025 +res_blocks.0.conv2.bn.weight: gradient norm = 4.126070 +res_blocks.0.conv2.bn.bias: gradient norm = 3.600731 +res_blocks.0.skip_conv.conv.weight: gradient norm = 206.902969 +res_blocks.0.skip_conv.conv.bias: gradient norm = 0.000041 +res_blocks.0.skip_conv.bn.weight: gradient norm = 3.918743 +res_blocks.0.skip_conv.bn.bias: gradient norm = 3.363138 +res_blocks.1.conv1.conv.weight: gradient norm = 260.609314 +res_blocks.1.conv1.conv.bias: gradient norm = 0.000007 +res_blocks.1.conv1.bn.weight: gradient norm = 3.573316 +res_blocks.1.conv1.bn.bias: gradient norm = 2.891786 +res_blocks.1.conv2.conv.weight: gradient norm = 304.705994 +res_blocks.1.conv2.conv.bias: gradient norm = 0.000008 +res_blocks.1.conv2.bn.weight: gradient norm = 3.090661 +res_blocks.1.conv2.bn.bias: gradient norm = 2.567820 +res_blocks.1.skip_conv.conv.weight: gradient norm = 213.828308 +res_blocks.1.skip_conv.conv.bias: gradient norm = 0.000006 +res_blocks.1.skip_conv.bn.weight: gradient norm = 2.974518 +res_blocks.1.skip_conv.bn.bias: gradient norm = 2.654161 +res_blocks.2.conv1.conv.weight: gradient norm = 275.692627 +res_blocks.2.conv1.conv.bias: gradient norm = 0.000003 +res_blocks.2.conv1.bn.weight: gradient norm = 2.680028 +res_blocks.2.conv1.bn.bias: gradient norm = 2.330478 +res_blocks.2.conv2.conv.weight: gradient norm = 322.229462 +res_blocks.2.conv2.conv.bias: gradient norm = 0.000003 +res_blocks.2.conv2.bn.weight: gradient norm = 2.274737 +res_blocks.2.conv2.bn.bias: gradient norm = 1.934244 +res_blocks.2.skip_conv.conv.weight: gradient norm = 226.909836 +res_blocks.2.skip_conv.conv.bias: gradient norm = 0.000002 +res_blocks.2.skip_conv.bn.weight: gradient norm = 2.229406 +res_blocks.2.skip_conv.bn.bias: gradient norm = 1.926265 +res_blocks.3.conv1.conv.weight: gradient norm = 294.384094 +res_blocks.3.conv1.conv.bias: gradient norm = 0.000001 +res_blocks.3.conv1.bn.weight: gradient norm = 2.085438 +res_blocks.3.conv1.bn.bias: gradient norm = 1.841472 +res_blocks.3.conv2.conv.weight: gradient norm = 244.549484 +res_blocks.3.conv2.conv.bias: gradient norm = 0.000001 +res_blocks.3.conv2.bn.weight: gradient norm = 1.783224 +res_blocks.3.conv2.bn.bias: gradient norm = 1.472559 +res_blocks.3.skip_conv.conv.weight: gradient norm = 239.638962 +res_blocks.3.skip_conv.conv.bias: gradient norm = 0.000001 +res_blocks.3.skip_conv.bn.weight: gradient norm = 1.741726 +res_blocks.3.skip_conv.bn.bias: gradient norm = 1.479385 +res_blocks.4.conv1.conv.weight: gradient norm = 222.197418 +res_blocks.4.conv1.conv.bias: gradient norm = 0.000001 +res_blocks.4.conv1.bn.weight: gradient norm = 1.537410 +res_blocks.4.conv1.bn.bias: gradient norm = 1.350556 +res_blocks.4.conv2.conv.weight: gradient norm = 186.724823 +res_blocks.4.conv2.conv.bias: gradient norm = 0.000001 +res_blocks.4.conv2.bn.weight: gradient norm = 1.329779 +res_blocks.4.conv2.bn.bias: gradient norm = 1.189980 +res_blocks.4.skip_conv.conv.weight: gradient norm = 185.488098 +res_blocks.4.skip_conv.conv.bias: gradient norm = 0.000000 +res_blocks.4.skip_conv.bn.weight: gradient norm = 1.376067 +res_blocks.4.skip_conv.bn.bias: gradient norm = 1.273393 +res_blocks.5.conv1.conv.weight: gradient norm = 175.495209 +res_blocks.5.conv1.conv.bias: gradient norm = 0.000000 +res_blocks.5.conv1.bn.weight: gradient norm = 1.415485 +res_blocks.5.conv1.bn.bias: gradient norm = 1.291371 +res_blocks.5.conv2.conv.weight: gradient norm = 155.555298 +res_blocks.5.conv2.conv.bias: gradient norm = 0.000001 +res_blocks.5.conv2.bn.weight: gradient norm = 3.850402 +res_blocks.5.conv2.bn.bias: gradient norm = 4.637940 +res_blocks.5.skip_conv.conv.weight: gradient norm = 159.404724 +res_blocks.5.skip_conv.conv.bias: gradient norm = 0.000001 +res_blocks.5.skip_conv.bn.weight: gradient norm = 3.816106 +res_blocks.5.skip_conv.bn.bias: gradient norm = 4.661281 +equalconv.conv.bias: gradient norm = 9.044192 +equalconv.conv.weight_orig: gradient norm = 7.597575 +linear_layers.0.linear.bias: gradient norm = 6.444403 +linear_layers.0.linear.weight_orig: gradient norm = 7.694224 +linear_layers.1.linear.bias: gradient norm = 6.159342 +linear_layers.1.linear.weight_orig: gradient norm = 7.564744 +linear_layers.2.linear.bias: gradient norm = 5.789473 +linear_layers.2.linear.weight_orig: gradient norm = 6.603564 +linear_layers.3.linear.bias: gradient norm = 5.712759 +linear_layers.3.linear.weight_orig: gradient norm = 6.747198 +final_linear.linear.bias: gradient norm = 5.656854 +final_linear.linear.weight_orig: gradient norm = 6.717711 +Gradient flow test passed for LatentTokenEncoder + +Testing gradient flow for LatentTokenDecoder +const: gradient norm = 779.447388 +style_conv_layers.0.conv.weight: gradient norm = 539.260986 +style_conv_layers.0.style.weight: gradient norm = 7917.743652 +style_conv_layers.0.style.bias: gradient norm = 1298.246460 +style_conv_layers.1.conv.weight: gradient norm = 784.501160 +style_conv_layers.1.style.weight: gradient norm = 7587.056641 +style_conv_layers.1.style.bias: gradient norm = 1244.025024 +style_conv_layers.2.conv.weight: gradient norm = 749.168640 +style_conv_layers.2.style.weight: gradient norm = 8494.576172 +style_conv_layers.2.style.bias: gradient norm = 1392.828125 +style_conv_layers.3.conv.weight: gradient norm = 773.820923 +style_conv_layers.3.style.weight: gradient norm = 7538.504395 +style_conv_layers.3.style.bias: gradient norm = 1236.064209 +style_conv_layers.4.conv.weight: gradient norm = 670.328613 +style_conv_layers.4.style.weight: gradient norm = 7605.530762 +style_conv_layers.4.style.bias: gradient norm = 1247.054565 +style_conv_layers.5.conv.weight: gradient norm = 586.553040 +style_conv_layers.5.style.weight: gradient norm = 7297.344238 +style_conv_layers.5.style.bias: gradient norm = 1196.521729 +style_conv_layers.6.conv.weight: gradient norm = 484.850189 +style_conv_layers.6.style.weight: gradient norm = 5575.540527 +style_conv_layers.6.style.bias: gradient norm = 914.203064 +style_conv_layers.7.conv.weight: gradient norm = 370.123871 +style_conv_layers.7.style.weight: gradient norm = 4175.150391 +style_conv_layers.7.style.bias: gradient norm = 684.585693 +style_conv_layers.8.conv.weight: gradient norm = 404.413025 +style_conv_layers.8.style.weight: gradient norm = 4310.747559 +style_conv_layers.8.style.bias: gradient norm = 706.819519 +style_conv_layers.9.conv.weight: gradient norm = 412.654205 +style_conv_layers.9.style.weight: gradient norm = 4258.424805 +style_conv_layers.9.style.bias: gradient norm = 698.239868 +style_conv_layers.10.conv.weight: gradient norm = 245.772919 +style_conv_layers.10.style.weight: gradient norm = 3062.014404 +style_conv_layers.10.style.bias: gradient norm = 502.068481 +style_conv_layers.11.conv.weight: gradient norm = 292.453766 +style_conv_layers.11.style.weight: gradient norm = 3329.650391 +style_conv_layers.11.style.bias: gradient norm = 545.951904 +style_conv_layers.12.conv.weight: gradient norm = 232.569595 +style_conv_layers.12.style.weight: gradient norm = 2870.214111 +style_conv_layers.12.style.bias: gradient norm = 470.619781 +Gradient flow test passed for LatentTokenDecoder + +Testing ImplicitMotionAlignment modules +Testing ImplicitMotionAlignment for dim=128, spatial_size=64 +cross_attention.to_q.weight: gradient norm = 830.171265 +cross_attention.to_q.bias: gradient norm = 2252.421143 +cross_attention.to_k.weight: gradient norm = 2901.171387 +cross_attention.to_k.bias: gradient norm = 0.000049 +cross_attention.to_v.weight: gradient norm = 75491.593750 +cross_attention.to_v.bias: gradient norm = 379057.437500 +cross_attention.to_out.weight: gradient norm = 420455.531250 +cross_attention.to_out.bias: gradient norm = 699836.250000 +blocks.0.attention.in_proj_weight: gradient norm = 657236.187500 +blocks.0.attention.in_proj_bias: gradient norm = 58207.203125 +blocks.0.attention.out_proj.weight: gradient norm = 940164.500000 +blocks.0.attention.out_proj.bias: gradient norm = 99538.054688 +blocks.0.mlp.0.weight: gradient norm = 377064.687500 +blocks.0.mlp.0.bias: gradient norm = 33329.140625 +blocks.0.mlp.2.weight: gradient norm = 718948.000000 +blocks.0.mlp.2.bias: gradient norm = 89722.859375 +blocks.0.norm1.weight: gradient norm = 40633.195312 +blocks.0.norm1.bias: gradient norm = 40290.886719 +blocks.0.norm2.weight: gradient norm = 18535.234375 +blocks.0.norm2.bias: gradient norm = 19941.255859 +blocks.1.attention.in_proj_weight: gradient norm = 455065.500000 +blocks.1.attention.in_proj_bias: gradient norm = 40223.574219 +blocks.1.attention.out_proj.weight: gradient norm = 515086.906250 +blocks.1.attention.out_proj.bias: gradient norm = 71307.281250 +blocks.1.mlp.0.weight: gradient norm = 262347.656250 +blocks.1.mlp.0.bias: gradient norm = 23188.937500 +blocks.1.mlp.2.weight: gradient norm = 489467.781250 +blocks.1.mlp.2.bias: gradient norm = 65084.519531 +blocks.1.norm1.weight: gradient norm = 28883.685547 +blocks.1.norm1.bias: gradient norm = 29371.042969 +blocks.1.norm2.weight: gradient norm = 14049.950195 +blocks.1.norm2.bias: gradient norm = 15103.478516 +blocks.2.attention.in_proj_weight: gradient norm = 363645.718750 +blocks.2.attention.in_proj_bias: gradient norm = 32142.544922 +blocks.2.attention.out_proj.weight: gradient norm = 424813.187500 +blocks.2.attention.out_proj.bias: gradient norm = 55891.292969 +blocks.2.mlp.0.weight: gradient norm = 221795.640625 +blocks.2.mlp.0.bias: gradient norm = 19604.423828 +blocks.2.mlp.2.weight: gradient norm = 441987.250000 +blocks.2.mlp.2.bias: gradient norm = 54326.875000 +blocks.2.norm1.weight: gradient norm = 20070.251953 +blocks.2.norm1.bias: gradient norm = 21361.015625 +blocks.2.norm2.weight: gradient norm = 10427.683594 +blocks.2.norm2.bias: gradient norm = 11241.589844 +blocks.3.attention.in_proj_weight: gradient norm = 293086.843750 +blocks.3.attention.in_proj_bias: gradient norm = 25905.779297 +blocks.3.attention.out_proj.weight: gradient norm = 374198.656250 +blocks.3.attention.out_proj.bias: gradient norm = 47259.253906 +blocks.3.mlp.0.weight: gradient norm = 190921.375000 +blocks.3.mlp.0.bias: gradient norm = 16875.392578 +blocks.3.mlp.2.weight: gradient norm = 367157.281250 +blocks.3.mlp.2.bias: gradient norm = 46340.949219 +blocks.3.norm1.weight: gradient norm = 22470.152344 +blocks.3.norm1.bias: gradient norm = 21038.578125 +blocks.3.norm2.weight: gradient norm = 8983.610352 +blocks.3.norm2.bias: gradient norm = 8902.168945 +Gradient flow test passed for ImplicitMotionAlignment with dim=128 +Testing ImplicitMotionAlignment for dim=256, spatial_size=32 +cross_attention.to_q.weight: gradient norm = 1242.162476 +cross_attention.to_q.bias: gradient norm = 1921.085205 +cross_attention.to_k.weight: gradient norm = 3077.728271 +cross_attention.to_k.bias: gradient norm = 0.000031 +cross_attention.to_v.weight: gradient norm = 90357.109375 +cross_attention.to_v.bias: gradient norm = 187364.156250 +cross_attention.to_out.weight: gradient norm = 212677.296875 +cross_attention.to_out.bias: gradient norm = 340579.593750 +blocks.0.attention.in_proj_weight: gradient norm = 330534.437500 +blocks.0.attention.in_proj_bias: gradient norm = 20790.921875 +blocks.0.attention.out_proj.weight: gradient norm = 420269.468750 +blocks.0.attention.out_proj.bias: gradient norm = 36665.308594 +blocks.0.mlp.0.weight: gradient norm = 189268.734375 +blocks.0.mlp.0.bias: gradient norm = 11830.186523 +blocks.0.mlp.2.weight: gradient norm = 358124.156250 +blocks.0.mlp.2.bias: gradient norm = 32148.673828 +blocks.0.norm1.weight: gradient norm = 14946.377930 +blocks.0.norm1.bias: gradient norm = 14309.821289 +blocks.0.norm2.weight: gradient norm = 6713.513184 +blocks.0.norm2.bias: gradient norm = 6847.319824 +blocks.1.attention.in_proj_weight: gradient norm = 220829.062500 +blocks.1.attention.in_proj_bias: gradient norm = 13802.628906 +blocks.1.attention.out_proj.weight: gradient norm = 272784.562500 +blocks.1.attention.out_proj.bias: gradient norm = 23890.263672 +blocks.1.mlp.0.weight: gradient norm = 130541.664062 +blocks.1.mlp.0.bias: gradient norm = 8159.210449 +blocks.1.mlp.2.weight: gradient norm = 265748.250000 +blocks.1.mlp.2.bias: gradient norm = 23132.369141 +blocks.1.norm1.weight: gradient norm = 10471.356445 +blocks.1.norm1.bias: gradient norm = 10546.423828 +blocks.1.norm2.weight: gradient norm = 4646.239746 +blocks.1.norm2.bias: gradient norm = 4485.238281 +blocks.2.attention.in_proj_weight: gradient norm = 181167.000000 +blocks.2.attention.in_proj_bias: gradient norm = 11323.332031 +blocks.2.attention.out_proj.weight: gradient norm = 224253.875000 +blocks.2.attention.out_proj.bias: gradient norm = 19564.628906 +blocks.2.mlp.0.weight: gradient norm = 110451.617188 +blocks.2.mlp.0.bias: gradient norm = 6903.474609 +blocks.2.mlp.2.weight: gradient norm = 211880.218750 +blocks.2.mlp.2.bias: gradient norm = 18862.037109 +blocks.2.norm1.weight: gradient norm = 7891.067383 +blocks.2.norm1.bias: gradient norm = 7728.061523 +blocks.2.norm2.weight: gradient norm = 4577.380859 +blocks.2.norm2.bias: gradient norm = 4303.497070 +blocks.3.attention.in_proj_weight: gradient norm = 150842.578125 +blocks.3.attention.in_proj_bias: gradient norm = 9427.898438 +blocks.3.attention.out_proj.weight: gradient norm = 185588.765625 +blocks.3.attention.out_proj.bias: gradient norm = 16886.332031 +blocks.3.mlp.0.weight: gradient norm = 91945.140625 +blocks.3.mlp.0.bias: gradient norm = 5746.728027 +blocks.3.mlp.2.weight: gradient norm = 195401.046875 +blocks.3.mlp.2.bias: gradient norm = 16384.000000 +blocks.3.norm1.weight: gradient norm = 6441.775391 +blocks.3.norm1.bias: gradient norm = 6424.875000 +blocks.3.norm2.weight: gradient norm = 3628.647949 +blocks.3.norm2.bias: gradient norm = 3575.992432 +Gradient flow test passed for ImplicitMotionAlignment with dim=256 +Testing ImplicitMotionAlignment for dim=512, spatial_size=16 +cross_attention.to_q.weight: gradient norm = 2377.416504 +cross_attention.to_q.bias: gradient norm = 1611.652100 +cross_attention.to_k.weight: gradient norm = 3261.309814 +cross_attention.to_k.bias: gradient norm = 0.000029 +cross_attention.to_v.weight: gradient norm = 116167.125000 +cross_attention.to_v.bias: gradient norm = 80549.664062 +cross_attention.to_out.weight: gradient norm = 151539.265625 +cross_attention.to_out.bias: gradient norm = 144888.562500 +blocks.0.attention.in_proj_weight: gradient norm = 167920.296875 +blocks.0.attention.in_proj_bias: gradient norm = 7592.502930 +blocks.0.attention.out_proj.weight: gradient norm = 206169.390625 +blocks.0.attention.out_proj.bias: gradient norm = 13067.125977 +blocks.0.mlp.0.weight: gradient norm = 97404.921875 +blocks.0.mlp.0.bias: gradient norm = 4305.809082 +blocks.0.mlp.2.weight: gradient norm = 177578.906250 +blocks.0.mlp.2.bias: gradient norm = 11957.936523 +blocks.0.norm1.weight: gradient norm = 5181.068848 +blocks.0.norm1.bias: gradient norm = 5329.598145 +blocks.0.norm2.weight: gradient norm = 2524.885010 +blocks.0.norm2.bias: gradient norm = 2552.831055 +blocks.1.attention.in_proj_weight: gradient norm = 116449.585938 +blocks.1.attention.in_proj_bias: gradient norm = 5147.482910 +blocks.1.attention.out_proj.weight: gradient norm = 144568.343750 +blocks.1.attention.out_proj.bias: gradient norm = 8931.921875 +blocks.1.mlp.0.weight: gradient norm = 65922.906250 +blocks.1.mlp.0.bias: gradient norm = 2913.866455 +blocks.1.mlp.2.weight: gradient norm = 135285.375000 +blocks.1.mlp.2.bias: gradient norm = 8303.961914 +blocks.1.norm1.weight: gradient norm = 3756.583984 +blocks.1.norm1.bias: gradient norm = 3501.456055 +blocks.1.norm2.weight: gradient norm = 1656.396484 +blocks.1.norm2.bias: gradient norm = 1739.926636 +blocks.2.attention.in_proj_weight: gradient norm = 92338.960938 +blocks.2.attention.in_proj_bias: gradient norm = 4081.290039 +blocks.2.attention.out_proj.weight: gradient norm = 114180.890625 +blocks.2.attention.out_proj.bias: gradient norm = 7064.245117 +blocks.2.mlp.0.weight: gradient norm = 54397.640625 +blocks.2.mlp.0.bias: gradient norm = 2404.385986 +blocks.2.mlp.2.weight: gradient norm = 107140.007812 +blocks.2.mlp.2.bias: gradient norm = 6784.312500 +blocks.2.norm1.weight: gradient norm = 2845.508057 +blocks.2.norm1.bias: gradient norm = 2693.060791 +blocks.2.norm2.weight: gradient norm = 1350.523926 +blocks.2.norm2.bias: gradient norm = 1371.093750 +blocks.3.attention.in_proj_weight: gradient norm = 75547.195312 +blocks.3.attention.in_proj_bias: gradient norm = 3339.017334 +blocks.3.attention.out_proj.weight: gradient norm = 96877.164062 +blocks.3.attention.out_proj.bias: gradient norm = 5954.127441 +blocks.3.mlp.0.weight: gradient norm = 47230.671875 +blocks.3.mlp.0.bias: gradient norm = 2087.548096 +blocks.3.mlp.2.weight: gradient norm = 87877.218750 +blocks.3.mlp.2.bias: gradient norm = 5792.618652 +blocks.3.norm1.weight: gradient norm = 2362.469482 +blocks.3.norm1.bias: gradient norm = 2427.172607 +blocks.3.norm2.weight: gradient norm = 1180.518188 +blocks.3.norm2.bias: gradient norm = 1207.247070 +Gradient flow test passed for ImplicitMotionAlignment with dim=512 +Testing ImplicitMotionAlignment for dim=512, spatial_size=8 +cross_attention.to_q.weight: gradient norm = 1483.171021 +cross_attention.to_q.bias: gradient norm = 503.627716 +cross_attention.to_k.weight: gradient norm = 1681.210327 +cross_attention.to_k.bias: gradient norm = 0.000013 +cross_attention.to_v.weight: gradient norm = 36307.039062 +cross_attention.to_v.bias: gradient norm = 13140.595703 +cross_attention.to_out.weight: gradient norm = 36972.554688 +cross_attention.to_out.bias: gradient norm = 22684.265625 +blocks.0.attention.in_proj_weight: gradient norm = 36526.507812 +blocks.0.attention.in_proj_bias: gradient norm = 1675.603149 +blocks.0.attention.out_proj.weight: gradient norm = 48570.421875 +blocks.0.attention.out_proj.bias: gradient norm = 2935.032471 +blocks.0.mlp.0.weight: gradient norm = 21888.810547 +blocks.0.mlp.0.bias: gradient norm = 967.929993 +blocks.0.mlp.2.weight: gradient norm = 42141.007812 +blocks.0.mlp.2.bias: gradient norm = 2692.202881 +blocks.0.norm1.weight: gradient norm = 1168.022461 +blocks.0.norm1.bias: gradient norm = 1184.776978 +blocks.0.norm2.weight: gradient norm = 545.710815 +blocks.0.norm2.bias: gradient norm = 545.764465 +blocks.1.attention.in_proj_weight: gradient norm = 26525.279297 +blocks.1.attention.in_proj_bias: gradient norm = 1172.923828 +blocks.1.attention.out_proj.weight: gradient norm = 35669.054688 +blocks.1.attention.out_proj.bias: gradient norm = 2056.190430 +blocks.1.mlp.0.weight: gradient norm = 15664.407227 +blocks.1.mlp.0.bias: gradient norm = 692.517212 +blocks.1.mlp.2.weight: gradient norm = 30074.457031 +blocks.1.mlp.2.bias: gradient norm = 1949.438843 +blocks.1.norm1.weight: gradient norm = 761.614319 +blocks.1.norm1.bias: gradient norm = 794.153076 +blocks.1.norm2.weight: gradient norm = 374.636383 +blocks.1.norm2.bias: gradient norm = 387.756561 +blocks.2.attention.in_proj_weight: gradient norm = 21088.666016 +blocks.2.attention.in_proj_bias: gradient norm = 932.268433 +blocks.2.attention.out_proj.weight: gradient norm = 26868.851562 +blocks.2.attention.out_proj.bias: gradient norm = 1685.688232 +blocks.2.mlp.0.weight: gradient norm = 12992.422852 +blocks.2.mlp.0.bias: gradient norm = 574.340698 +blocks.2.mlp.2.weight: gradient norm = 26733.482422 +blocks.2.mlp.2.bias: gradient norm = 1651.312622 +blocks.2.norm1.weight: gradient norm = 637.058960 +blocks.2.norm1.bias: gradient norm = 656.740479 +blocks.2.norm2.weight: gradient norm = 316.993774 +blocks.2.norm2.bias: gradient norm = 329.310150 +blocks.3.attention.in_proj_weight: gradient norm = 19976.527344 +blocks.3.attention.in_proj_bias: gradient norm = 883.036926 +blocks.3.attention.out_proj.weight: gradient norm = 23130.628906 +blocks.3.attention.out_proj.bias: gradient norm = 1480.547974 +blocks.3.mlp.0.weight: gradient norm = 11456.120117 +blocks.3.mlp.0.bias: gradient norm = 506.405090 +blocks.3.mlp.2.weight: gradient norm = 23175.083984 +blocks.3.mlp.2.bias: gradient norm = 1448.154663 +blocks.3.norm1.weight: gradient norm = 642.471375 +blocks.3.norm1.bias: gradient norm = 617.867859 +blocks.3.norm2.weight: gradient norm = 265.104340 +blocks.3.norm2.bias: gradient norm = 281.729309 +Gradient flow test passed for ImplicitMotionAlignment with dim=512 diff --git a/model.py b/model.py index 45ccf38..5c06618 100644 --- a/model.py +++ b/model.py @@ -16,7 +16,7 @@ # from common import DownConvResBlock,UpConvResBlock import colored_traceback.auto # makes terminal show color coded output when crash from framedecoder import EnhancedFrameDecoder -DEBUG = True +DEBUG = False def debug_print(*args, **kwargs): if DEBUG: print(*args, **kwargs) diff --git a/test_gradient_flow.py b/test_gradient_flow.py index a9e80fb..66e1852 100644 --- a/test_gradient_flow.py +++ b/test_gradient_flow.py @@ -145,8 +145,9 @@ def run_all_gradient_flow_tests(): test_implicit_motion_alignment_modules() # Test FrameDecoder - frame_decoder = FrameDecoder() - test_gradient_flow(frame_decoder, (1, 512, 32, 32)) + # frame_decoder = FrameDecoder() + # # [torch.Size([1, 512, 32, 32]), torch.Size([1, 512, 32, 32]), torch.Size([1, 512, 32, 32]), torch.Size([1, 512, 32, 32])] + # test_gradient_flow(frame_decoder, (1, 512, 32, 32)) if __name__ == "__main__": run_all_gradient_flow_tests() \ No newline at end of file From 00fe54fe648164bdb0e17739c7739af4a7003026 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 16:12:06 +1000 Subject: [PATCH 081/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 10985 -> 10985 bytes config.yaml | 1 + model.py | 5 +++-- train.py | 19 +++++++++++++++++-- vit.py | 19 +++++++++++++++++++ 5 files changed, 40 insertions(+), 4 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 1cdb2c91880296496f05a536bd79439e6da0cc48..38d2052a4ae686b73579dbc440dc9572e1437168 100644 GIT binary patch delta 35 pcmaDE`ZAPvIWI340}$K}-jVixBQGl>quXX-Mh#6y-pMvvYXQv23TFTS delta 35 pcmaDE`ZAPvIWI340}$8+??`*Mk(ZT`F=VqaqlP9U&tw~|wE(}{3Dy7r diff --git a/config.yaml b/config.yaml index 11be71e..b5a09b6 100644 --- a/config.yaml +++ b/config.yaml @@ -6,6 +6,7 @@ model: use_resnet_feature: False use_mlgffn: False use_enhanced_generator: False + use_skip: False # Training parameters training: initial_video_repeat: 5 diff --git a/model.py b/model.py index 5c06618..085c18e 100644 --- a/model.py +++ b/model.py @@ -397,7 +397,7 @@ def forward(self, token, condition): For each scale, aligns the reference features to the current frame using the ImplicitMotionAlignment module. ''' class IMFModel(nn.Module): - def __init__(self,use_resnet_feature=False,use_mlgffn=False,use_enhanced_generator=False, latent_dim=32, base_channels=64, num_layers=4, noise_level=0.1, style_mix_prob=0.5): + def __init__(self,use_resnet_feature=False,use_mlgffn=False,use_skip=False,use_enhanced_generator=False, latent_dim=32, base_channels=64, num_layers=4, noise_level=0.1, style_mix_prob=0.5): super().__init__() self.encoder_dims = [64, 128, 256, 512] @@ -412,10 +412,11 @@ def __init__(self,use_resnet_feature=False,use_mlgffn=False,use_enhanced_generat FeatureExtractor = ResNetFeatureExtractor if use_resnet_feature else DenseFeatureEncoder self.dense_feature_encoder = FeatureExtractor(output_channels=self.motion_dims) + IMF = ImplicitMotionAlignmentWithSkip if use_skip else ImplicitMotionAlignment self.implicit_motion_alignment = nn.ModuleList() for i in range(num_layers): dim = self.motion_dims[i] - model = ImplicitMotionAlignment( + model = IMF( feature_dim=dim, motion_dim=dim, depth=4, diff --git a/train.py b/train.py index 03747f9..196c894 100644 --- a/train.py +++ b/train.py @@ -51,13 +51,26 @@ def get_noise_magnitude(epoch, max_epochs, initial_magnitude=0.1, final_magnitud """ return max(final_magnitude, initial_magnitude - (initial_magnitude - final_magnitude) * (epoch / max_epochs)) +def get_layer_wise_learning_rates(model): + params = [] + params.append({'params': model.dense_feature_encoder.parameters(), 'lr': 1e-4}) + params.append({'params': model.latent_token_encoder.parameters(), 'lr': 5e-4}) + params.append({'params': model.latent_token_decoder.parameters(), 'lr': 5e-4}) + params.append({'params': model.implicit_motion_alignment.parameters(), 'lr': 2e-4}) + params.append({'params': model.frame_decoder.parameters(), 'lr': 1e-4}) + params.append({'params': model.mapping_network.parameters(), 'lr': 1e-5}) + return params + def train(config, model, discriminator, train_dataloader, val_loader, accelerator): + # layerwise params + layer_wise_params = get_layer_wise_learning_rates(model) + # Generator optimizer optimizer_g = AdamW( - model.parameters(), + layer_wise_params, lr=config.training.learning_rate_g, betas=(config.optimizer.beta1, config.optimizer.beta2), weight_decay=config.training.weight_decay @@ -368,8 +381,10 @@ def main(): num_layers=config.model.num_layers, use_resnet_feature=config.model.use_resnet_feature, use_mlgffn=config.model.use_mlgffn, - use_enhanced_generator=config.model.use_enhanced_generator + use_enhanced_generator=config.model.use_enhanced_generator, + use_skip=config.model.use_skip ) + add_gradient_hooks(model) # discriminator = PatchDiscriminator(ndf=config.discriminator.ndf) Original diff --git a/vit.py b/vit.py index 328d31d..20ca281 100644 --- a/vit.py +++ b/vit.py @@ -55,6 +55,25 @@ def forward(self, x): output = x_reshaped.permute(1, 2, 0).view(B, C, H, W) return output + +class ImplicitMotionAlignmentWithSkip(nn.Module): + def __init__(self, feature_dim, motion_dim, depth, num_heads, window_size, mlp_ratio): + super().__init__() + self.cross_attention = CrossAttention(feature_dim, motion_dim) + self.blocks = nn.ModuleList([ + TransformerBlock(motion_dim, num_heads, window_size, mlp_ratio) + for _ in range(depth) + ]) + self.skip_connections = nn.ModuleList([ + nn.Linear(motion_dim, motion_dim) for _ in range(depth) + ]) + + def forward(self, m_c, m_r, f_r): + x = self.cross_attention(m_c, m_r, f_r) + for block, skip in zip(self.blocks, self.skip_connections): + x = block(x) + skip(x) + return x + class ImplicitMotionAlignment(nn.Module): def __init__(self, feature_dim, motion_dim, depth=2, num_heads=8, window_size=4, mlp_ratio=4, use_mlgffn=False): super().__init__() From 23e50c149a4602aa561aa33387b3d75024ecdda3 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 16:13:39 +1000 Subject: [PATCH 082/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model.py | 22 +++++++++++----------- train.py | 35 ++++++++++++++++++++++++++++++----- 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/model.py b/model.py index 085c18e..44c0859 100644 --- a/model.py +++ b/model.py @@ -319,30 +319,30 @@ def forward(self, features): reshaped_feat = feat reshaped_features.append(reshaped_feat) - print(f"Reshaped features: {[f.shape for f in reshaped_features]}") + debug_print(f"Reshaped features: {[f.shape for f in reshaped_features]}") x = reshaped_features[-1] # Start with the smallest feature map - print(f" Initial x shape: {x.shape}") + debug_print(f" Initial x shape: {x.shape}") for i in range(len(self.upconv_blocks)): - print(f"\n Processing upconv_block {i+1}") + debug_print(f"\n Processing upconv_block {i+1}") x = self.upconv_blocks[i](x) - print(f" After upconv_block {i+1}: {x.shape}") + debug_print(f" After upconv_block {i+1}: {x.shape}") if i < len(self.feat_blocks): - print(f" Processing feat_block {i+1}") + debug_print(f" Processing feat_block {i+1}") feat_input = reshaped_features[-(i+2)] - print(f" feat_block {i+1} input shape: {feat_input.shape}") + debug_print(f" feat_block {i+1} input shape: {feat_input.shape}") feat = self.feat_blocks[i](feat_input) - print(f" feat_block {i+1} output shape: {feat.shape}") + debug_print(f" feat_block {i+1} output shape: {feat.shape}") - print(f" Concatenating: x {x.shape} and feat {feat.shape}") + debug_print(f" Concatenating: x {x.shape} and feat {feat.shape}") x = torch.cat([x, feat], dim=1) - print(f" After concatenation: {x.shape}") + debug_print(f" After concatenation: {x.shape}") - print("\n Applying final convolution") + debug_print("\n Applying final convolution") x = self.final_conv(x) - print(f" FrameDecoder final output shape: {x.shape}") + debug_print(f" FrameDecoder final output shape: {x.shape}") return x ''' diff --git a/train.py b/train.py index 196c894..0df6a7f 100644 --- a/train.py +++ b/train.py @@ -323,8 +323,9 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato scheduler_d.step(avg_d_loss) # Logging if accelerator.is_main_process: - wandb.log({ - "ema":current_decay, + # Existing logs + log_dict = { + "ema": current_decay, "noise_magnitude": noise_magnitude, "batch_g_loss": g_loss.item(), "batch_d_loss": d_loss.item(), @@ -332,9 +333,33 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato "perceptual_loss": l_v.item(), "gan_loss": g_loss_gan.item(), "batch": batch_idx + epoch * len(train_dataloader), - "lr_g": optimizer_g.param_groups[0]['lr'], - "lr_d": optimizer_d.param_groups[0]['lr'] - }) + } + + # Add layer-wise learning rates + for i, param_group in enumerate(optimizer_g.param_groups): + log_dict[f"lr_g_group_{i}"] = param_group['lr'] + log_dict["lr_d"] = optimizer_d.param_groups[0]['lr'] + + # Add gradient norms for each component of the generator + components = [ + 'dense_feature_encoder', + 'latent_token_encoder', + 'latent_token_decoder', + 'implicit_motion_alignment', + 'frame_decoder', + 'mapping_network' + ] + for component in components: + params = getattr(model, component).parameters() + grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in params if p.grad is not None])) + log_dict[f"grad_norm_{component}"] = grad_norm.item() + + # Add gradient norm for the discriminator + disc_grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in discriminator.parameters() if p.grad is not None])) + log_dict["grad_norm_discriminator"] = disc_grad_norm.item() + + # Log to wandb + wandb.log(log_dict) # Log gradient flow for generator and discriminator criterion = [perceptual_loss_fn,pixel_loss_fn] From 92d7f84ddb6218e006a2c8205860dcd72cbedbbf Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 16:14:30 +1000 Subject: [PATCH 083/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model.py b/model.py index 44c0859..67fc08c 100644 --- a/model.py +++ b/model.py @@ -307,7 +307,7 @@ def __init__(self): def forward(self, features): debug_print(f"🎒 FrameDecoder input shapes") for f in features: - print(f"f:{f.shape}") + debug_print(f"f:{f.shape}") # Reshape features reshaped_features = [] for feat in features: From 664150ff8b9376f78f469db9d86b41113bc57cd4 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 16:18:58 +1000 Subject: [PATCH 084/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 0df6a7f..08493e4 100644 --- a/train.py +++ b/train.py @@ -351,12 +351,20 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato ] for component in components: params = getattr(model, component).parameters() - grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in params if p.grad is not None])) - log_dict[f"grad_norm_{component}"] = grad_norm.item() + grad_norms = [torch.norm(p.grad.detach()) for p in params if p.grad is not None] + if grad_norms: + grad_norm = torch.norm(torch.stack(grad_norms)) + log_dict[f"grad_norm_{component}"] = grad_norm.item() + else: + log_dict[f"grad_norm_{component}"] = 0.0 # Add gradient norm for the discriminator - disc_grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in discriminator.parameters() if p.grad is not None])) - log_dict["grad_norm_discriminator"] = disc_grad_norm.item() + disc_grad_norms = [torch.norm(p.grad.detach()) for p in discriminator.parameters() if p.grad is not None] + if disc_grad_norms: + disc_grad_norm = torch.norm(torch.stack(disc_grad_norms)) + log_dict["grad_norm_discriminator"] = disc_grad_norm.item() + else: + log_dict["grad_norm_discriminator"] = 0.0 # Log to wandb wandb.log(log_dict) From b819e8e6c92f0a6469987ddd2048c9c2c9b2aef0 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 16:20:52 +1000 Subject: [PATCH 085/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/train.py b/train.py index 08493e4..92775c4 100644 --- a/train.py +++ b/train.py @@ -336,12 +336,7 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato } # Add layer-wise learning rates - for i, param_group in enumerate(optimizer_g.param_groups): - log_dict[f"lr_g_group_{i}"] = param_group['lr'] - log_dict["lr_d"] = optimizer_d.param_groups[0]['lr'] - - # Add gradient norms for each component of the generator - components = [ + component_names = [ 'dense_feature_encoder', 'latent_token_encoder', 'latent_token_decoder', @@ -349,7 +344,12 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato 'frame_decoder', 'mapping_network' ] - for component in components: + for i, param_group in enumerate(optimizer_g.param_groups): + log_dict[f"lr_g_{component_names[i]}"] = param_group['lr'] + log_dict["lr_d"] = optimizer_d.param_groups[0]['lr'] + + # Add gradient norms for each component of the generator + for component in component_names: params = getattr(model, component).parameters() grad_norms = [torch.norm(p.grad.detach()) for p in params if p.grad is not None] if grad_norms: From 21f043baa3eb3f944f330469bc6752ed096d4199 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 16:37:57 +1000 Subject: [PATCH 086/152] ok --- __pycache__/helper.cpython-311.pyc | Bin 25384 -> 25442 bytes __pycache__/stylegan.cpython-311.pyc | Bin 8276 -> 6627 bytes helper.py | 3 +- stylegan.py | 45 --------------------------- train.py | 12 +++---- 5 files changed, 8 insertions(+), 52 deletions(-) diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index 20a2d7a0012cc4e888515fba7399011efc0695a2..aca80ef083b03b26e66ab7d2dbd15be2c630ff70 100644 GIT binary patch delta 1133 zcmZ9KZAep57{~X_ZO+X*x2C1;<=#2XRzU@BHnGIYvK9n0oJN^e_G0GZ?E0a|-X)A6 z5Z9 z5H1LULB${2x?5WxToZye=vA6+>j5OZYcEpE7QAOKR38R>Yj+4%a{hbT3}|$uD@Lq# z&G9-38%`_(3lIA)}!nEp3B`_GVrzZ|x-6f?S<>52tcI_IT%EZFUIrzz9WWuI;a zm0UX2u1eLGn-se;<*6+{HVIBgGr1oABO>@YlR=Q4!*WU8LqvM1vhYWu+lh4 zrOpsU91K>#0KN}a>S82(!qU3c8j)w<$=U+!t+Q(9^PozJ)(vIYUvi5_$MVSO<}xcL z%%QLZ_k{eg8mqp40aw}-g<5CT{Ft(vc6li@DA&0P%O6en=n|8N{

)< zXsU$sxVY)4Ho)h*m9eJn0IzX=^K5f6VK&w7B4lD)^FW|QCZ*OglW|P`JYuOti=(nc@ueK_wq&w0*y&wH0=VQ3ck zL7vxfY#pt7*4i@8SD8Up8qKF5EKcU}gC$uD+$EJ)CUa)5LCbK2#9&p?U`vMC%IhYrNB~~-U zVvf~_E^Eeyb>EsC(sL}%Ta)1$(RgQYkN=* zyIo+(*_#$qGGVvP5mGIQRWoV2qvSj`jw7N;B|FhjW5g)Yz&nF7Et^3>Ow+c+_o=;+ zatXRc2OnJ3h67@1h)2XX;&DEgDJ}Qj32ZF1r$qict@GtLYTpZ2m9O@E06S)+<|}Qf z#h`o4vRv%*4C7eZ^Z2rPZTW(>d`?@wXcp(RCbpIC^w@Cy2;)ej1mig1$kz>$NieQ_ za@6bKEtcgzfNA`ZdjdY=;kth6@KClXOH<6IrIdt8F=VL zMae+tIehPQXx7P`$C!fKFpLiihF}V73&+%}%sr#%izd|Y1gH0u>3itZTHA$qy}R7DUL9ymXtRdx>lF2cCt*b zGFi7T97OzmI>i=6^?8D5f`cTb^;~SeaNaF@Bp>doFNV9=S>L6uWbrNq8rlHfVMk-8 z!Jm*x1-t|r^fit+=y&ixt3+d#Xp|DoO`;+AyOO8}sS-Pz1gKSRHd!IA6u|V23&IMV HZSMOE#d`@6 diff --git a/__pycache__/stylegan.cpython-311.pyc b/__pycache__/stylegan.cpython-311.pyc index c76b3cd6881c87f97d36c19be32f108d45bc04a0..241966a3a52ef3472f931d134b04a0c43b1e6d9f 100644 GIT binary patch delta 391 zcmccO@YtAdIWI340}vEN?nu+&n#d=?xNf4l10(Omj2N+HObiUGffxd!xKbE{88rDe zG0O8#mS*If>?n4e%?qT}dorK6GNb)uBk`wkeUdr4TcxdJ7NgD3aNN^D*% z5yNNzR_qVb5db25zj7k>( delta 1704 zcma)6Z)jUp6utf=A~`gEZJVt zR?DmDU}7c#Ew`{LL;TbU6Pfze5Bs2ueQ}Khj1LME-3Py=v;{wj=f2l$YsbX%xW9Yu zIrsd|yXV~Z{&=oFC|_@D6A{Rrmh!#5hmQe1 zNf-Z<&)%X?*)@9;& zL&pXT;gz|`GlfDfs#=z4MAm5a4KbTfrpr^Qd_I#ioQ2s^GTmCtWHc})i`j*Y_B!1w zJz|L(j@`G58lqa5&lgkYFXS=?KQUX(s0Ihl0TZK!-3l9MDN*eOU@OW12Q4XE(~?H- zw|TGy_k)|(mrQ2!*-|q3B8@`OqraEXju#2?6?@hGtHb|`!(Vd*>W;v=BXG~92w}CvTP3l~_%Q)s@(W z;$D$f-&}UBxGXhZQ@ZO)_okf-$lC~@ak0<=QOT{(yo zBjB+hp~ppfkD=yPy^q0Sq5WaK zCkonpN>f!Z7Y_q0p@!`KWcuUj)y|q6uFK&H+cJ2r6siSSy4)2}lgh9eZtlp?kaiqY z?Ip@Uzo|hh(J(bqpx)v;0MN$c>E37)AaQ6auYoX&R~}JG|@ij*t6Co zGab3)=pu}A+crj*DRNio!8x*}+{a_&xAp-%PMq%JHVvG`IC;b8|6X;U!~N#w42_te zFis&w!A)TjAi`NirN03!OPZdjZ64HO)g&40cnycgsB0gI8K0((DRRBD-}V8GnKEN` zIy-P+UyQlajF}=mzFzzxIpxbe^DYEM9NHmD(@TN2$@(EoQ|&y30)>m@cb`Wbr7DFP z@{dnzqgn)OUOGu+(SHJuk{|pb=NuSlyNZV}`O|-=FYqjeB0mnFGAqI+wQoOkj$!gm H*Yp1Z6}^`8 diff --git a/helper.py b/helper.py index 7a0ee32..36ffd64 100644 --- a/helper.py +++ b/helper.py @@ -124,8 +124,9 @@ def log_grad_flow(named_parameters, global_step): # Normalize gradients max_grad = max(grads) if max_grad == 0: - print("☠☠☠ Warning: All gradients are zero. ☠☠☠") + print("👿👿👿 Warning: All gradients are zero. 👿👿👿") normalized_grads = grads # Use unnormalized grads if max is zero + raise ValueError(f"👿👿👿 Warning: All gradients are zero. 👿👿👿") else: normalized_grads = [g / max_grad for g in grads] diff --git a/stylegan.py b/stylegan.py index 29571e4..e90ddd3 100644 --- a/stylegan.py +++ b/stylegan.py @@ -56,51 +56,6 @@ def __init__(self, in_dim, out_dim): def forward(self, input): return self.linear(input) -class ConvBlock(nn.Module): - def __init__( - self, - in_channel, - out_channel, - kernel_size, - padding, - kernel_size2=None, - padding2=None, - downsample=False, - fused=False, - ): - super().__init__() - - pad1 = padding - pad2 = padding - if padding2 is not None: - pad2 = padding2 - - kernel1 = kernel_size - kernel2 = kernel_size - if kernel_size2 is not None: - kernel2 = kernel_size2 - - self.conv1 = nn.Sequential( - EqualConv2d(in_channel, out_channel, kernel1, padding=pad1), - nn.LeakyReLU(0.2), - ) - - if downsample: - self.conv2 = nn.Sequential( - EqualConv2d(out_channel, out_channel, kernel2, padding=pad2), - nn.AvgPool2d(2), - nn.LeakyReLU(0.2), - ) - else: - self.conv2 = nn.Sequential( - EqualConv2d(out_channel, out_channel, kernel2, padding=pad2), - nn.LeakyReLU(0.2), - ) - - def forward(self, input): - out = self.conv1(input) - out = self.conv2(out) - return out diff --git a/train.py b/train.py index 92775c4..b068238 100644 --- a/train.py +++ b/train.py @@ -53,12 +53,12 @@ def get_noise_magnitude(epoch, max_epochs, initial_magnitude=0.1, final_magnitud def get_layer_wise_learning_rates(model): params = [] - params.append({'params': model.dense_feature_encoder.parameters(), 'lr': 1e-4}) - params.append({'params': model.latent_token_encoder.parameters(), 'lr': 5e-4}) - params.append({'params': model.latent_token_decoder.parameters(), 'lr': 5e-4}) - params.append({'params': model.implicit_motion_alignment.parameters(), 'lr': 2e-4}) - params.append({'params': model.frame_decoder.parameters(), 'lr': 1e-4}) - params.append({'params': model.mapping_network.parameters(), 'lr': 1e-5}) + params.append({'params': model.dense_feature_encoder.parameters(), 'lr': 1e-5}) + params.append({'params': model.latent_token_encoder.parameters(), 'lr': 5e-5}) + params.append({'params': model.latent_token_decoder.parameters(), 'lr': 5e-5}) + params.append({'params': model.implicit_motion_alignment.parameters(), 'lr': 2e-5}) + params.append({'params': model.frame_decoder.parameters(), 'lr': 1e-5}) + params.append({'params': model.mapping_network.parameters(), 'lr': 1e-6}) return params From f17a027c05a2c41cd6f5846c8eaaa998d750abbb Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 16:51:56 +1000 Subject: [PATCH 087/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index b068238..76b22c5 100644 --- a/train.py +++ b/train.py @@ -63,14 +63,16 @@ def get_layer_wise_learning_rates(model): + + def train(config, model, discriminator, train_dataloader, val_loader, accelerator): # layerwise params - layer_wise_params = get_layer_wise_learning_rates(model) + # layer_wise_params = get_layer_wise_learning_rates(model) # Generator optimizer optimizer_g = AdamW( - layer_wise_params, + model.parameters(), lr=config.training.learning_rate_g, betas=(config.optimizer.beta1, config.optimizer.beta2), weight_decay=config.training.weight_decay From 178daac502908bdf9ea6766c8f50bf1e6420eeaf Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 16:56:02 +1000 Subject: [PATCH 088/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/helper.cpython-311.pyc | Bin 25442 -> 25442 bytes __pycache__/resblock.cpython-311.pyc | Bin 27853 -> 27853 bytes __pycache__/stylegan.cpython-311.pyc | Bin 6627 -> 6627 bytes train.py | 16 ++++++++-------- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index aca80ef083b03b26e66ab7d2dbd15be2c630ff70..8aac144e47eaf695356fc78b036ff973522f0525 100644 GIT binary patch delta 21 bcmaEKjPcPiMy}<&yj%=G5D~wTDu?SLQu+qP delta 21 bcmX?mlkx0LMy}<&yj%=G@WN{&*WnxhRq+Qv diff --git a/__pycache__/stylegan.cpython-311.pyc b/__pycache__/stylegan.cpython-311.pyc index 241966a3a52ef3472f931d134b04a0c43b1e6d9f..1feb2386a31566607a1ec7f9d1bb46185fb718fa 100644 GIT binary patch delta 19 ZcmaEC{MeXlIWI340}w>SZ{)fu2>?3C1!e#M delta 19 ZcmaEC{MeXlIWI340}vENZsfWt2>?4r1$qDg diff --git a/train.py b/train.py index 76b22c5..0a69308 100644 --- a/train.py +++ b/train.py @@ -53,12 +53,12 @@ def get_noise_magnitude(epoch, max_epochs, initial_magnitude=0.1, final_magnitud def get_layer_wise_learning_rates(model): params = [] - params.append({'params': model.dense_feature_encoder.parameters(), 'lr': 1e-5}) - params.append({'params': model.latent_token_encoder.parameters(), 'lr': 5e-5}) - params.append({'params': model.latent_token_decoder.parameters(), 'lr': 5e-5}) - params.append({'params': model.implicit_motion_alignment.parameters(), 'lr': 2e-5}) - params.append({'params': model.frame_decoder.parameters(), 'lr': 1e-5}) - params.append({'params': model.mapping_network.parameters(), 'lr': 1e-6}) + params.append({'params': model.dense_feature_encoder.parameters(), 'lr': 1e-4}) + params.append({'params': model.latent_token_encoder.parameters(), 'lr': 1e-4}) + params.append({'params': model.latent_token_decoder.parameters(), 'lr': 1e-4}) + params.append({'params': model.implicit_motion_alignment.parameters(), 'lr': 1e-4}) + params.append({'params': model.frame_decoder.parameters(), 'lr': 1e-4}) + params.append({'params': model.mapping_network.parameters(), 'lr': 1e-4}) return params @@ -68,11 +68,11 @@ def get_layer_wise_learning_rates(model): def train(config, model, discriminator, train_dataloader, val_loader, accelerator): # layerwise params - # layer_wise_params = get_layer_wise_learning_rates(model) + layer_wise_params = get_layer_wise_learning_rates(model) # Generator optimizer optimizer_g = AdamW( - model.parameters(), + layer_wise_params, lr=config.training.learning_rate_g, betas=(config.optimizer.beta1, config.optimizer.beta2), weight_decay=config.training.weight_decay From 9cde9613170dca17e40bf6c4caae3bf56da5f7f2 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 16:58:48 +1000 Subject: [PATCH 089/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 0a69308..3e4faa0 100644 --- a/train.py +++ b/train.py @@ -59,6 +59,7 @@ def get_layer_wise_learning_rates(model): params.append({'params': model.implicit_motion_alignment.parameters(), 'lr': 1e-4}) params.append({'params': model.frame_decoder.parameters(), 'lr': 1e-4}) params.append({'params': model.mapping_network.parameters(), 'lr': 1e-4}) + params.append({'params': model.noise_injection.parameters(), 'lr': 1e-4}) return params @@ -73,7 +74,6 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato # Generator optimizer optimizer_g = AdamW( layer_wise_params, - lr=config.training.learning_rate_g, betas=(config.optimizer.beta1, config.optimizer.beta2), weight_decay=config.training.weight_decay ) From 849137ec7424e8e62d864295c35e0da77b0d5b39 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 17:01:53 +1000 Subject: [PATCH 090/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 3e4faa0..50adaa6 100644 --- a/train.py +++ b/train.py @@ -59,7 +59,7 @@ def get_layer_wise_learning_rates(model): params.append({'params': model.implicit_motion_alignment.parameters(), 'lr': 1e-4}) params.append({'params': model.frame_decoder.parameters(), 'lr': 1e-4}) params.append({'params': model.mapping_network.parameters(), 'lr': 1e-4}) - params.append({'params': model.noise_injection.parameters(), 'lr': 1e-4}) + # params.append({'params': model.noise_injection.parameters(), 'lr': 1e-4}) return params From f0ebc3173387119c716d81a8c478c2293f1a9158 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 17:02:45 +1000 Subject: [PATCH 091/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/train.py b/train.py index 50adaa6..e9929a9 100644 --- a/train.py +++ b/train.py @@ -53,13 +53,12 @@ def get_noise_magnitude(epoch, max_epochs, initial_magnitude=0.1, final_magnitud def get_layer_wise_learning_rates(model): params = [] - params.append({'params': model.dense_feature_encoder.parameters(), 'lr': 1e-4}) - params.append({'params': model.latent_token_encoder.parameters(), 'lr': 1e-4}) - params.append({'params': model.latent_token_decoder.parameters(), 'lr': 1e-4}) - params.append({'params': model.implicit_motion_alignment.parameters(), 'lr': 1e-4}) - params.append({'params': model.frame_decoder.parameters(), 'lr': 1e-4}) - params.append({'params': model.mapping_network.parameters(), 'lr': 1e-4}) - # params.append({'params': model.noise_injection.parameters(), 'lr': 1e-4}) + params.append({'params': model.dense_feature_encoder.parameters(), 'lr': 1e-5}) + params.append({'params': model.latent_token_encoder.parameters(), 'lr': 1e-5}) + params.append({'params': model.latent_token_decoder.parameters(), 'lr': 1e-5}) + params.append({'params': model.implicit_motion_alignment.parameters(), 'lr': 1e-5}) + params.append({'params': model.frame_decoder.parameters(), 'lr': 1e-5}) + params.append({'params': model.mapping_network.parameters(), 'lr': 1e-5}) return params From c6ebb835bcc630e9758f9eb515309002b29e197f Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 17:06:01 +1000 Subject: [PATCH 092/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/train.py b/train.py index e9929a9..a509704 100644 --- a/train.py +++ b/train.py @@ -68,14 +68,10 @@ def get_layer_wise_learning_rates(model): def train(config, model, discriminator, train_dataloader, val_loader, accelerator): # layerwise params - layer_wise_params = get_layer_wise_learning_rates(model) + # layer_wise_params = get_layer_wise_learning_rates(model) # Generator optimizer - optimizer_g = AdamW( - layer_wise_params, - betas=(config.optimizer.beta1, config.optimizer.beta2), - weight_decay=config.training.weight_decay - ) + optimizer_g = Adam( layer_wise_params ) # Discriminator optimizer optimizer_d = AdamW( @@ -106,8 +102,6 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato # Use the unified gan_loss_fn gan_loss_type = config.loss.type - # perceptual_loss_fn = VGGPerceptualLoss().to(accelerator.device) - # perceptual_loss_fn = LPIPSPerceptualLoss().to(accelerator.device) perceptual_loss_fn = lpips.LPIPS(net='alex').to(accelerator.device) pixel_loss_fn = nn.L1Loss() From e0549ee5f2e076ca071535ee7d9fcd41ff66610f Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 17:06:28 +1000 Subject: [PATCH 093/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index a509704..69f7d64 100644 --- a/train.py +++ b/train.py @@ -71,7 +71,7 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato # layer_wise_params = get_layer_wise_learning_rates(model) # Generator optimizer - optimizer_g = Adam( layer_wise_params ) + optimizer_g = AdamW( layer_wise_params ) # Discriminator optimizer optimizer_d = AdamW( From 1f3f6dffc5f37008454516f6e83f14355fa22f31 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 17:06:54 +1000 Subject: [PATCH 094/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 69f7d64..b695431 100644 --- a/train.py +++ b/train.py @@ -68,7 +68,7 @@ def get_layer_wise_learning_rates(model): def train(config, model, discriminator, train_dataloader, val_loader, accelerator): # layerwise params - # layer_wise_params = get_layer_wise_learning_rates(model) + layer_wise_params = get_layer_wise_learning_rates(model) # Generator optimizer optimizer_g = AdamW( layer_wise_params ) From 38151e352fd7ea46f279f042d287017d1547a1f5 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 17:11:29 +1000 Subject: [PATCH 095/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index b695431..4eab33e 100644 --- a/train.py +++ b/train.py @@ -62,7 +62,12 @@ def get_layer_wise_learning_rates(model): return params - +def warmup_learning_rate(optimizer, current_step, warmup_steps, base_lr): + if current_step < warmup_steps: + lr = base_lr * current_step / warmup_steps + for param_group in optimizer.param_groups: + param_group['lr'] = lr + return optimizer def train(config, model, discriminator, train_dataloader, val_loader, accelerator): @@ -126,7 +131,8 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato total_g_loss = 0 total_d_loss = 0 - + warmup_steps = 100 + base_lr = 1e-4 current_decay = get_ema_decay(epoch, config.training.num_epochs) if ema: ema.decay = current_decay @@ -134,6 +140,10 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato for batch_idx, batch in enumerate(train_dataloader): # Repeat the current video for the specified number of times for _ in range(int(video_repeat)): + + + optimizer_g = warmup_learning_rate(optimizer_g, global_step, warmup_steps, base_lr) + source_frames = batch['frames'] batch_size, num_frames, channels, height, width = source_frames.shape From ac027b1a1897ebb3756f03e83f3bee1ee94ba613 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 17:11:59 +1000 Subject: [PATCH 096/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 4eab33e..a67128d 100644 --- a/train.py +++ b/train.py @@ -132,7 +132,7 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato total_d_loss = 0 warmup_steps = 100 - base_lr = 1e-4 + base_lr = 1e-5 current_decay = get_ema_decay(epoch, config.training.num_epochs) if ema: ema.decay = current_decay From 6e11dc73ed35c99be9949e04d516b814ccb616d8 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 17:15:12 +1000 Subject: [PATCH 097/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index a67128d..7d48bdd 100644 --- a/train.py +++ b/train.py @@ -131,7 +131,7 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato total_g_loss = 0 total_d_loss = 0 - warmup_steps = 100 + warmup_steps = 1000 base_lr = 1e-5 current_decay = get_ema_decay(epoch, config.training.num_epochs) if ema: From 8b1d17597e1ac22551e7d7642741688b755ce2a2 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 17:19:47 +1000 Subject: [PATCH 098/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 7d48bdd..8701188 100644 --- a/train.py +++ b/train.py @@ -73,10 +73,13 @@ def warmup_learning_rate(optimizer, current_step, warmup_steps, base_lr): def train(config, model, discriminator, train_dataloader, val_loader, accelerator): # layerwise params - layer_wise_params = get_layer_wise_learning_rates(model) + # layer_wise_params = get_layer_wise_learning_rates(model) # Generator optimizer - optimizer_g = AdamW( layer_wise_params ) + optimizer_g = AdamW( model.parameters(), + lr=config.training.learning_rate_g, + betas=(config.optimizer.beta1, config.optimizer.beta2), + weight_decay=config.training.weight_decay ) # Discriminator optimizer optimizer_d = AdamW( From 43cad56ad9bd517417eb9fdb52fa24c6f6beada6 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 17:24:35 +1000 Subject: [PATCH 099/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 8701188..3cecdf8 100644 --- a/train.py +++ b/train.py @@ -72,7 +72,7 @@ def warmup_learning_rate(optimizer, current_step, warmup_steps, base_lr): def train(config, model, discriminator, train_dataloader, val_loader, accelerator): - # layerwise params + # layerwise params - # layer_wise_params = get_layer_wise_learning_rates(model) # Generator optimizer @@ -134,8 +134,8 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato total_g_loss = 0 total_d_loss = 0 - warmup_steps = 1000 - base_lr = 1e-5 + # warmup_steps = 1000 + # base_lr = 1e-5 current_decay = get_ema_decay(epoch, config.training.num_epochs) if ema: ema.decay = current_decay @@ -145,7 +145,7 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato for _ in range(int(video_repeat)): - optimizer_g = warmup_learning_rate(optimizer_g, global_step, warmup_steps, base_lr) + # optimizer_g = warmup_learning_rate(optimizer_g, global_step, warmup_steps, base_lr) source_frames = batch['frames'] batch_size, num_frames, channels, height, width = source_frames.shape From 49a034c743178cbec6a36e39be682ca01b843175 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 17:30:25 +1000 Subject: [PATCH 100/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 3 +- train.py | 82 +++++++++-------------------------------------------- 2 files changed, 14 insertions(+), 71 deletions(-) diff --git a/config.yaml b/config.yaml index b5a09b6..3acedac 100644 --- a/config.yaml +++ b/config.yaml @@ -35,7 +35,6 @@ training: r1_gamma: 10 r1_interval: 16 label_smoothing: 0.1 - min_learning_rate_d: 1.0e-6 max_learning_rate_d: 1.0e-3 d_lr_adjust_frequency: 100 # Adjust D learning rate every 100 steps @@ -43,11 +42,11 @@ training: target_d_loss_ratio: 0.6 # Target ratio of D loss to G loss every_xref_frames: 16 use_many_xrefs: False - scales: [1, 0.5, 0.25, 0.125] enable_xformers_memory_efficient_attention: True learning_rate_d: 1.0e-4 weight_decay: 0.01 + lte_learning_rate: 1.0e-5 # Dataset parameters dataset: # celeb-hq torrent https://github.com/johndpope/MegaPortrait-hack/tree/main/junk diff --git a/train.py b/train.py index 3cecdf8..03747f9 100644 --- a/train.py +++ b/train.py @@ -51,35 +51,17 @@ def get_noise_magnitude(epoch, max_epochs, initial_magnitude=0.1, final_magnitud """ return max(final_magnitude, initial_magnitude - (initial_magnitude - final_magnitude) * (epoch / max_epochs)) -def get_layer_wise_learning_rates(model): - params = [] - params.append({'params': model.dense_feature_encoder.parameters(), 'lr': 1e-5}) - params.append({'params': model.latent_token_encoder.parameters(), 'lr': 1e-5}) - params.append({'params': model.latent_token_decoder.parameters(), 'lr': 1e-5}) - params.append({'params': model.implicit_motion_alignment.parameters(), 'lr': 1e-5}) - params.append({'params': model.frame_decoder.parameters(), 'lr': 1e-5}) - params.append({'params': model.mapping_network.parameters(), 'lr': 1e-5}) - return params - - -def warmup_learning_rate(optimizer, current_step, warmup_steps, base_lr): - if current_step < warmup_steps: - lr = base_lr * current_step / warmup_steps - for param_group in optimizer.param_groups: - param_group['lr'] = lr - return optimizer def train(config, model, discriminator, train_dataloader, val_loader, accelerator): - # layerwise params - - # layer_wise_params = get_layer_wise_learning_rates(model) - # Generator optimizer - optimizer_g = AdamW( model.parameters(), + optimizer_g = AdamW( + model.parameters(), lr=config.training.learning_rate_g, betas=(config.optimizer.beta1, config.optimizer.beta2), - weight_decay=config.training.weight_decay ) + weight_decay=config.training.weight_decay + ) # Discriminator optimizer optimizer_d = AdamW( @@ -110,6 +92,8 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato # Use the unified gan_loss_fn gan_loss_type = config.loss.type + # perceptual_loss_fn = VGGPerceptualLoss().to(accelerator.device) + # perceptual_loss_fn = LPIPSPerceptualLoss().to(accelerator.device) perceptual_loss_fn = lpips.LPIPS(net='alex').to(accelerator.device) pixel_loss_fn = nn.L1Loss() @@ -134,8 +118,7 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato total_g_loss = 0 total_d_loss = 0 - # warmup_steps = 1000 - # base_lr = 1e-5 + current_decay = get_ema_decay(epoch, config.training.num_epochs) if ema: ema.decay = current_decay @@ -143,10 +126,6 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato for batch_idx, batch in enumerate(train_dataloader): # Repeat the current video for the specified number of times for _ in range(int(video_repeat)): - - - # optimizer_g = warmup_learning_rate(optimizer_g, global_step, warmup_steps, base_lr) - source_frames = batch['frames'] batch_size, num_frames, channels, height, width = source_frames.shape @@ -331,9 +310,8 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato scheduler_d.step(avg_d_loss) # Logging if accelerator.is_main_process: - # Existing logs - log_dict = { - "ema": current_decay, + wandb.log({ + "ema":current_decay, "noise_magnitude": noise_magnitude, "batch_g_loss": g_loss.item(), "batch_d_loss": d_loss.item(), @@ -341,41 +319,9 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato "perceptual_loss": l_v.item(), "gan_loss": g_loss_gan.item(), "batch": batch_idx + epoch * len(train_dataloader), - } - - # Add layer-wise learning rates - component_names = [ - 'dense_feature_encoder', - 'latent_token_encoder', - 'latent_token_decoder', - 'implicit_motion_alignment', - 'frame_decoder', - 'mapping_network' - ] - for i, param_group in enumerate(optimizer_g.param_groups): - log_dict[f"lr_g_{component_names[i]}"] = param_group['lr'] - log_dict["lr_d"] = optimizer_d.param_groups[0]['lr'] - - # Add gradient norms for each component of the generator - for component in component_names: - params = getattr(model, component).parameters() - grad_norms = [torch.norm(p.grad.detach()) for p in params if p.grad is not None] - if grad_norms: - grad_norm = torch.norm(torch.stack(grad_norms)) - log_dict[f"grad_norm_{component}"] = grad_norm.item() - else: - log_dict[f"grad_norm_{component}"] = 0.0 - - # Add gradient norm for the discriminator - disc_grad_norms = [torch.norm(p.grad.detach()) for p in discriminator.parameters() if p.grad is not None] - if disc_grad_norms: - disc_grad_norm = torch.norm(torch.stack(disc_grad_norms)) - log_dict["grad_norm_discriminator"] = disc_grad_norm.item() - else: - log_dict["grad_norm_discriminator"] = 0.0 - - # Log to wandb - wandb.log(log_dict) + "lr_g": optimizer_g.param_groups[0]['lr'], + "lr_d": optimizer_d.param_groups[0]['lr'] + }) # Log gradient flow for generator and discriminator criterion = [perceptual_loss_fn,pixel_loss_fn] @@ -422,10 +368,8 @@ def main(): num_layers=config.model.num_layers, use_resnet_feature=config.model.use_resnet_feature, use_mlgffn=config.model.use_mlgffn, - use_enhanced_generator=config.model.use_enhanced_generator, - use_skip=config.model.use_skip + use_enhanced_generator=config.model.use_enhanced_generator ) - add_gradient_hooks(model) # discriminator = PatchDiscriminator(ndf=config.discriminator.ndf) Original From 500bc7ee07afc151dbad0ecc018279a6a50c6e05 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 19:02:28 +1000 Subject: [PATCH 101/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 2 +- resblock.py | 52 +++++++++++++++++++++++++++++++++++++++------------- 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/config.yaml b/config.yaml index 3acedac..628e62a 100644 --- a/config.yaml +++ b/config.yaml @@ -45,7 +45,7 @@ training: scales: [1, 0.5, 0.25, 0.125] enable_xformers_memory_efficient_attention: True learning_rate_d: 1.0e-4 - weight_decay: 0.01 + weight_decay: 1e-5 lte_learning_rate: 1.0e-5 # Dataset parameters dataset: diff --git a/resblock.py b/resblock.py index b0c66e8..c24973f 100644 --- a/resblock.py +++ b/resblock.py @@ -95,24 +95,50 @@ def forward(self, x): class ResBlock(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2 if downsample else 1, padding=1) + self.bn1 = nn.BatchNorm2d(out_channels) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.bn2 = nn.BatchNorm2d(out_channels) + self.relu2 = nn.ReLU(inplace=True) - # Main path - self.conv1 = ConvLayer(in_channels, out_channels, downsample=True) - self.conv2 = ConvLayer(out_channels, out_channels) - - # Skip connection path - self.skip_conv = ConvLayer(in_channels, out_channels, downsample=True) + if downsample or in_channels != out_channels: + self.shortcut = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2 if downsample else 1, padding=0), + nn.BatchNorm2d(out_channels) + ) + else: + self.shortcut = nn.Identity() + + self.downsample = downsample + self.in_channels = in_channels + self.out_channels = out_channels def forward(self, x): - # Main path - main = self.conv1(x) - main = self.conv2(main) + debug_print(f"ResBlock input shape: {x.shape}") + debug_print(f"ResBlock parameters: in_channels={self.in_channels}, out_channels={self.out_channels}, downsample={self.downsample}") + + residual = self.shortcut(x) + debug_print(f"After shortcut: {residual.shape}") - # Skip connection path - skip = self.skip_conv(x) + out = self.conv1(x) + debug_print(f"After conv1: {out.shape}") + out = self.bn1(out) + out = self.relu1(out) + debug_print(f"After bn1 and relu1: {out.shape}") - # Combine paths - return main + skip + out = self.conv2(out) + debug_print(f"After conv2: {out.shape}") + out = self.bn2(out) + debug_print(f"After bn2: {out.shape}") + + out += residual + debug_print(f"After adding residual: {out.shape}") + + out = self.relu2(out) + debug_print(f"ResBlock output shape: {out.shape}") + + return out class ModulatedConv2d(nn.Module): From 9c8f167e566445960143ec67bd07e61156d778d1 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 19:04:04 +1000 Subject: [PATCH 102/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/resblock.cpython-311.pyc | Bin 27853 -> 29855 bytes resblock.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/__pycache__/resblock.cpython-311.pyc b/__pycache__/resblock.cpython-311.pyc index 3f786ca350c603f38ab0f89ebabe298d70c2f57c..2a47484b4d7a77167afdecd700adb362ce49ac10 100644 GIT binary patch delta 3294 zcma)8drVu`8NbIbu5BL1#$Xo{3?UB_9s#jHND0mISRT-Xkg5&J3Kx6>2K(AMHZ%|i zqE>4mV{4j27TRVC+oF`LQKQUK({<4lby`O!ZN2WCD)GjeN~@+y)f6=L$Nt;*UE>EC zns#;X?|#p7e&_Yw`-g|*>enRqotzv42bXjFgOg>gw{s&!;?lL#ggLZZHRPv6l)AajP@aQoOU0ywhAOud84n@FXr^o|}{<+o~S&*s12i zrp$oMKESR?CISPl+mzB&pyjikTq}VlatkJ~Gnbhb!;3l)N)Vg~9SCIzT)hD6PWCYJ^IFAkuR>o#&WuUAd~3 zFjDN;W-XdIK6_@q?VfP|xo-K$N%=?z-r`9)a7wn)sFlX7bW$^IShQGYcFt8rEwwRA zEnBj#W952#_7g_i#8*b^V@{{9oGqTu{>#npfn#?o$*)4P7d-?Z|bjgQ= zg}Q}0xkr$Pyt2s`HThyDAI9X)hNGscn5k+>t1ii1<^U%5Ot(+A|I5IYlwLXXX8V

nP`q(c6XR_U9GmfNmz&_>?bHY(?Mp2wD7c4}O}z}(%Up0%D%%q&|!}cYS@P@CG7FxA~s%{Z&9tv*M?PzX4zD;4g6+y zt070L3#(VGOE1`&Tv)&=ju*l^U-4$8yxMIx_Gw`vhfi5YJJMEt(&`AS!>X|yjytPf zMZk~@VDIfurx}yhiLknpb7em}R(Mnle%tHs8*+N2z;MVJ91sFxvvaI0We*5c@QERj z2Ad(KyKg{{B+(me8Oy73`iDaqhMX#=+y9aj6np`%*fM6?brQ0m#l+TxMspIDKEE_E zQU|&b2@CK9#ms83YDUSBDmcisUI?e9`~@|MMsJ-kjfZ|E1P01 zjBAyVB~L{y91G>NJrBObkyB3F8V8LYgb=n+x?y4%gy3b^J?FSZ-Xa%loA0>!JiOqJ z@{KXRQRW-L@tEVRhg2VEVom!KVglvk-;Es~TxA(OyUcN8 zP1lLx1lkcUFy7TgPOv-9eUTe3JE_Nv{yuWOA>&J-;Z5~M!D3#c9 z zcnu(k1=Q)J_)kHn*oMYx@;U2l9I~EA!}k%U5q^Mhk^QN0i}r00qY-seORCv@J8Hg_ z3U%>Rhrt<~P#r{g9pM53E)KMJW5*`#k3sw*a(_pYx(I1v$?6P^dc|a!1Kt&_!PY0* z!zIjU6x(y@5_30$t}_UA2+E{hVxPBGYOjL+O+@<)(b$x(OjXV{n0F~vQlYOe4_a8= z9@j*AK-p^zD6a8(f}uo*3YSvsYmHI5^l4Uc11_@^>UMj433BH!>k>dbS3DgM`$D3- zTPcEB^hfBr%pUF8ti1-}zau$&|ENu_~}+1;MbqBIGEC+LxaAwlXB=?L2>L^daSfrNOrH0(=#Pw7vw4P_qAp>&Cn-hS;v z5Wf@g^KQCLfrYNeD3ipQVgVWoJ&=v4&VNEBCs;snu$tYl!j v?p+jLWhVw~hHpRwid=AiWPHFzs%r?;-v5S4Iz76C_T{v*mqb4D7}fs)#@2Dr delta 1946 zcmaJ?eQZ-z6o0pU?`7?PPS)Cu4b}aAjTF-PeM2 z1%-%qb9{`u=|YHtMm_=&<;8%+sL`0HQQXK9;3?`~CgYDejmAIqoZEfC1l#m?&;8wV z?)lw&?!A3$oP2kaEI2)XzFC0JbKc9_y;UD9(CoyieQs|Osm|bWhY-ShA zZA{CpCocBM!a7#H*iIxiQMh$cf+P&>;lk5&PIZc;k~5wAZS+R5H^tRgK(b8_qdl7$EJ zhI`JqKHD^6YaF#Tj@cTMV!!#8ZOL%r%+^s`)tIdc32KAXq-kmqk?WmLZpvCLrVd4vC9_-=^~0_ zPZcb8?Eq4(1$asL#gchA-WMOv@7>eCCn?^T>Mk3#!v6tH=hLy0(gy9D!XP27 z?C}+qWS7>xV%)e)m(&hlMA5~a0c98Uf~#(dh1BktqQn*2ilN(BqNJ60*{za#vR7NW z@)R*QBag5X6x!MCRdu9UD=GEj$2L@coxcsoTM*g-bSWB(1QmYHXfd(@R#w#`y#eM~?PAqUqr4B< z{cK=uUGB?h9zb{nKs5k39MptFl1;8HCLgmUHJ!FzB>E6uL+D2sU?WbPT$e zw5v5w%uM}2T}k{xxl6JrT2bs*ECs zi&x@~T;oqy2WJ;F?_LYi#b7{qF!DPIwkd2Pm6p3 z-2M=>#!V9v+0QDQ{nIkvZhp9SMkX9n3Xcsxvh)-xJB@G#VXj8t;0ztbA;!#GR!wir zx-BcE4}o2%MYm)bNSk(W+xOW{JKQ)&gW`*BRMaPNIc4+;jKC{+3B9=M z6?U-0W1kgf4yb{U8uf+z3XQS9JG8=7W(d&dg<@S^e}^v=Qi3WyiF5F+nJU_~!1K~| zAd9uNJBGySn_!Mtpwka%34~vGD-!xN%ZK0@;Ovj+#{VyFF-thhHiBN5#c7~z(72rs z!JAbbew1{w?rDC{~B!rqSzcB3vn- sMe_hFj%1l{g9#D^b-mUU$s$Eo0`KpCAxZmp|Bn2hvx6ntV0XIlU+1*`sQ>@~ diff --git a/resblock.py b/resblock.py index c24973f..abbf41e 100644 --- a/resblock.py +++ b/resblock.py @@ -93,7 +93,7 @@ def forward(self, x): return self.relu(self.bn(self.conv(x))) class ResBlock(nn.Module): - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels, out_channels, downsample=False): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2 if downsample else 1, padding=1) self.bn1 = nn.BatchNorm2d(out_channels) From 44e04df85fe849ff0faf317a064b48b6c1ac4bed Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 19:04:11 +1000 Subject: [PATCH 103/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/resblock.cpython-311.pyc | Bin 29855 -> 29829 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/__pycache__/resblock.cpython-311.pyc b/__pycache__/resblock.cpython-311.pyc index 2a47484b4d7a77167afdecd700adb362ce49ac10..be2e58ad4d9e723f0eef3a3ce4d7a145956258fe 100644 GIT binary patch delta 2274 zcmZvdYiv_h9Kd(_xLLc^jj^(hl|BZtb*_8C0xKOG6bghoVFRNqTU**oR@UBIPCJLt zphRARL*`!|O9&zcaE?Iaejy?lzxaWe%%F*}5#*C*2`UL6{HC7&x$8QX-K4+%&pH3s zIsfzcWQrV}BBtvmld*t(?;g6dXT_H$Pa$zBn+wlqNFDrQ{DU;Zw?$^+R(>wBXl->& zkG_JaXUG;ushpID(m@(di_{I1i!I8_CG8q(JE|;4)FNCAE-mHOO>k(zBIfn%f-f8% z9Be_XVn`8&PHMn*8wUi-`bnE|zwDT%6pgcmk%6!ziSZP5D?7~FG@E^>=R|BkY(((5 zekP6Urb9A0TS&!JEIJg9(@xe*=@=aBQ^PgG(~4_?nu&73V$s;g^{Q=tvc{-9ETzJU zWSmV^nOwY{SYKmKXcwXpeQLPRoJY5{Yw0-g&p*Ea8XqubL`t*VtSpa6bIY6a-B< zS+b<{ux5L8+3e1YmbFUYFMAknyi{JwtHR5mRt5GumX&U29`bvKZo}OI_Z=PN2l&@< zj68v(PA8MEJ2#Lg%2Q{x&OfvEl-s_64tTAXqBJ)M_HAL9TUuqu^1EbN4-U_})pp_t zBFIbOtF^%EwQjPLmBX(*sO=)lUflI`5gJq=EQkogt$6F~I%hAFv%0rp;xOHfqG3o% z?Hdxs1EQ)Hf=q*3m_7%$8r&wXrH_gUIZ8WWs=-Q{6`}DashT&{J`B|dlTH1kPibfl z5)y&aEtMAD&-dXHj4A|Tk{G5`hHqM`i)LjHPg^PlHDB=7RVm9?oF`inXi^lD*a=R` zaXiQ}hT-lGGRZaj6bYyxrMb%7g$AMZK#4lx8tHCH4sjU9tX?#Z-1&Igcf;!xC3T6MSm_qUH!8- zltDBhcv;S|)a%G)<@-%~orQ-ft#<@?QD()N+UkwxCsI7$M6$YgER|LR@f_?JW1iW$ z4Eb!bMz#LIXe^;7br`MA!DygdIFI55mv2Rnr~95ST?)pQhoM#CIS49DBzR6NtGfus*geKE@$TBx~G_*K|+{ zCq&u=RlWAyB6RjPk~>PW*Qz0+GSW9`Y`%*|Uo#w7xNZ+SWDdU=PUkNA7J843uqW~> z)Z(i5U}T`DB7|+8r7@O@NvU*Liiqqe{WDNw=Wk>?D!33e2p@vKf*YVcmZ;6AS}YAE zf-}$OXQ;=Qak_vHV~wN}G<#Q(tm4@_rt|*FG^SRg+GCUo2)_x+vExy$mk+V~f%pD~ zYE~S20wuD={3D7O$EBen`|EO5(-3&&l_UZCWIOo|M&+tad=fnM0P5pCO#2Z0&KN`K lBH|L_GU6)YHsTH6pw>DX(4|q{^Xmw{{dQ58g&2w delta 2459 zcmZ`)Z){Ul6yNRI(se6r%tB$MuN`x1_ovVbVb-xZ#tP~-aT^i`u9Wt{I=a_+-|Nsx zz%T&?QQ$BI2_YC1fy6jo0tf_SqKPEx5)w3fh<+kUN)v;Lq!OZGMWkLBpKnv2oV8qMZ$wHZf*ZlZz>`n8wGVcb%wb+fOU^l; zLFH2j^$cXf;8|LL;{XQ;&F&=u_4@2Rx?(ieOp(5@B#CjEdev_08r|xZ$a5jIA_Nh5 z+*M3EtA8Al$(o6Zax9t*$7wSg%6bM)_NBu4;r`sujHyi2qc)pP@EB6InZ=r;iEWY` z9vFWx41M@ zrzo$8Cn8%T`AnVua8q#n@FHwn5GheiX$5nmi|QnZJ8Z-NE~6E0uPcB{#l`yVd7&9N zK}_(YV*zaR7r+6B1$rHoaM3X*n^u^i`HIqxvqQ+n?PR-dD^G~FD~2>DFR$MMB|=3} zHW`))Rowd)!8LDYnkw0*n9YJy!n}+&+n(jEhZ{oJ^AfX7)wUd^g$NZ0yy(2jsiH%N zv!ZmeP_&wHczt2NvxD3M&AEr%g|N%T%AOGDO9JM3gqexYthn>V5AQ z|C8zl42N)#AIF~-w~`Ie?+&UxzInut$!z)f7=AaJJZ4&kY z1fTlLt8=VSE0C2p;$kQo8=wL@NUc+S;J-kgi=&Ien8dDQTuE#CcuZDOfp`wQ_&HYMDad4_nNv09r-di=7G^feK$|JPJ}5>MF&g4| z%%C5l>LWPR<|e1rD{VIo>G~Y*tTm>xgKs;x)=%4zQ)tG^%#Rtb_4_zCJvZ3W^>iWM zBc919obM_#J`VarS1%b;Yq}pA*scpbk#xLwda61n<7ozkhGP;xU-S%y_#EL}u3&7A z6(3NAAvN!VJs`9JjX5SSwy2@afzdx^h77N=eVxvy((CQEZnj!Ci-mLei+lskP@LIAn^LC0sDRw3})F^Ezt zEQ}4fGpQCy$$`-1EBy&lvPuV+V~a>L{25zDCRAttG5yZljByWNe|iU{Ji`0@37q&f zJ Date: Sun, 11 Aug 2024 19:06:57 +1000 Subject: [PATCH 104/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.yaml b/config.yaml index 628e62a..0940197 100644 --- a/config.yaml +++ b/config.yaml @@ -13,7 +13,7 @@ training: final_video_repeat: 2 use_ema: False use_r1_reg: True - batch_size: 2 # need to redo emodataset to remove the cache npz - smaller numbers won't work... + batch_size: 1 # need to redo emodataset to remove the cache npz - smaller numbers won't work... num_epochs: 1000 save_steps: 250 learning_rate_g: 1.0e-4 # Reduced learning rate for generator From 4b5b9316e7891c70f7a5f915c802e984e08575c4 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 19:13:15 +1000 Subject: [PATCH 105/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- resblock.py | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/resblock.py b/resblock.py index abbf41e..c464e68 100644 --- a/resblock.py +++ b/resblock.py @@ -102,41 +102,27 @@ def __init__(self, in_channels, out_channels, downsample=False): self.bn2 = nn.BatchNorm2d(out_channels) self.relu2 = nn.ReLU(inplace=True) - if downsample or in_channels != out_channels: - self.shortcut = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2 if downsample else 1, padding=0), - nn.BatchNorm2d(out_channels) - ) - else: - self.shortcut = nn.Identity() - + self.skip_conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2 if downsample else 1, bias=False), + nn.BatchNorm2d(out_channels) + ) + self.downsample = downsample self.in_channels = in_channels self.out_channels = out_channels def forward(self, x): - debug_print(f"ResBlock input shape: {x.shape}") - debug_print(f"ResBlock parameters: in_channels={self.in_channels}, out_channels={self.out_channels}, downsample={self.downsample}") - - residual = self.shortcut(x) - debug_print(f"After shortcut: {residual.shape}") + residual = self.skip_conv(x) out = self.conv1(x) - debug_print(f"After conv1: {out.shape}") out = self.bn1(out) out = self.relu1(out) - debug_print(f"After bn1 and relu1: {out.shape}") out = self.conv2(out) - debug_print(f"After conv2: {out.shape}") out = self.bn2(out) - debug_print(f"After bn2: {out.shape}") out += residual - debug_print(f"After adding residual: {out.shape}") - out = self.relu2(out) - debug_print(f"ResBlock output shape: {out.shape}") return out From 26b07c855e495856d2420467139336adbba6947e Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 19:13:37 +1000 Subject: [PATCH 106/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/resblock.cpython-311.pyc | Bin 29829 -> 28750 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/__pycache__/resblock.cpython-311.pyc b/__pycache__/resblock.cpython-311.pyc index be2e58ad4d9e723f0eef3a3ce4d7a145956258fe..53bb9d00414143551a6500bc3e282b44b5dea1d9 100644 GIT binary patch delta 1142 zcmZ`&UrbY17{6cdy_MGXQg#t^rR%M!SubIz7#B(yjDg@bLUiR}G8!n*t}WWzo?A6I z5{+gHkY;0kF`_XUW|@DV%yFATW6aF5xhx6`586%Fm>5I0ESUJThh^uKut^rr!|!~* z@Avz@bMCq4e7FF&3!t2JI_&}*UmUs;8)>|x?DRk6-=LA&hZ*g4ixUvsRd`4Hbo--q?o*!o4kh>Ir-Hsm3`+VBQgUB8xBDp9&1dq@1#77u*~0M_&!@sIbAS zSmoM|f7)DFsE}}`;t}k^zc#$>A#7jf2Mi7gch$<{O-GxieRHaRUiFWa6!3bPuR1xA zoKnY)oG~UAlovjEhTdf~J`tZZ$M@v+Fcs5lE3xo$+iNaaUlg7Pey6N13eN``30wsj zKu>khdW7$;`>5MPN2%xT~#c>kL$w4 zjcDlDQVyr7Z_7h~b0}?{glXQ3Uf6m>gc+<2-df$qOZqv)(9&=WzNAwP!xFv{EW`5V z&)_D#+$jronPT1{g70>O;Q{@sBPMApwY;=DVP;8J(Z<1xEv=Op z>P*7#^m=E;x?+ijP@7D4F&mDDPFpXr;BSF$>ON@o{o6G!XGtWHA~D|V9EbB9E)?57 z?n%jCGx`$kd2g)+hNv07W8XZ(4ObZ)P+MZzDAC%Zk?cEB^Nroc;BXze#uRq#zu~T; z^%|P-kacNA`cpgv@;#;nsXuv4I`|6<7wWn9K92-o#S$s!iC>qqMQJ(vfZK{c=r^wO z$F`s`=+=RF*{oa0&uk)}FU3P_>eAN}zI?tUYmmWsZ7M}PoT^y^>y0M(4)+@y*`+vZ wtY6I!iSTbG_i)|_!bMcl9{7XSqJ@t35LZOdBsRdk28%GHjfvp=HZzPHl==!R!~>-f2NiiAApt^4=q#^-j{j#SLntVszoly+A46ke-2c*;~|qEN^9 zEj(C2y-ZB>9Lva-bS=D>ZQPz)G|#VZB{0TCo&8-J*q~W);s; zOYGwAWmZyPbxATQNXQA|1X<@CZOdu00y~@}dkIU~_Qc9vbaZPV?DiYIp-4Pt zj1IaYw8NM%E%6Z-;|kI-nu$7)=Ae7f6$;UTXltUvY7EC?MTFdH^n{OwqOM>hK({6o zTMnTtR5AaFprZ0T;|_<0hZ+&C%;SMjqtO-e7#TVc$MmvnXgz|mf(nDGJaBnD-q2x` z5%qfFt^f`!HH?}nI*Fx>!9Pe=#Gi~=i+sjPNRZu$-M#P+bN6r@xzP*gvp7Y}QEE=J zacbxE9<-pDrkXQUGe zfC9)O9@(6_>O$*SYxjgQt**_eYf}`zee=Yz$&;L>8O=0h&ro)bvggWGZ~D&q#+7G+ zBf%8O>*^=kC%ZUJJ(_8%AwxBAR0Hg*T{WDqGSxnfqW@V%eI`5-{+}wdXt%+$gQL-+ z-PTRiPFgt4IyBRiEkoHj%9f*Ksbgm<9#v|Ya!-3XO-mDc(v%}ZIXKEuXuNpKg0NTz zKdS9$Rf{b1#N$kzR)ntL@rJdnz6O9BUb7jYbWOGR9KO}vV>NpA`8D5>W@bOaS>ceE zKFSPWR}v3JkHl#@K{IZI87FM4bAkalR`(njXJ_j!18G0zu^ELa>psvA7j<>4eBB06 zgEJ1leG!jaZ5An0gOxg74*PBWd>c*RDm!oc9e~f_?S@hCG3#j@C&5klchfDI6H8s# z?1$+M=fOAZ;f;PF_zB*2Tm@OEikM4seHkx$SFGQIv}F2jG+K{opV5bo(Kp zu{4$y2E5T2vz?b=V}OdbDroNXfkpQ9&LiTIEhIZSWLZhn?~NQpzvvL$wc|bUXDF)- zu!}o}#ipg&Sy7B}c|-gQWu|c553#wz7ysmykZ1-$D_grqD*!aRsqY7=_!<^`36-8| z)pg`rNahR7@bq=K$Ssy^oAVz|d&fZZV^q{?oS`d(QM?{AL+GCxhE&vhZDh m3z&vK`}N=+`_OL!RvSRy-M>wW899C*^9#$t0d{M!jQj^{dPBhg From bbd29ebb18da0403be51ded7512cfbe52446c21a Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 19:33:50 +1000 Subject: [PATCH 107/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- resblock.py | 232 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 143 insertions(+), 89 deletions(-) diff --git a/resblock.py b/resblock.py index c464e68..6f854c5 100644 --- a/resblock.py +++ b/resblock.py @@ -1,7 +1,10 @@ import torch import torch.nn as nn import torch.nn.functional as F +import torchvision.transforms as transforms import matplotlib.pyplot as plt +from PIL import Image +import numpy as np DEBUG = False @@ -177,10 +180,9 @@ def forward(self, x, latent): - -def test_upconvresblock(block, input_shape): +def test_upconvresblock(block, input_tensor): print("\nTesting UpConvResBlock") - x = torch.randn(input_shape) + x = input_tensor output = block(x) print(f"Input shape: {x.shape}, Output shape: {output.shape}") assert output.shape[2:] == tuple(2*x for x in x.shape[2:]), "UpConvResBlock should double spatial dimensions" @@ -192,9 +194,9 @@ def test_upconvresblock(block, input_shape): assert param.grad is not None, f"No gradient for {name}" print(f"{name} gradient shape: {param.grad.shape}") -def test_downconvresblock(block, input_shape): +def test_downconvresblock(block, input_tensor): print("\nTesting DownConvResBlock") - x = torch.randn(input_shape) + x = input_tensor output = block(x) print(f"Input shape: {x.shape}, Output shape: {output.shape}") assert output.shape[2:] == tuple(x//2 for x in x.shape[2:]), "DownConvResBlock should halve spatial dimensions" @@ -206,9 +208,9 @@ def test_downconvresblock(block, input_shape): assert param.grad is not None, f"No gradient for {name}" print(f"{name} gradient shape: {param.grad.shape}") -def test_featresblock(block, input_shape): +def test_featresblock(block, input_tensor): print("\nTesting FeatResBlock") - x = torch.randn(input_shape) + x = input_tensor output = block(x) print(f"Input shape: {x.shape}, Output shape: {output.shape}") assert output.shape == x.shape, "FeatResBlock should maintain input shape" @@ -219,10 +221,9 @@ def test_featresblock(block, input_shape): assert param.grad is not None, f"No gradient for {name}" print(f"{name} gradient shape: {param.grad.shape}") -def test_modulatedconv2d(conv, input_shape, style_dim): +def test_modulatedconv2d(conv, input_tensor, style): print("\nTesting ModulatedConv2d") - x = torch.randn(input_shape) - style = torch.randn(input_shape[0], style_dim) + x = input_tensor output = conv(x, style) print(f"Input shape: {x.shape}, Style shape: {style.shape}, Output shape: {output.shape}") assert output.shape[1] == conv.weight.shape[0], "Output channels should match conv's out_channels" @@ -232,13 +233,12 @@ def test_modulatedconv2d(conv, input_shape, style_dim): assert conv.weight.grad is not None, "No gradient for weight" print(f"Weight gradient shape: {conv.weight.grad.shape}") -def test_styledconv(conv, input_shape, latent_dim): +def test_styledconv(conv, input_tensor, latent): print("\nTesting StyledConv") - x = torch.randn(input_shape) - latent = torch.randn(input_shape[0], latent_dim) + x = input_tensor output = conv(x, latent) print(f"Input shape: {x.shape}, Latent shape: {latent.shape}, Output shape: {output.shape}") - expected_shape = list(input_shape) + expected_shape = list(x.shape) expected_shape[1] = conv.conv.weight.shape[0] if conv.upsample: expected_shape[2] *= 2 @@ -251,41 +251,9 @@ def test_styledconv(conv, input_shape, latent_dim): assert param.grad is not None, f"No gradient for {name}" print(f"{name} gradient shape: {param.grad.shape}") -def test_resblock(resblock, input_shape): - print("\nTesting ResBlock") - x = torch.randn(input_shape) - output = resblock(x) - print(f"Input shape: {x.shape}, Output shape: {output.shape}") - - # Check output shape - expected_output_shape = list(input_shape) - expected_output_shape[1] = resblock.conv2.conv.out_channels # Use conv2's out_channels - expected_output_shape[2] //= 2 # Always downsample - expected_output_shape[3] //= 2 - assert tuple(output.shape) == tuple(expected_output_shape), f"Expected shape {expected_output_shape}, got {output.shape}" - - # Test gradient flow - output.sum().backward() - for name, param in resblock.named_parameters(): - assert param.grad is not None, f"No gradient for {name}" - print(f"{name} gradient shape: {param.grad.shape}") - - # Test residual connection - resblock.eval() - with torch.no_grad(): - residual_output = resblock(x) - main_path = resblock.conv2(resblock.conv1(x)) - skip_path = resblock.skip_conv(x) - direct_output = main_path + skip_path - assert torch.allclose(residual_output, direct_output, atol=1e-6), "Residual connection not working correctly" - - print("ResBlock test passed successfully!") - - - -def test_block_with_dropout(block, input_shape, block_name): +def test_block_with_dropout(block, input_tensor, block_name): print(f"\nTesting {block_name}") - x = torch.randn(input_shape) + x = input_tensor # Test in training mode block.train() @@ -310,10 +278,15 @@ def test_block_with_dropout(block, input_shape, block_name): print(f"{block_name} test passed successfully!") -def visualize_feature_maps(block, input_shape, num_channels=4, latent_dim=None): - x = torch.randn(input_shape) + +def visualize_feature_maps(block, input_data, num_channels=4, latent_dim=None): + if isinstance(input_data, torch.Tensor): + x = input_data + else: # Assume it's a shape tuple + x = torch.randn(input_data) + if isinstance(block, StyledConv): - latent = torch.randn(input_shape[0], latent_dim) + latent = torch.randn(x.shape[0], latent_dim) output = block(x, latent) else: output = block(x) @@ -342,68 +315,149 @@ def visualize_feature_maps(block, input_shape, num_channels=4, latent_dim=None): output ] titles = ['After Conv1', 'After Conv2', 'Final Output'] + elif isinstance(block, StyledConv): + intermediate_outputs = [output] + titles = ['Output'] else: intermediate_outputs = [output] titles = ['Output'] - fig, axs = plt.subplots(len(intermediate_outputs), num_channels, figsize=(20, 5 * len(intermediate_outputs))) - if len(intermediate_outputs) == 1: - axs = [axs] # Make it 2D for consistency + num_outputs = len(intermediate_outputs) + fig, axs = plt.subplots(num_outputs, min(num_channels, output.shape[1]), figsize=(20, 5 * num_outputs)) + if num_outputs == 1 and min(num_channels, output.shape[1]) == 1: + axs = np.array([[axs]]) + elif num_outputs == 1 or min(num_channels, output.shape[1]) == 1: + axs = np.array([axs]) for i, out in enumerate(intermediate_outputs): - for j in range(num_channels): - axs[i][j].imshow(out[0, j].detach().cpu().numpy(), cmap='viridis') - axs[i][j].axis('off') + for j in range(min(num_channels, out.shape[1])): + ax = axs[i, j] if num_outputs > 1 and min(num_channels, output.shape[1]) > 1 else axs[i] + feature_map = out[0, j].detach().cpu().numpy() + + # Normalize the feature map + feature_map = (feature_map - feature_map.min()) / (feature_map.max() - feature_map.min() + 1e-8) + + ax.imshow(feature_map, cmap='viridis') + ax.axis('off') if j == 0: - axs[i][j].set_title(f'{titles[i]}\nChannel {j}') + ax.set_title(f'{titles[i]}\nChannel {j}') else: - axs[i][j].set_title(f'Channel {j}') + ax.set_title(f'Channel {j}') plt.tight_layout() plt.show() +def load_and_preprocess_image(image_path, target_size=(224, 224)): + # Load the image + img = Image.open(image_path).convert('RGB') + + # Define the preprocessing steps + preprocess = transforms.Compose([ + transforms.Resize(target_size), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + # Preprocess the image + img_tensor = preprocess(img) + + # Add batch dimension + img_tensor = img_tensor.unsqueeze(0) + + return img_tensor + +def test_resblock_with_image(resblock, image_tensor): + print("\nTesting ResBlock with Image") + input_shape = image_tensor.shape + print(f"Input shape: {input_shape}") + + # Pass the image through the ResBlock + output = resblock(image_tensor) + print(f"Output shape: {output.shape}") + + # Check output shape + expected_output_shape = list(input_shape) + expected_output_shape[1] = resblock.out_channels + if resblock.downsample: + expected_output_shape[2] //= 2 + expected_output_shape[3] //= 2 + assert tuple(output.shape) == tuple(expected_output_shape), f"Expected shape {expected_output_shape}, got {output.shape}" + + # Test gradient flow + output.sum().backward() + for name, param in resblock.named_parameters(): + assert param.grad is not None, f"No gradient for {name}" + print(f"{name} gradient shape: {param.grad.shape}") + + # Visualize input and output + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) + + # Display input image + input_img = image_tensor.squeeze(0).permute(1, 2, 0).cpu().detach().numpy() + input_img = (input_img - input_img.min()) / (input_img.max() - input_img.min()) + ax1.imshow(input_img) + ax1.set_title("Input Image") + ax1.axis('off') + + # Display output feature map (first channel) + output_img = output.squeeze(0)[0].cpu().detach().numpy() + output_img = (output_img - output_img.min()) / (output_img.max() - output_img.min()) + ax2.imshow(output_img, cmap='viridis') + ax2.set_title("Output Feature Map (First Channel)") + ax2.axis('off') + + plt.tight_layout() + plt.show() + + print("ResBlock test with image passed successfully!") if __name__ == "__main__": - # Run all tests - upconv = UpConvResBlock(64, 128) - test_upconvresblock(upconv, (1, 64, 56, 56)) - visualize_feature_maps(upconv, (1, 64, 56, 56)) + # Load the image + image_path = "/media/oem/12TB/Downloads/CelebV-HQ/celebvhq/35666/0H5fm71cs4A_11/000000.png" + image_tensor = load_and_preprocess_image(image_path) + + # Run all tests with the image tensor + upconv = UpConvResBlock(3, 64) + test_upconvresblock(upconv, image_tensor) + visualize_feature_maps(upconv, image_tensor) - downconv = DownConvResBlock(128, 256) - test_downconvresblock(downconv, (1, 128, 56, 56)) - visualize_feature_maps(downconv, (1, 128, 56, 56)) + downconv = DownConvResBlock(3, 64) + test_downconvresblock(downconv, image_tensor) + visualize_feature_maps(downconv, image_tensor) - featres = FeatResBlock(256) - test_featresblock(featres, (1, 256, 28, 28)) - visualize_feature_maps(featres, (1, 256, 28, 28)) + featres = FeatResBlock(3) + test_featresblock(featres, image_tensor) + visualize_feature_maps(featres, image_tensor) - modconv = ModulatedConv2d(64, 128, 3) - test_modulatedconv2d(modconv, (1, 64, 56, 56), 64) + modconv = ModulatedConv2d(3, 64, 3) + style = torch.randn(image_tensor.shape[0], 3) + test_modulatedconv2d(modconv, image_tensor, style) - styledconv = StyledConv(64, 128, 3, 32, upsample=True) - test_styledconv(styledconv, (1, 64, 56, 56), 32) - visualize_feature_maps(styledconv, (1, 64, 56, 56), num_channels=4, latent_dim=32) + # Test with dropout + upconv = UpConvResBlock(3, 64) + test_block_with_dropout(upconv, image_tensor, "UpConvResBlock") - resblock = ResBlock(64, 128) - test_resblock(resblock, (1, 64, 56, 56)) - visualize_feature_maps(resblock, (1, 64, 56, 56)) + downconv = DownConvResBlock(3, 64) + test_block_with_dropout(downconv, image_tensor, "DownConvResBlock") - # Usage - resblock = ResBlock(64, 64) - test_resblock(resblock, (1, 64, 56, 56)) + featres = FeatResBlock(3) + test_block_with_dropout(featres, image_tensor, "FeatResBlock") + resblock = ResBlock(3, 64) + test_block_with_dropout(resblock, image_tensor, "ResBlock") - # dropout - upconv = UpConvResBlock(64, 128) - test_block_with_dropout(upconv, (1, 64, 56, 56), "UpConvResBlock") + # Test ResBlock with and without downsampling + resblock = ResBlock(3, 64, downsample=False) + test_resblock_with_image(resblock, image_tensor) - downconv = DownConvResBlock(128, 256) - test_block_with_dropout(downconv, (1, 128, 56, 56), "DownConvResBlock") + resblock_down = ResBlock(3, 64, downsample=True) + test_resblock_with_image(resblock_down, image_tensor) - featres = FeatResBlock(256) - test_block_with_dropout(featres, (1, 256, 28, 28), "FeatResBlock") - resblock = ResBlock(64, 128) - test_block_with_dropout(resblock, (1, 64, 56, 56), "ResBlock") + styledconv = StyledConv(3, 64, 3, 32, upsample=True) + latent = torch.randn(image_tensor.shape[0], 32) + test_styledconv(styledconv, image_tensor, latent) + visualize_feature_maps(styledconv, image_tensor, num_channels=4, latent_dim=32) + \ No newline at end of file From ee710f28923ed043ba2145905743b68ccf9b9305 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 21:00:08 +1000 Subject: [PATCH 108/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- resblock.py | 224 ++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 173 insertions(+), 51 deletions(-) diff --git a/resblock.py b/resblock.py index 6f854c5..06e36d8 100644 --- a/resblock.py +++ b/resblock.py @@ -12,26 +12,155 @@ def debug_print(*args, **kwargs): if DEBUG: print(*args, **kwargs) +# class UpConvResBlock(nn.Module): +# def __init__(self, in_channels, out_channels): +# super().__init__() +# self.upsample = nn.Upsample(scale_factor=2, mode='nearest') +# self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) +# self.bn1 = nn.BatchNorm2d(out_channels) +# self.relu = nn.ReLU(inplace=True) +# self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) +# self.feat_res_block1 = FeatResBlock(out_channels) +# self.feat_res_block2 = FeatResBlock(out_channels) + +# def forward(self, x): +# x = self.upsample(x) +# x = self.conv1(x) +# x = self.bn1(x) +# x = self.relu(x) +# x = self.conv2(x) +# x = self.feat_res_block1(x) +# x = self.feat_res_block2(x) +# return x + + +# class DownConvResBlock(nn.Module): +# def __init__(self, in_channels, out_channels, dropout_rate=0.1): +# super().__init__() +# self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) +# self.bn1 = nn.BatchNorm2d(out_channels) +# self.relu = nn.ReLU(inplace=True) +# self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2) +# self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) +# # self.bn2 = nn.BatchNorm2d(out_channels) # 🤷 works with not without +# # self.dropout = nn.Dropout2d(dropout_rate) # 🤷 +# self.feat_res_block1 = FeatResBlock(out_channels) +# self.feat_res_block2 = FeatResBlock(out_channels) + +# def forward(self, x): +# out = self.conv1(x) +# out = self.bn1(out) +# out = self.relu(out) +# out = self.avgpool(out) +# out = self.conv2(out) +# # out = self.bn2(out) # 🤷 +# # out = self.relu(out) # 🤷 +# # out = self.dropout(out) # 🤷 +# out = self.feat_res_block1(out) +# out = self.feat_res_block2(out) +# return out + + +# class FeatResBlock(nn.Module): +# def __init__(self, channels): +# super().__init__() +# self.bn1 = nn.BatchNorm2d(channels) +# self.relu1 = nn.ReLU(inplace=True) +# self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) +# self.bn2 = nn.BatchNorm2d(channels) +# self.relu2 = nn.ReLU(inplace=True) +# self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) +# self.relu3 = nn.ReLU(inplace=True) + +# def forward(self, x): +# residual = x +# out = self.bn1(x) +# out = self.relu1(out) +# out = self.conv1(out) +# out = self.bn2(out) +# out = self.relu2(out) +# out = self.conv2(out) +# out += residual +# out = self.relu3(out) +# return out + + +# class ConvLayer(nn.Module): +# def __init__(self, in_channels, out_channels, downsample=False): +# super().__init__() +# self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2 if downsample else 1, padding=1) +# self.bn = nn.BatchNorm2d(out_channels) +# self.relu = nn.ReLU(inplace=True) + +# def forward(self, x): +# return self.relu(self.bn(self.conv(x))) + +# class ResBlock(nn.Module): +# def __init__(self, in_channels, out_channels, downsample=False): +# super().__init__() +# self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2 if downsample else 1, padding=1) +# self.bn1 = nn.BatchNorm2d(out_channels) +# self.relu1 = nn.ReLU(inplace=True) +# self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) +# self.bn2 = nn.BatchNorm2d(out_channels) +# self.relu2 = nn.ReLU(inplace=True) + +# self.skip_conv = nn.Sequential( +# nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2 if downsample else 1, bias=False), +# nn.BatchNorm2d(out_channels) +# ) + +# self.downsample = downsample +# self.in_channels = in_channels +# self.out_channels = out_channels + +# def forward(self, x): +# residual = self.skip_conv(x) + +# out = self.conv1(x) +# out = self.bn1(out) +# out = self.relu1(out) + +# out = self.conv2(out) +# out = self.bn2(out) + +# out += residual +# out = self.relu2(out) + +# return out + + class UpConvResBlock(nn.Module): - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels, out_channels, dropout_rate=0.1): super().__init__() self.upsample = nn.Upsample(scale_factor=2, mode='nearest') self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.bn2 = nn.BatchNorm2d(out_channels) + self.dropout = nn.Dropout2d(dropout_rate) + self.feat_res_block1 = FeatResBlock(out_channels) self.feat_res_block2 = FeatResBlock(out_channels) + + self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): x = self.upsample(x) - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.conv2(x) - x = self.feat_res_block1(x) - x = self.feat_res_block2(x) - return x + residual = self.residual_conv(x) + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + out = out + residual + out = self.relu(out) + out = self.dropout(out) + out = self.feat_res_block1(out) + out = self.feat_res_block2(out) + return out class DownConvResBlock(nn.Module): @@ -42,8 +171,8 @@ def __init__(self, in_channels, out_channels, dropout_rate=0.1): self.relu = nn.ReLU(inplace=True) self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - # self.bn2 = nn.BatchNorm2d(out_channels) - # self.dropout = nn.Dropout2d(dropout_rate) + self.bn2 = nn.BatchNorm2d(out_channels) + self.dropout = nn.Dropout2d(dropout_rate) self.feat_res_block1 = FeatResBlock(out_channels) self.feat_res_block2 = FeatResBlock(out_channels) @@ -53,83 +182,76 @@ def forward(self, x): out = self.relu(out) out = self.avgpool(out) out = self.conv2(out) - # out = self.bn2(out) - # out = self.relu(out) - # out = self.dropout(out) + out = self.bn2(out) + out = self.relu(out) + out = self.dropout(out) out = self.feat_res_block1(out) out = self.feat_res_block2(out) return out class FeatResBlock(nn.Module): - def __init__(self, channels): + def __init__(self, channels, dropout_rate=0.1): super().__init__() + self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.relu1 = nn.ReLU(inplace=True) - self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) self.bn2 = nn.BatchNorm2d(channels) + self.dropout = nn.Dropout2d(dropout_rate) self.relu2 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) - self.relu3 = nn.ReLU(inplace=True) def forward(self, x): residual = x - out = self.bn1(x) + out = self.conv1(x) + out = self.bn1(out) out = self.relu1(out) - out = self.conv1(out) - out = self.bn2(out) - out = self.relu2(out) out = self.conv2(out) + out = self.bn2(out) + out = self.dropout(out) out += residual - out = self.relu3(out) + out = self.relu2(out) return out - - -class ConvLayer(nn.Module): - def __init__(self, in_channels, out_channels, downsample=False): - super().__init__() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2 if downsample else 1, padding=1) - self.bn = nn.BatchNorm2d(out_channels) - self.relu = nn.ReLU(inplace=True) - - def forward(self, x): - return self.relu(self.bn(self.conv(x))) + class ResBlock(nn.Module): - def __init__(self, in_channels, out_channels, downsample=False): + def __init__(self, in_channels, out_channels, downsample=False, dropout_rate=0.1): super().__init__() - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2 if downsample else 1, padding=1) + self.in_channels = in_channels + self.out_channels = out_channels # Add this line + self.downsample = downsample + + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, + stride=2 if downsample else 1, padding=1) self.bn1 = nn.BatchNorm2d(out_channels) - self.relu1 = nn.ReLU(inplace=True) + self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) self.bn2 = nn.BatchNorm2d(out_channels) - self.relu2 = nn.ReLU(inplace=True) + self.dropout = nn.Dropout2d(dropout_rate) - self.skip_conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2 if downsample else 1, bias=False), - nn.BatchNorm2d(out_channels) - ) - - self.downsample = downsample - self.in_channels = in_channels - self.out_channels = out_channels + if downsample or in_channels != out_channels: + self.shortcut = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, + stride=2 if downsample else 1, padding=1), + nn.BatchNorm2d(out_channels) + ) + else: + self.shortcut = nn.Identity() def forward(self, x): - residual = self.skip_conv(x) + residual = self.shortcut(x) out = self.conv1(x) out = self.bn1(out) - out = self.relu1(out) - + out = self.relu(out) out = self.conv2(out) out = self.bn2(out) - + out = self.dropout(out) out += residual - out = self.relu2(out) + out = self.relu(out) return out - class ModulatedConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1, demodulate=True): super().__init__() From 7f9f8977b1917b028c1932e56640c897df2cf3e0 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 21:12:07 +1000 Subject: [PATCH 109/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/loss.cpython-311.pyc | Bin 10304 -> 10304 bytes __pycache__/resblock.cpython-311.pyc | Bin 28750 -> 31274 bytes __pycache__/vggloss.cpython-311.pyc | Bin 7076 -> 7076 bytes config.yaml | 4 ++-- 4 files changed, 2 insertions(+), 2 deletions(-) diff --git a/__pycache__/loss.cpython-311.pyc b/__pycache__/loss.cpython-311.pyc index 2649d8ba9f9875c96e05ad040a489eb453058386..17b857ff8dbeef85fdd9b24b2898ca4934fe49ce 100644 GIT binary patch delta 20 acmX>Qa3FwtIWI340}$No*s+n@R09A+QU)mi delta 20 acmX>Qa3FwtIWI340}z<=tl!9OssR8z&jkVi diff --git a/__pycache__/resblock.cpython-311.pyc b/__pycache__/resblock.cpython-311.pyc index 53bb9d00414143551a6500bc3e282b44b5dea1d9..af52b1be70953f8337452835abb48ce075fe57f4 100644 GIT binary patch literal 31274 zcmeHwYj7J!df)&Uyg?G+n-m}5n-D2cpe5=>i@Zyc5<6!hSyhQaHWeXSFSzRrLL&6Q?=#$ zarb>az+eUdDLU)jWUDe9e$&(abx%LP@9VF>ZoHM1Wnti&|L($-u5%3YKky}S8RL%6 z?y(H>3d1p+#?SOhPtCBlSF4fSx?ycUxQE$ZecZ<|Z0t40-KJr4uZ6m`e(SKU*B1BB z>doeKy*V8FBGYTvGKUyW|02T~eA*#LA_H&g80H=Lm#*Giie&^W(;8TL6w3@)mNl^Q zDV7zmY-?Z@P^>J#%3cGjkYeQkmVFJZBF+INE#`6omT-9hOSyc28@K|1Wn3Y^a;}am zeo@m~!PRr6@T}xEfUnb6#g$#tVlB*Q_?Nlz-fAu5tMpaBse$@|f9dM20pDD%f~$NJ zaNnW$7}8q{{>8ppUya&d!+?LqD*kyH(y0Ud5?|da>16iL=PFWiSr4g{`s!CnC9{7a z<_1zy_lE{Zi5Y5)7oYLf47n0+jwwkv8HlH@kBty-gRj9?OP26U4Bs7K78>bE?C7F_sf!sq2K6oPLF? zN{iD><1}ZCQzW$vR}JxM-qb_CgMaDj-30O0(>~Sey!38{cxAp#&_0{hxsUf)2;qvj z>U*_s0$q9s{-q1(ism%^cQt&9UOg^NOn+q9JK#ImAIG5&%wZ1vf3^j{D@>4jN`TJ} z%Jha|Ld+l&(!8PhrUq#iqz(UXYkFMTn126(eP<8F^ke+sND$sxFFz278Lr*LX8?zu z)A>bZ%dn3d^tO!phFjdNXZE!m=|0%P`vQIb(f(_j$8N{0oUddpvp4T(Ic;wG3J6Vj>)6~r)QjA2gra>je6v^bQVB8KWK$sB6e$RRdA zy;O16%AqY|jI4xjNSBts>_k|wGU0-bo?do)BJ?qMGP73YIayw85&b>0dLTpH<|fj| zLFKaae2Rl6W*qT(VT9n2vIY9RexK)xw?8<_$JpUf&i5&Hk592zy{(H`ule{9pWhQ0 zobbgAfgnG~`C`T~FUJjz44ipe6EhBujQPF&K9`0shA6JQm_9H*=Hp`~k7saXFzE5b zv?C)i)7i0rcX-V2iyB~SxMV6`G!=`cQer9_^ zNs2qe>~z+Hf}+{4MV_5IcHb+UI4L}RQ7E`17F;3)mmt*gAg^F{PqbOgYaw|p;9+`@ zUpRXxx?jk5iTN&)?*c!wltgIm;{EHw$y36`OG3eAvEVW(xSTnOqLSHA^q5%WCPi+* zuzgb5GVpCIKY5D^OnH;<#JANyb8 z!hS}H+<%A8_6qQUXd*=Q3~AKEYf=N48HrvRrWP=jUzF2L!_)z$OcrG{aHJ0INyL;1 zT$z zb|uNKr16%rs~5AYqu|v*vKumc<>ZDfbS%BC;ZeSN+km8Y#5?Q*Qfu*efP{|&Mfb#P z9?x~4Z&C=QO8joXk7dOvSHI8i_jn+cZCE!rpq`Ql7&0#&!8b7A6aXMriN6(Xmj0w> zg;^6;Wb5-*n6&^c0sTlnq%U8IA^7D)D*OJ?n_F7#rEZ~%(3_MYTCFCamv6M?itfWzRUtuqB%fx0RdyRa%C2%p!Cv} zLG1ZwBv4Vo$kf9ZKxX+u1VsSO@Eg#HVwX!x1sfixx~^Ql9Q~2S^C;Qz)d(mpt3_8G zf_el<)newoHwI3Oj`~qh@!l911E+%0ZA3gu?wZjh(>xw$Oc|w$lGKha%N1{7lw|;` z%N1XN1c4&Kyk?18{P6 z@gWnYJXoL=%*m8bnW#KAO_|U<7BjQZkp6XTLKHBk z3IK=T8;Y@^mI7QlDh2R65NQ{H09JN_`BSzwQJgk5;;jGQBmUa}fc1Y_ww}tMsZg*Q zpU-{yt7PwdJ7k%E62U$!fr3HrM^6NAV!$x~Dh6%Q?^q{5h3yK$$Xg;lJ)ar^)0!)dSIlmR!O_E!&8G#Te=iUWMRe;PPKLEtQj{rP|L*AMzUbz0TLpXC* zC^#n;oFfJ2lJnz6F|Uc_HGzi-+1~l+8KGdKSg?^4Y@|q($F4ht)ERcBB*T`txSx6I zyi4MJ8p>CoaOqcSN2U|U%=|A!k3}&v^Oa1Pa;H}2;WspW#z`@BOPXVOm_8;S^0d^b z;6+^!NzoH28$trhZwH{Bb4qL{U)9;-Ea&HlJOuz)PF{Z4^mx9LGJi~YO__VDxsu5m zzI7)jeHtoH6xZM+lsqH&8Ri%Zs^TOjebWlF8pyXGCs_+foCK5ahkyA>bd+>XqNjRu z9o3t&D!n;OaudqEgVO9yrh||!Qc6tqvu{Fo{VMng&}GIWcYT=sFrGxNK_dijN$w$2 zS5U0WT~C?^hIBm41V=O+tODF)(^Ta$l?^J>kVUBs)H=x#lrdzgPZjbow3DlqkW(2` z^_65!rI}sk9NH3Xri!~(4rSh;iXp4RLu}f76@>$Og>(S}*4s5&)BF)t3Q0w|T%Ipi z->TGBEkWA6SXR6vuT!NVtJdFy+`khq1FmLyRB>F!IDbUn)Cs05Pr)Qx5#b2{F$)L! z3`t$Sj#;LHTqPN&<+ShmxNjsl==H};fvcl@uzx%lGacd3FL>K!kOXS^oBS#KdXlYT z{{XhCW~`3H5R{o}4A25Qpv$ewRF^l@qY}7yxvUYNCCjC?z$tS+NzO)>o3;t{?NM*k zyWqa>g$1|5BbrAH!p|v0#`?)`Lc#seTCld=4^{4TZNLX56UX;c0aJ^&YB}l z(a?Rza{U&eYJ1cXbXP{cYgjRg@X>F?h95vW7ZG#4suZ+xErh76PI5J|&H1V;R zO4=7uAb?a87d@?Bpwt0?5_ZP6(WIZ z5|+p|6O>spup_x!gTSBxF+ry`ok(MSa_0`Qz&7OEhxG84WKjlqlfQy&D6yELRz=3l zC%inY0t8|0RS%1~{a0c3b=SYN8ZU^{^+~UpP3xrcM1fJ4$5lGXdZjznUPZ~nn zDC8^s3h6jahy`_lMlGn}%xXv`mzfs{xfD$o%toLuOd27kOceQmN_{OZ1fDdN~r z)PLsOrPS3rVi?r*qaLZya+=JgfIgxNu{2-$kRc_OCWxyD@oPvHRDl*?PmqP_kPa4L zpkE}7swT}Lb4qWuOj>AfHD}HTPC_1AD=BRTbR{*%@LjMAf#nan@`1fvTJzA+2VE5T z6npS5c1>^jyKmk;u!nlVk_aUuh5>#QtT|j+VDoX+JLZeAHwJw-`AZNXW*#32fCTEB z@JSTpY4p_df$MxQrW+f*8Pf&Ehr#Fq?MgHUV+ORsXys_SE^AWT3n>S0gVl(>4~;`& zmZW(`%y5;GwV3|qAQ!wUeN|5h8&wg}_HFih;>3oZYbGs9$u{ zN3V+x7jd}e4u}qS*znMrGu^SmXw5yE2l*wlT~WhZme(zF?(b*4k+tM>FFM`xF8mXn z?Znx>5EPxAV*Y-Tzdw9nxnM)Y_r0N4hvo|A*>|l=Tb^3n^3;8YxTTY9>AZjPW6j5I zv7(z)bc+Su;e!tyMXzl5_J(DfV`l$s*HV7%Vt#G(q?q48@*6~3!(8e7nRlOF+IC=Z z+kyLDaob_C?eNF7KW|z(+PiqP7wk#IqnF6hOOJHgT+8PSfUsf3#AIc^eDI}%GwvB* zWJoM*oU5O&6Ky+)ZAV!1AUkLJ>df^>{ao$6=7n(~`+%5zfMg#4Jw!#--9gc`A*>0z z!&hc}%h?XPcs?HC?p|A}*t}S=dG5Mc;U*RC`L=~^V)kB=y*I2=h$%X$J7S+alUY`| z06{#0{|Eo+vWi|`x5BIjvhz_^Sqn(A3MStV|MHb!HeXIwv7E*ZOa+yLY7qrl0H*ob z-vjunMHHM#1CcP?)2K-e2os5NtR}p$pfUtfkhE=vI7xPe!imhVRN{{;vZS%Y6amJi z1EWz&76v5i$3O1?+mNI6J&wK4QH4n3Z>Du58 zZ*XvQB(p3NGgA`c;Rc5#`;*5?GKpXawmk?-e0hM+=L3jkdHsU}V2C#gwugKGloeGNgr!CUJ&ZeO zyA2?|Oh8u&V+KSR2~w0ydfQw@8^>vX9X9bZkPy&*AlpxZ*{}J$=7HTYn-%fSdFO|O z6K4hcInjQO*w2OcKD6gke$Xb`tBJijtVhG5BQG79@kRSa29%Lk`RMbGLPHMH#57YaMDjuj7hajeO{Q<+lj7a(ZT8c6d(p7|mg)#FD zZE}Nwa{h&zSd?j};_gqekyhvb8kOb)mZPu{=1s_4QhE0R^z38~Ib-tcFy`b@la%Qd zDVfq!ie?-UYP=|p?X*jc1NFjWo+GI7_A{5ZUxc#ZJSU_dlF9K@7L?Q+V*h}7O{K?f}LvXyDzeijEs#3om4v9<^00l z>^wdmRQOOT&@cB9s0gfORDyoy1klB3D80L6o&>}CkOt-MbIdCm__hVNA-t`n3G}=T zu9$Z77CQ9f8w04Y_l3EAz&GN%HOB9ms1HIOJ>z4PuO#*dw1W-4*B{ukS;lL_L01nS z|DW(T#R#=~;GR1fJsG|J{nOH$64K!`(V{Ll6nJ#p&pEl#@jk!L85jfh zHB!;GXl)aud&f1=+D51w1bw~rcaOes^!Z~i9$R8d7TFS!-9Xq40=r?ksBEcd(_+yk zv8aUQjnHOz*6%Q^Ye6GBcM-VIMJ)B=E8%&8;P4@|*l@`dbOaL?)QSxDm! z=|wc{CZ^p2y;t;!Bx#@f7wY{l)cdQU-X${iPBo$8pGS)E0s>sVS|6bX(UX|~Hl!rf z{EL{zHxT>|040g$!{9N|NT~&`dp(9c?@uMq9YSr#oHu%MPV@aU;7qwC^1LOrUZTO0)9hP|B{qFV! z_d@GJ>$|(*RjlnG04qC4WyhkmLy+#h0nyq)sQb5^+99i+fEQ{zyXkcH&aE;U|?OejmFR3_A~QzkTkOen`(mrQ8vnJAE{iqb+R zNl_*mmHvqVIqZe`9sE0~)B(PrR!j~-mUikhif~S^nuyWq8|0cv zrob7J7>WUCR1XH2-m&W$T+ww5EF6GIbxU%1OyRULZ8E5;8*}EE$HEo=8DwyZhjK;D}&t13Jj6G3<of`a(otXNt#7m5-gYI>tIq~LD$DOejbusB72#z zmj(86k~f07npCvRYe+@wqP10!?yXluYb&AdHODk4kaoj-M6fpC?b<6^8whpNDNTEH z{~KMTb}OmfC0Jla(=M`4681?UaZ3_@#$kE`y8k+O-fw7_Dz!4oA+OfBfIcFl{v_Gu zRYP3Q8QP(3A&$&-cxOn@8EM~kx_4G!h%vfc)&K)%$3Oy=qaSgAA)$!nsNt6M{c$7*^2gR?1P zD$BX5rYEasB2OkiN&Trrc|f;wjH+DKqz8$$w48%82)cYYkrgLGl)F18ia}4}3yd@b zos!}M)HS|-uojTC7BmIIQM=Iaq{l#)aPmAk38VAuF?)M%KmK?-nIh@v() z7O1+qM0(|Wx(;8JQ*04>ffa{d{e2#^(%2@PK3yxl-WCCIPF{*kd zSEi$@r0PK}SGESNFB8;PN-v_qp}A+mEK}!+@-sY4 zqfjLe@{`ZG$*I!BQ?YQO$pt@#v0tNnMS{cu*BKej#R! zztL~3MfQs**_RT#IT*a^Nt6p`^tYe|0{u(bkPOW6R@O3Uwd zz0mc1_lw+PDoCXJx$ot0(<%&42->C`niaIu1n0>L2`DiFpU0~R)Pz? z@Z3$`WcY}Nf3Cnvts~aH#O5!u`663L*uo{Ya*?e}GS^?bnjK0V?iDOgNVmvtCG1uq zaR)%)WPY~Gwl7!rK`y(m(D*@#1|a_oc&pAEKw*cnqTjcrit8awP$sA_hn3wa9nPOJ!LUzLYQfa|WmGAuG;E(7|6(qC|W&4pe_l{RXizl63Hw$!{_(oNz>~%J+mD{J+9n zsq?tj8+=z$TZ!DzeEwFKOI_xWVSt@9f+4H+bQ|1B(iN?j!SyAJ#N$ls^* zc4xAF)V-bm6R5@2=n(RovlS`ku!zNXdH7zn|cT!H3+gyT0>^eAgQ+awtv(&60wVd?qm*!WI-Wh z9ncJ<7>Z3=M|LH@8>sx1zFV;RcT?)t8pQ1?@hi3F_rvMFmwIWf*e0`rxT`CErQ|Ma zJvFcXGEIBqWHv1`T|?kqDXFfnJq_2S*p`qrWD8|qQ7An}jOZsn(5rt(%^DvoV3Y3#A$~Aizgn0fyJ$SY5gKk3Q;cAy1%d;92^0~fuzL+tJ{AIsqjApFp&Yhm7m9t==eWI zEQ*RAF@qGH>V`kf2cJ(-t%WCwUKt#~Z6UlD(Tq0+;mpy&00gl8!`?AC{O`&YP;_)k zhjcn8ObM5(4z`B`21f!x??}H-QXd%3NM>snAt`E7OBD?C+&O8xC-wo!IIaO*bqF>g z*pJ{4g3Sn;5p*Fqh5$)>Oy~EFz+M(V?5hrp_l@~SgMpZC7*q_}k+GN_?A5)uV`ksT z_%IX&P7KA}cmr^VD(4G&VJnNSe+*7Fg-|%6SU2pw6*CMDgW2FsIOX`(U;vEPePDDq z7=*nqwjgfb^Z32DL3IIJW&p|M;8Ed$`(aSC!Fhw;m<`fOo>pp+?4POCJ5_sA1vsv< z#qy~MH9a*s=<`VG5I9zuek>5PVloL&9kfME=e-q(X$IrFYKCH3$gCBY)5m$A2bv^+ z!Ut%})t7WGW8sa#031RH&FN9b;Qt1q%xXY~G6t;{qj4Cr`196uU&*Mq=9t8*Ac0mRBv&ZEx*-edpZSA3ptqrxyy|uXwNGV~x1; zDA{>b+|)%jb%`~{NX;>^xSJGri?$wO>w)M?wyH&2RkZG{rq`S1d_Nre!4O0(d9URD zwvTpyu>0fjzkcp7o)a#6#N*G9nJyTsySr1+R<>n67D^o$SxV0huc`$yk9 zdVl<*+aKH(PM%vjd2#XNMd8wA@uY{G^oWO^A%~t3+r6aSD{k&1oBPD-ep1~p+Bjn4 z6bUX>Z(pq5zMy~K`kwXvfseXB=oU_$UpjSh@f0N1w{+>s;-xFXfPZO#UmW0tKu{bQ zCj;Z+r5ohZ4e``Xa_XkI{}$PQOWbjr?6@sfPmt;f(e^B{J)2TN9;l$paAPN+74%xW zXl^9tM%=Tks2}UA*64w^j=X+k&inmiZ$JaOVZBAPwGvxvvi*}tz8_xu!L^0D_nY2p zy6^jF@Pk1(n)he^5B=h<(`46aaq}6n`HWb5meigVi_ekbbE54$v7N`f8Cj|EEKKfG zkhc!FlX=@En(K+V9*z9VmYwyHThPqQxur9uv=6I2KFKW*3O99 z^~ZU07Z=)o(*CjLCmo`zi@3Uka~Fi(r-jRdV(*}Ec1Xy%CgxltIoF;)Jk=S_TQ=s- zco&UDg0bke&Zuqf%02&r``)NnwuhAMxnH$d)+v;AKFGFDUybBN^fNaTCzirqFLQoG zvuG|6%q6d#j+V{cxEEU3d++OFc?T))xbI#p-!GK!mtyK>#-;O7Kgr!7lx-1n-6YqI z=ZJ24gy4<=2Ky5*mcJ8-w9n<;E17rSD-(;_NKu>U*iIbVp{dM(3Lw09xxzKK_tnwJ z=%3Wg`|lt8@E9Cu{NZu2^$ck}Bb>jm*m^-|y)ctEb26ftIWt=_QzBgU3Vr?LG8f*r zTv{ISMjfyj&9zjzaj|rxSlUcVo0m$rES7E&OIt~4>r&~C#nK&O=`K>bE8GS8$+CvI z4L_{>L8a*0MO?eY4Z9YeegEt4eO>H0ff~`%mzPfWEuQWZPjlonCw6ecz%{Ysnpo-& zcRjRa(>)>A=FH(EQcn`FGiASg=%qt52k!Jt_k??3Q^?HTmu`e_EZcKmv3=VXHAJt! zVHV1Fz&(FbdW!a)#J*Fo?*zv~2`zGYP7}Etxhx$wA=sP2@u&nU=pV`$y@@0=r~+H~ zfHl5&`1zwR9-X3p)Vknp?WYfc2=o69fXbE_r+ZX&iRXH#Ee(>Z#34p4CS|s3OWip# zsSPS7u*rQCDqj=>=#{GERACtsSZV=Cg(H7UQZ506DL3qh0;%U6S?xNIBBV{Co3SS?f7>ze=Qn=(z0#Hk9Rt7L z`Fs9f{NKOnfWY!zZuj3F|I3bn+n?6EP4|A@G4Pi^r&kC5-odcX3lc^k$Vq~lna6gC zv7=)kco=ah+6TL7EwFMj5`et`!-1Hwb98tN7D-|TIF$mo)|$?YN+M9qj3=v8PR#ot z75F6r%OL)01DO5?ug0uF_)P_nIA|o@;P61qGRFJH_)%P`LCk^p4uDh}of7}3m-E0v z5d6SG0^dUu=Kmf(vy43cHE>KZpICCHH_r5m7ALVdmn@Bomd3dP(b7UJEmNJqjdRP0 zqvnAr7uT?xL{mL6)zh~PqN$vi%Bj0dsNNpGpFAZ7bqE$XA|tc*(xPSuFy>8BoV2E0Ty_RbMT;gQa5-t1y7jdm zmzC82{6<^I3c4rjx?9M)t`=H%dbEW=kGV<%CXJ~c~j0Tk!cG^#ICFWBDMtm0mN=Oq{$!tXW(x0ODe?ahK46BZR z6O}`t3UGFN$DEA^2jQR_XIv-an%E@gm|B3)Y*XnVvoT0<(#(&K;)k9!_z1zr2>t@WpCkCs2!0O0 zWuw0{k=Xi31te+fBknv$U5jKdNEImsvJs)AFG;l}|HKzM2r@6mBG`mE%ZcWt8S$m$z1E0>D4EEa7M zi&{xh>#xBS5h|#tbk-krN4@h!Lg7xaa3?9;`TX&z&O4>yns5Me=O|yY*Dl)eBo=!E zu{SK)TNdqLQa3*SmW|PP|3F{PCC;A+I;c)ou^T%I2K6M;LmLF zsE)dcqi)IJT6DNXM-y>0Ejd~j9j%gXZOO4~(XmT(v=c}B)Y0YKvdF~uo_qBwdCz|NV)Np8McA@z8m4=)Ab~yy&<9G2$7Vh?=L4%H1uVNg=mQ zWVaJ`yTEQw$>k@kag|9ewN%FX zzP@vt4xem+A4T!?-MD(a<%w-=ZEY=E4sW|MyxrX&*t*x_cDHPy*XFU20qJO@r(h() zNhV02&V1^WUaIh+<^KvE{9jA%XW)S+p2YCAMbb>0=n-w3M@FP0PVy43o3D(60R-$3 zhl4vUNw;+PMIprkl`4d9PEPyyzk`Hh4zQj^CA5FA4}SIzT``07q$Dz?J8|R~oj^*o z@IJ)hDYds^Yy;g+;*(V4`!TE;!3hL-9*d+se+6ByBDjkHw_@R;--TzKMm=E|L4iTW12EVP6so6vxA^QM`Ss$ zJx2p)k)(r}+}V;x=>K^#u+gT0?ea1xlU>xW%!;8hD79$zz{%4IP;AFg8Pt|&wyiKS z*s)Ke$ys62;9;#pQ@FyYL4A{E_X;Be*k;7q!V`Z!zMDiI8kgvREA+LF^SIpZ% z^1zN0{Iko#mm_;4yqHx@vZ|*JF0(}fTNE)wZ6dpgu)sYxEwiNpTMFlh!B(Fwgxw;r zTOQ_A3OSWgO|(_asV6!0Q-_x+xsEi#&#tj$T*fuy0EWMa`@j@+bOV} znTkd64oCgM)`KE@h_HtQ_E1VOD9WMe#rZw*(GCxDio&m`bGDG2Er?tuuw{`8QK*-=3FpK7vYEq?lTqzlP{`jZ=5HnW zTfrt*Lc9dOZO85;>`sB*i7kzaBWO`z?tSxH1;-B2v4c2vK=um z$>w>jnDqq7dIDQLZ|eBWQB^8z~6WDF>w%7%y2UiIvPJzX&m@u6MH{2^;=wqdV Ju7r}X{|}gk&}skx literal 28750 zcmeHwX>c1?ni!784G;tg@DdO35JgxAMM{))(7Nw~vTmEU1cPjlf<%I31Jpr-4((N9 zNa2n-6K>4$$~zdxV^d>Q4U_U@IL=m$;>@m<)K=2XdRt_Amx(H=sYI1c6%}a{mw%-4 zeXoH=Hx5d(D^GT6w%PsR_3Q7we&=_+?=?SlIIJ2tel#}QbG==o`LB4BI!wvHxOVY(`iV?82e3~W(ey#Ev){Arqim^9MG`L4Gn7u z>Mq$*FZe{S(R>D%badJ&mJzT_%V0SumKm@t%U~5yEGuBymcepTEIVL1mcc5dSOtLP zTn5WUu?hjpwG37fTLkS<%oYQ5vu=PM)&sDFEdf}{mI5qe%K(}?pYPM4SgqEmCO2gK( z4R=?5Vt_Khl`Ml7QkDk2kg`5U$_o0VoUOmx_z8^e&)|{{nr~r_7OsW-%7V3dS~zz+ z=2NL|gU;ph$=9>x*n_MgM?LcOVLA38Ys!%#cW;$rZ?Tpf`1#tR9NU6*WTs%Xr+`kJ z()gMlVuN<{lV^W!u5DPfcG^~8Yuy;AS#mrjV7iEX=2)swcpVqNr5ckx})vz z=|scceY?*b5DkM|C>(`*CcyPZMC0YF_#43);q`t~)7&3qLxJXjV1IK<>*?LihmP%U z=7N!~zJcz`>jtljHa6Hb-0P>Aa7Y5pa7ESv7}G48>`NN0p=QzSh(~#I^~4nb%5zB% zPbtb{NT5;GFf}C~H^7s3p^iaKrf)LnEvvm@?$on7nBGiWb4}Z6fQHhG1!o5L z41}+o3`TZi`MMSDpg|&x_mA(w^!fuWw3$53G@4J~!DnzuM?&!=6_<&l)kNhfqS0KE z-^6NJ-CezUsKzvYd0MDa$vMQdF&)dW2G(%bsGfe1)|e`TDs76XXkt1w(SQ`#2lP;- zu%u#H(_M2`8p^&9LzO07LftYYlo?U9JoOG3dD-Q1q@F$nuONkoqt$2Wc!+~RG=+l! z7BGW&>1|MPT{tY7&kRNa{eyi$(L9_A8F6g2 zvZA3Iep^J_?m)ErV#fg24{`dgaEr*C3?4lrGF-54SVC$Q?fZj)Xlgu)&K|h)Lmm8d z{nY zac11H=qa6ib?VLOqxS;*%P08PUgtgM1Wkr4o6ngx5|d&h1ao2nJ;mBd~d*J8xsTZboxZXO#yCfy#q zb#UUu%_HMSq}xNc4ow7Z9vwfL3whq|owIurF2P<;?DZ5QE6Q0IH`C#A$B2_U9Ow_i zq_+C~KqZEOX!u3D-#;`Q=#!q%X~VT(JsruJ)gA2X^ZV~;Ib>rI94tGf8HSUHLtw`T z2LQlqO8qVA%!aK?njFAwHn{LXen87hbMA&r3KmB4uq*RLT}%K&K_kXJCujEJPmYDmEoELzcN~ zP*3=0V24K5c-NH0vU4-URF?pI1BRHdUf(W5G9Z0P=IE=fHYLgi6x5J{8Yu!( z7vCHhABgLh%o=C$d_mn@L0tlHR*`~L6elanSroTYy6@9+TcI3WD}ooXSPGW21tSsQ zgpC{ppkg^zWI6djnaxnVq*)F~d3{oF&Z7k7UlbYzfo+ zNPY(>t#UZ$`%uQA2UWt19Z4pm<~%+}CW}!s7SVv(P)nNepr+I+#kSJvEQL2D=dx%{ zX+07fktVcK!SN_HDw%<7hmuBq3LrlNp{h!gP`sKHucoRBaFlb2)4b;e!Se#~yg-!^ z%4M!im(32{ck!ps@Sd}R=PdD@r4Rr1Ipla{+?pfy@(#_RD`m^`4}Xeniei0mVC zc#Z1wauB^*G@_WPxvt7MmJ3rJ&`Q*tn7KJft~})-_>?^LG2I9V-n7K>RDD{SG%u+K zRBjIf%2oDaXdi+`1p5KpVHCW?3@Unv1r>qY36D~0wvw&5RiiM;QVt^WTL6Hi6uaZ* z6eAhBd3pSD7NbaUoh08#AuG6!f~#aAn{DxHH=Wse%1k9UH=%+Y0_3K-6qwbhr(^-Q z6&cBLKw>1Ab2nV_ks30;o&sY*<>hGLdXU?nTVwfO@PcoXqo8CT({rG+gpum&?=o35 zUmoJQDR`!6^K=eFO@A=i)7F2b=0a((DJl1(oS!ND`z)D+*b7)9gSILTX|ky63UdvZ zawUK}Su(tp1+7uCdF7Yji{ze>DeqaxEpiD!e##3rp+i%br`zwH;7?!Rj|cA$-5=uF zAU`n3yN3k#5OEIyeYGqWxhK~pDutqTq-Y&Pm~t7D;+D9D4mK6JL@tP9?eR4L%NT?5 znSSyK(iz4hHThKErHGAQvZNe`f^Z!OsGx~dRwb9Z(k+>!9SCCX(fdJm0HU_EE$x=aPnX`;EXpEyr45S{vylw!( zu*Q#MQhElOvcE8Omt0ejOencX>3)z!WgCjfbcF&DpDP!2&1p&51MiWJ zKmT$MTJPZu_V9y4yoVD!9Pw~zu~uor&mpBQZzaiI2FQ`cugX{4u3EGgj31h~oM@jt zasM!X;QYNISQLO20XD?jFA4Tb#C{1;I;PqMdlRuYJu>JktWPumpr{2)R?Ui)5O6N6 zXvD>WvW3bf{4M9{navk);!pyzNfM(q35du(vP9_oMwkqa@Vg^IzY5JJ4fuKt38hlHde4xRB^73*(;+~j&H(c_8V-=5VdD%b>{(5TgXH*7% zqv$+_D;r>XWl^YTWvOg41c2k`B2(x}XXl;#C6m4i49nj!-fA1>)o zi%&qIWxX!Oz@1$Bm;vr&6HCkpck)}1mP_kd)(iqIT3!xtu=E4=Xwd+R#N8Lc&=3f- zVbORs80x(kr519o49X%}*&tmTf)@0dGi=*wPC#+FE=VBB$5)oHXWON?fz>1T$fPaP3pUa5|)&(K=*75C6)B2v>JZw3l$AW z4CC-8k-pUMftOG}4BF#s0HOr}WpwFy9r|EYKgx_6MvXCD0!U-_5z~VT#J~cFR?Vl; z4si*kS$ZUtZ%jpkyIMxy>PCRAjhY~(-1}T4DoYYDR?>ZMO2|DgX9xi{!+MaAM}!iq z7o6ih${^9K0oNkcB^ouyOfhr5nyYFkk2U3cGHZzKl-W|rI-8!~TQEnQ8;bVNUAm06 zdWIB6b^EAiD)gK-cP&s&jxn@chL|xUl{V^0w^`a*5oiI91X*H?>0$9rMHxmdF-yj1 zwT@cpXtm@n2TFfWZI#SXnO$OdFZi3l@)})P!%;4+sp;ttl^7r52=3iFzW%>_=lZ_w zGzym7C>b&Kas%Kj;&XsY$;H56P-L!zf>*f!BoHmb;RuL;!I7XuL6E5cL5d78%Dz*r*Xp)cPYh;V8Y5fMS)TG?Qp08TBnl5xD{c0vPq`HFsYW%uR7ad}L}9+m4baUv2TuR7Ym%@ z7bk|M8mAlXu9+R?3-$>G`$)k)u+UZ4-VO=o6>)96CEha;TqtnS)urJn_V(rZ>UDF~ z>!yc<>K0PnGP7y6O(@t!3UNJvGpdt7bIl z3?}zKWG5e)h6k$WA?eRLP_$K#59(rudVJQArDwvbA}}072jwrKxhvF1`xY$_E=bK4 z0T3O5zECf?0}g;oBNy=*rGD$jN8lz10`A}d3jfl8pyLEpE1z;f!-xpsC`HMop%p3m zI7#}aA&GtQLZC9ZP9D)Q~xoF!O zq;0uDfeC%CmID%P!|Ze0lV66JuXYs%F%S4*+DZ?Oz={#19CK>CQJoIPYWBENx*@aR zQuE$;gM>JYW z6VF>l!C5}0MbZ4M=6x-^+lsphz@-ySo*j1}B!x3O+5h??H(~@kAl#@UzSwD)yUg}i6D&j>? zv=qiSg&hOlUM|3ffc|-bwt2sCrSFr?e8H2WC1dUyjJ=Gp z=B_|DZjxZdLi8YI{evk(LqufoP2hKGphZFQ3|;4^Z=o$M8lh|uB{G;4_FiDp2)3w>AjOO2Vw>b;+;3$fpnP|EB z%J?f|`zh+KsoLACc9)y~-l0B{P!$TqgRNPKn&_L4$KQ5FmMeHu?@yt!4xU zu#&!W{|NJV7eO3=lCpEC%#Sot+JbBSPoe7fF?l0^6jk5BH|&@WBu-3ge|Q>#8K*?m z*JoBNMW8PR`mW?9>)bzv*Kp$qejmYi5d0GWS(MR}v-I3;!2LIfub_>M0<(%Rt9WMB zbEl11_nDT@Uz}~3ZJllXd>h;f4LbQs7lkGv*Vhj9P?EFc6h>sPPiAIIBcd<50m9+bZi&<7Eh*vH^h#o`z=>JdGfD z$|;{o@HBOdcw`!+^k7TVpOGeId}{jk2SCw;%S#0shz0XzdcuOG0qdM0K2d3jo4^O( zMescYIcK$t&gvf68wY-ORC40}2#^K1BL5peQrf8#n0mt0^GyA7nb9q`U!1em@zUAQ zA=v5&4gcK|P;kH+XQicqYLC{AkurG_XV~pI>>@guN1SWC)Tkp()~#Mi zxiA8F?mQ5~{Bmnh4A-aNItQ43=!Zh+sV+Orgw7bTSOxY5=>VP4`>r=-xx^9dS5rWbz-!{RwH?Z4CXLTT~JvbBQaKrV|^dd4Z~ zM-27^wnGy_8kxm%UWhea(_K?gM%FX}Hrj4|w~k^PS@UwwRpnFAHnp`aW=v8q54A!m z@Tj%P2$UgPgfTpeUF=zGVpN2YEn(0E@}Z`qvIt{>#dLFed6$YXwwQ_1FG#Btw3HMO zF<84YWx6LIZ9dzXqhq94rhaLwwSw@0$?Pb#c4t#Qs$9`E+Yv{S22xSRJ5mOgr(h(! zD(aO?NnijAc7sEWS0Uxtb4fWO(-(s83s4$H$s65Pou+ZJ=mQSgGNBW#l0BZT!ccnm zHnbFn^iC{Dy*g<&dLN%m0f?YMQ!*0)l(u3!DKU)DB{izg{RnXW2O>`a>qyE?UV*74 zOfApUX3{}N(fDgqD-tJeS5H?6j*Y~zacqBz4xYW~ANMEepkJbce%|Vrj7J#aPn;ts zfFNR6VBR3i8$9#IqNjWcR(2YRr*W)fp{(+@g_JdsGFajyWi4YJ6Ds)(&BQ9WmC=B2ac>b;U*R`M-Z z@$4Oz0)nlX(8tRX-_jNHrEBI&*9fKSNa?!Yu;~e94N0;C_?#&9x3b=#qF@;-A*1M; zA}dL0n}3>uf{Z=0{i~X*TCL_) zXbEUDAjqkMEl}4%`ZSZz*)ighd0Wb7Hm#aVZXno15({(gOEIY=6Em*y_6|f9EUWkT z|L0E=|L0F7@AI|!eg0e%=6?(PmoB0a->eb;M!MY&ia_YoGc5dr0Rbm`1y=zNf-eF5oZERKiJ^ zO86!UQzzX$<=LXGaH3?gQm|DMTlJ4t@=Y%Zb=ygua!F^DKiSLg9K~~Bm<+rrh2?dg z<@spMeChhR()B`VD=BULHuxwkAUWPW_Rg`lkKZ^xc6`xNFkW@DW*ojFW-wgQN^qf~ zZmKEK^8MAfSI3+qn}?Td>32&dctlI{ zL{TPX`Le@bfBkiKq!jbl1sa2ux2_YIX2LY{sWVaxgYQoob~PXH>c8?b2O3OYt+P?1L2X49nL@RZ>~y1aDu1n!#}3 znt=Vr8nRszp0!6}%>OrK(Fnf8rVM2fgqx>|GJwIVQmR@rqvz1kLOH%@dzqoVGDdoG zC)kKZX1XI)r#WB`$Oys06!fn3s@yTC4Q-@+A8@Rk&qsEn?*XjuM^c7!Hq_G-#-l~Hh3z`UvB0yhoAoV$bI z#|Uz;T`-U@ckD;A=9bT zK$V#+dOT_j+oUII>N0A~C=X7TBmmPWP8ZLMDH|XQvih8N*47JQ@Q2E`ceMVcP#=9QJdzjA{ z?Um`Vsvhm!PocH)QODc{qkQ!2*BHC-&U`eqFkRAg2=HYyQaYB>+~0WF)$P9A9nc^> zr9rT368Z!r@m}qyi8TN;M^~onaS3%zdPIGx5Y)}2&Kth*01ITxVW2y5j}E^50Q^1u z4-2RZwgrrG%jC^kB$F-hN-LDk8eN^g9I!OCMDw(b+M{a~co`Z1;Qhs@BW8`+6nx*2 z4kt>}Y15&Y9dL^@jqW(H1$;*lIwVeJUn?qP%JSrdWES^pI-##=buk z23_trRheMd4Z6c~{|dlJ4)j*;kKqYy@A)}mQB;f&jZ*T5=%4Wr2?khfO+BGr+z?9b zzM|<$2)<$-il8aCyFV}pAH3}80o8Sn^p$Y$h&dJV)ucVKV2fxZ6plm#;qD;aMMFO_ zELmCUN@XLK16xV@7;F=UwjiM2+r$Y;zw9V|?_nR_?gapQR{Nr&IWpWe*f$W3i2A-@ zShNJg!~H@01T5WZ-wR(}5A}mz|5ee*2BQJkud44JgfF857X0G?CU7kj0TXNx)XY#6 z_O#lgxGmq`7q|}kCTxX;guViWmTHG|vRgm*&#)-#5xk9n8iU<*Hx&Jc1<_#2lq(u( z_E1J>PEjAY77?`}ZYLxXwU;8ec#%?KrC&*vEx8hk!1v#v_5J9VHp~V6(B={DKLY+1 z5WORP4Ao@SR^COEYhvYx6@saOm>TfoN)?MP*a|!K+O+HTh3QVgwVAj!$BhfFk_iuP z78P8z#8n$N!p2xB>bqr=WpN`YhUOy3s9G>J5mVD6O}U}+f!+1qky}Tm3_mFNpa5)@ zWj<2o6YMV#`wOtiHTlT;fi=6VYSOrQD5Yd?N%*7J+%FRJfrg)N84mczo@BV_Fn zq5deTKPr?RBW1?~dk3+1K=OHe?VPkS|0byMiS=S}hb(6Yo!OjvptH^M^?!~#f z7iSH>uzg{>x9_WCUmoL6o|`}U`rJv#tZV*!&)oSQzPE3_mz(S5_()Xf9VWfQ!uc!Y z{1xHkRdVvGu=g6-drjDUoov1?)Qyn35yAc@vA>zoKz?YT3vl8fpbhk4yI^S|mZl|* zMcF>K4{V8jpC0YXM zPF;geT_~!UsE8YXtf*`{3A~c=4ji%G+uLlb4vhiCuH1mAq-?0(_1o7}x)*Xa!%nUMOlIMJ>2# zZzBaajj(Zg15BfdnBZz4E|^wE3p@f4-?djYQ!j;&SaNmA) zzH04U)movdnN&5;S8bT9+8|W5k*c=&s%>*s+k~o@NYzXC4)OL%X%g@_59G|=>$uf1 z)h5_$iM^J$*TOc%iCwp@#IGzA7EkP&G)@dnTILJu=L+i+SEd6WUzc3rXLm`zLg7wQ zxRa;ngRBG{(?f}lM2AqgffR1w>AB>9%KUYq#$YBX%MH(zEHb7W2j4z?|^wALF3`o=Ypn&&vRx!0<)f_V9c|w+l%>TgE%)1Wq|l6>4*EZm57`TCb}j=k1+m8`eC~h z8zLJ_6it>t!uTiYhcDQ*u!l?rm2!y2Rod{W3@U8e?Rc966)t=#gT_j%y9{<}%e8Gw z8X0VA)jF3nY5?X%#vUJ;)0YeSa_H8A%J_w;-3h(mXdsRT@U~(~d8TyAm~aZrdcv&d zne_`y1a#a~?h2Ap4p$PohjZyzHt-3`LMto wAH3`Uc&B6IG4`BK&jyM<&@I=KX7l4d;kCd diff --git a/__pycache__/vggloss.cpython-311.pyc b/__pycache__/vggloss.cpython-311.pyc index 6cd9844011bb4ce774cec8f6e6e9b68d7fcfaf60..c10eb4442faac2a34040b1c2395649aa82c4457a 100644 GIT binary patch delta 20 acmZ2tzQmk+IWI340}$No*s+m&rZfOOVFnfe delta 20 acmZ2tzQmk+IWI340}z<=tl!8zQyKs@-UROe diff --git a/config.yaml b/config.yaml index 0940197..c157cf3 100644 --- a/config.yaml +++ b/config.yaml @@ -6,7 +6,7 @@ model: use_resnet_feature: False use_mlgffn: False use_enhanced_generator: False - use_skip: False + use_skip: True # Training parameters training: initial_video_repeat: 5 @@ -65,7 +65,7 @@ logging: sample_size: 1 # for images on wandb output_dir: "./samples" visualize_every: 100 # Visualize latent tokens every 100 batches - print_model_details: True + print_model_details: False log_every: 100 # Accelerator settings accelerator: From f1f71ea579f14559709c4ab1e66cf2e553818d04 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 21:24:26 +1000 Subject: [PATCH 110/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 73 +++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 59 insertions(+), 14 deletions(-) diff --git a/train.py b/train.py index 03747f9..7020ff5 100644 --- a/train.py +++ b/train.py @@ -51,17 +51,28 @@ def get_noise_magnitude(epoch, max_epochs, initial_magnitude=0.1, final_magnitud """ return max(final_magnitude, initial_magnitude - (initial_magnitude - final_magnitude) * (epoch / max_epochs)) +def get_layer_wise_learning_rates(model): + params = [] + params.append({'params': model.dense_feature_encoder.parameters(), 'lr': 1e-5}) + params.append({'params': model.latent_token_encoder.parameters(), 'lr': 1e-5}) + params.append({'params': model.latent_token_decoder.parameters(), 'lr': 1e-5}) + params.append({'params': model.implicit_motion_alignment.parameters(), 'lr': 1e-5}) + params.append({'params': model.frame_decoder.parameters(), 'lr': 1e-5}) + params.append({'params': model.mapping_network.parameters(), 'lr': 1e-5}) + return params + def train(config, model, discriminator, train_dataloader, val_loader, accelerator): + # layerwise params - + layer_wise_params = get_layer_wise_learning_rates(model) + # Generator optimizer - optimizer_g = AdamW( - model.parameters(), + optimizer_g = AdamW( layer_wise_params, lr=config.training.learning_rate_g, betas=(config.optimizer.beta1, config.optimizer.beta2), - weight_decay=config.training.weight_decay - ) + weight_decay=config.training.weight_decay ) # Discriminator optimizer optimizer_d = AdamW( @@ -92,8 +103,6 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato # Use the unified gan_loss_fn gan_loss_type = config.loss.type - # perceptual_loss_fn = VGGPerceptualLoss().to(accelerator.device) - # perceptual_loss_fn = LPIPSPerceptualLoss().to(accelerator.device) perceptual_loss_fn = lpips.LPIPS(net='alex').to(accelerator.device) pixel_loss_fn = nn.L1Loss() @@ -117,8 +126,7 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato total_g_loss = 0 total_d_loss = 0 - - + current_decay = get_ema_decay(epoch, config.training.num_epochs) if ema: ema.decay = current_decay @@ -126,6 +134,8 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato for batch_idx, batch in enumerate(train_dataloader): # Repeat the current video for the specified number of times for _ in range(int(video_repeat)): + + source_frames = batch['frames'] batch_size, num_frames, channels, height, width = source_frames.shape @@ -310,8 +320,9 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato scheduler_d.step(avg_d_loss) # Logging if accelerator.is_main_process: - wandb.log({ - "ema":current_decay, + # Existing logs + log_dict = { + "ema": current_decay, "noise_magnitude": noise_magnitude, "batch_g_loss": g_loss.item(), "batch_d_loss": d_loss.item(), @@ -319,9 +330,41 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato "perceptual_loss": l_v.item(), "gan_loss": g_loss_gan.item(), "batch": batch_idx + epoch * len(train_dataloader), - "lr_g": optimizer_g.param_groups[0]['lr'], - "lr_d": optimizer_d.param_groups[0]['lr'] - }) + } + + # Add layer-wise learning rates + component_names = [ + 'dense_feature_encoder', + 'latent_token_encoder', + 'latent_token_decoder', + 'implicit_motion_alignment', + 'frame_decoder', + 'mapping_network' + ] + for i, param_group in enumerate(optimizer_g.param_groups): + log_dict[f"lr_g_{component_names[i]}"] = param_group['lr'] + log_dict["lr_d"] = optimizer_d.param_groups[0]['lr'] + + # Add gradient norms for each component of the generator + for component in component_names: + params = getattr(model, component).parameters() + grad_norms = [torch.norm(p.grad.detach()) for p in params if p.grad is not None] + if grad_norms: + grad_norm = torch.norm(torch.stack(grad_norms)) + log_dict[f"grad_norm_{component}"] = grad_norm.item() + else: + log_dict[f"grad_norm_{component}"] = 0.0 + + # Add gradient norm for the discriminator + disc_grad_norms = [torch.norm(p.grad.detach()) for p in discriminator.parameters() if p.grad is not None] + if disc_grad_norms: + disc_grad_norm = torch.norm(torch.stack(disc_grad_norms)) + log_dict["grad_norm_discriminator"] = disc_grad_norm.item() + else: + log_dict["grad_norm_discriminator"] = 0.0 + + # Log to wandb + wandb.log(log_dict) # Log gradient flow for generator and discriminator criterion = [perceptual_loss_fn,pixel_loss_fn] @@ -368,8 +411,10 @@ def main(): num_layers=config.model.num_layers, use_resnet_feature=config.model.use_resnet_feature, use_mlgffn=config.model.use_mlgffn, - use_enhanced_generator=config.model.use_enhanced_generator + use_enhanced_generator=config.model.use_enhanced_generator, + use_skip=config.model.use_skip ) + add_gradient_hooks(model) # discriminator = PatchDiscriminator(ndf=config.discriminator.ndf) Original From 8a582ae21900cc71e688380c8f8abd6a655cefb6 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 21:25:25 +1000 Subject: [PATCH 111/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/model.py b/model.py index 67fc08c..bfeda8b 100644 --- a/model.py +++ b/model.py @@ -7,13 +7,12 @@ # from vit_scaled import ImplicitMotionAlignment #- SLOW but reduces memory 2x/3x # from vit_mlgffn import ImplicitMotionAlignment # from vit_xformers import ImplicitMotionAlignment -from vit import ImplicitMotionAlignment +from vit import ImplicitMotionAlignment,ImplicitMotionAlignmentWithSkip from stylegan import EqualConv2d,EqualLinear from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights from resblock import UpConvResBlock,DownConvResBlock,FeatResBlock,StyledConv,ResBlock import math import random -# from common import DownConvResBlock,UpConvResBlock import colored_traceback.auto # makes terminal show color coded output when crash from framedecoder import EnhancedFrameDecoder DEBUG = False From 27c0268df43ed8755d3ed10562a29938d0420a04 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 21:27:09 +1000 Subject: [PATCH 112/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vit.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/vit.py b/vit.py index 20ca281..db656c4 100644 --- a/vit.py +++ b/vit.py @@ -57,13 +57,20 @@ def forward(self, x): class ImplicitMotionAlignmentWithSkip(nn.Module): - def __init__(self, feature_dim, motion_dim, depth, num_heads, window_size, mlp_ratio): + def __init__(self, feature_dim, motion_dim, depth, num_heads, window_size, mlp_ratio, use_mlgffn=False): super().__init__() self.cross_attention = CrossAttention(feature_dim, motion_dim) - self.blocks = nn.ModuleList([ - TransformerBlock(motion_dim, num_heads, window_size, mlp_ratio) - for _ in range(depth) - ]) + if use_mlgffn: + self.cross_attention = MLGFFNCrossAttention(feature_dim, motion_dim, num_heads) + self.blocks = nn.ModuleList([ + HSCATB(feature_dim, num_heads, window_size, mlp_ratio) for _ in range(depth) + ]) + else: + self.cross_attention = CrossAttentionModule(dim=feature_dim, heads=num_heads, dim_head=feature_dim // num_heads) + self.blocks = nn.ModuleList([ + TransformerBlock(feature_dim, num_heads, feature_dim // num_heads, feature_dim * mlp_ratio) + for _ in range(depth) + ]) self.skip_connections = nn.ModuleList([ nn.Linear(motion_dim, motion_dim) for _ in range(depth) ]) From 4c32aca044d72f2a743261d7a2e4458a7d03add4 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 21:28:51 +1000 Subject: [PATCH 113/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vit.py b/vit.py index db656c4..43b4d57 100644 --- a/vit.py +++ b/vit.py @@ -59,7 +59,7 @@ def forward(self, x): class ImplicitMotionAlignmentWithSkip(nn.Module): def __init__(self, feature_dim, motion_dim, depth, num_heads, window_size, mlp_ratio, use_mlgffn=False): super().__init__() - self.cross_attention = CrossAttention(feature_dim, motion_dim) + if use_mlgffn: self.cross_attention = MLGFFNCrossAttention(feature_dim, motion_dim, num_heads) self.blocks = nn.ModuleList([ From 69d03ca50d4bbb1d5166a5ff740934f30f49bb14 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 21:29:45 +1000 Subject: [PATCH 114/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.yaml b/config.yaml index c157cf3..54f0485 100644 --- a/config.yaml +++ b/config.yaml @@ -6,7 +6,7 @@ model: use_resnet_feature: False use_mlgffn: False use_enhanced_generator: False - use_skip: True + use_skip: False # Training parameters training: initial_video_repeat: 5 From c4841ba1b24d5521387491b1c925f34f79381497 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 21:30:29 +1000 Subject: [PATCH 115/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.yaml b/config.yaml index 54f0485..ab19ad1 100644 --- a/config.yaml +++ b/config.yaml @@ -13,7 +13,7 @@ training: final_video_repeat: 2 use_ema: False use_r1_reg: True - batch_size: 1 # need to redo emodataset to remove the cache npz - smaller numbers won't work... + batch_size: 2 # need to redo emodataset to remove the cache npz - smaller numbers won't work... num_epochs: 1000 save_steps: 250 learning_rate_g: 1.0e-4 # Reduced learning rate for generator From 20bc1166bb298b90879c9e21ad544b7da1434ee9 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 21:31:33 +1000 Subject: [PATCH 116/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.yaml b/config.yaml index ab19ad1..54f0485 100644 --- a/config.yaml +++ b/config.yaml @@ -13,7 +13,7 @@ training: final_video_repeat: 2 use_ema: False use_r1_reg: True - batch_size: 2 # need to redo emodataset to remove the cache npz - smaller numbers won't work... + batch_size: 1 # need to redo emodataset to remove the cache npz - smaller numbers won't work... num_epochs: 1000 save_steps: 250 learning_rate_g: 1.0e-4 # Reduced learning rate for generator From 674c9d26e58e7ee8a3d1bfd357146b7603cd4f24 Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 21:49:44 +1000 Subject: [PATCH 117/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index 7020ff5..d184196 100644 --- a/train.py +++ b/train.py @@ -53,12 +53,12 @@ def get_noise_magnitude(epoch, max_epochs, initial_magnitude=0.1, final_magnitud def get_layer_wise_learning_rates(model): params = [] - params.append({'params': model.dense_feature_encoder.parameters(), 'lr': 1e-5}) - params.append({'params': model.latent_token_encoder.parameters(), 'lr': 1e-5}) - params.append({'params': model.latent_token_decoder.parameters(), 'lr': 1e-5}) - params.append({'params': model.implicit_motion_alignment.parameters(), 'lr': 1e-5}) - params.append({'params': model.frame_decoder.parameters(), 'lr': 1e-5}) - params.append({'params': model.mapping_network.parameters(), 'lr': 1e-5}) + params.append({'params': model.dense_feature_encoder.parameters(), 'lr': 1e-4}) + params.append({'params': model.latent_token_encoder.parameters(), 'lr': 1e-4}) + params.append({'params': model.latent_token_decoder.parameters(), 'lr': 1e-4}) + params.append({'params': model.implicit_motion_alignment.parameters(), 'lr': 1e-4}) + params.append({'params': model.frame_decoder.parameters(), 'lr': 1e-4}) + params.append({'params': model.mapping_network.parameters(), 'lr': 1e-4}) return params From 485e6fdd0f9484698ddfa40740fdfce0a793effe Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 22:03:13 +1000 Subject: [PATCH 118/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- resblock.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/resblock.py b/resblock.py index 06e36d8..54970f5 100644 --- a/resblock.py +++ b/resblock.py @@ -163,6 +163,7 @@ def forward(self, x): return out + class DownConvResBlock(nn.Module): def __init__(self, in_channels, out_channels, dropout_rate=0.1): super().__init__() @@ -171,8 +172,8 @@ def __init__(self, in_channels, out_channels, dropout_rate=0.1): self.relu = nn.ReLU(inplace=True) self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.bn2 = nn.BatchNorm2d(out_channels) - self.dropout = nn.Dropout2d(dropout_rate) + # self.bn2 = nn.BatchNorm2d(out_channels) # 🤷 works with not without + # self.dropout = nn.Dropout2d(dropout_rate) # 🤷 self.feat_res_block1 = FeatResBlock(out_channels) self.feat_res_block2 = FeatResBlock(out_channels) @@ -182,9 +183,9 @@ def forward(self, x): out = self.relu(out) out = self.avgpool(out) out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - out = self.dropout(out) + # out = self.bn2(out) # 🤷 + out = self.relu(out) # 🤷 + # out = self.dropout(out) # 🤷 out = self.feat_res_block1(out) out = self.feat_res_block2(out) return out From 91f23657c728a7c6fc7a4abb997273d36981ba6d Mon Sep 17 00:00:00 2001 From: John Pope Date: Sun, 11 Aug 2024 22:12:13 +1000 Subject: [PATCH 119/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/resblock.cpython-311.pyc | Bin 31274 -> 31001 bytes model.py | 2 ++ 2 files changed, 2 insertions(+) diff --git a/__pycache__/resblock.cpython-311.pyc b/__pycache__/resblock.cpython-311.pyc index af52b1be70953f8337452835abb48ce075fe57f4..34864511a70a1787effea00817d18036809b9274 100644 GIT binary patch delta 435 zcmZ4Wg>mL5M&9MTyj%=GP_bf1THi+AdVZ!gjFanCV15a|HDm`6r(ibf2s!WI-<3$!$VqtY6s}6m%vl3fnP;O;+UAo17rR zGC5O8Y;u5z-Q;J&A(LylL2Qx9Rw7!H`$RMpn1UIKSb;`pvKJ`>sUlSnp$;P4L4?L+ z6HzbjEkI^*Cy;1hn9L}nG+9T5fATI-8G#3ULj9RtnKPI>^Lp|+m>*15=hL5TDyG7? zeDVer(arf{V$5uZL3%fV+_!m!q$d;O*~y=zH!>aqE7&Qc$ijGg^KAuTX2wsG-zy(r zytR3QiaQe=4xbaR}j z7&F@;kd{p#=WL!X;mO2!X7WR+jf@Au3RXxfvM}D-d`e!JnepS~drAiwZ*HEi?9Rl- zzzEb70@lQ-Hj#yqee-Os^^A=F!MYN3f|=OlKq?Z!Dz51rU}RLCJl$XdlPu%p1(lAGx~0Bv)j$g7`55SsgZ>-ZSxVg rf2@oZn-BXq3o=gG%pUico$ Date: Mon, 12 Aug 2024 11:06:55 +1000 Subject: [PATCH 120/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/loss.cpython-311.pyc | Bin 10304 -> 10304 bytes __pycache__/resblock.cpython-311.pyc | Bin 31001 -> 41500 bytes config.yaml | 6 +- model.py | 108 ++++++------- resblock.py | 218 ++++++++++++++++++++++++++- train.py | 15 +- 6 files changed, 279 insertions(+), 68 deletions(-) diff --git a/__pycache__/loss.cpython-311.pyc b/__pycache__/loss.cpython-311.pyc index 17b857ff8dbeef85fdd9b24b2898ca4934fe49ce..df21e34ead641eaa2066834679512a562f042222 100644 GIT binary patch delta 20 acmX>Qa3FwtIWI340}!l?-MNw5R09A)1qIIl delta 20 acmX>Qa3FwtIWI340}$No*s+n@R09A+QU)mi diff --git a/__pycache__/resblock.cpython-311.pyc b/__pycache__/resblock.cpython-311.pyc index 34864511a70a1787effea00817d18036809b9274..6015a85bd5424483d3e594d342ce486f1138690a 100644 GIT binary patch delta 8985 zcmcIJZEzFUcDvF_TFLsBE%_5mvau~2+Ze~@BM@Tn!x$U~0|8>7j_ei2ANXS>z?c;> z89EUjcTiYoP*RvDGQ7lnW#*a0X~>(rK1k^c?Q7FrcVF#lURq{4Oeeo;Uh-x#nReQ9 z?^@OehIXd&c2{TbJ@?#m?>+bIJs*1acjWa4it-N)1|0#{nZuu*Ap8XJZ>W;ERLg~L z2k(^klc6#u_Y&E;hAC(A;a$uWTp~K1Od&uQQ^6Dg&JFK1@GfDB0bA-@%Qy!Vh)ING zy^O1~OhNdHePthz1o7waldsNlz~(S+ri3ZIz4ikow83w=tpeISz6xLY2Z|6hU}mU^ zB%r~HI(-1RB)+o}@VUN9)FBm*>DN31Z(D^=74XUPRi*HO@e>2DNoz>rvl(~B+SUP^ zeBZhhHc9F=&%oQUEZiBAsI~$tqQ3_B2jDC3ujNOmwOQ){U*N0umHXCTRpu0|E|Ma& zbbh_k5qi2EEwu=(h@sQR#6nwrO+{$#jBXWryp!p|??k&;gm#h9ho04iQLeiAPI>d6 z+RJAtOdSm6^wm8raE3fUz)wkdJRdJoS5M{ktSA#m6MY0o>%QGV^KE%9e!I=g2lMjL zTS0NW*yG^;b)Q)gRPp<4c@|arE(g{8S6O)$bp}?$e`U+#2alMjAk7iH&}=W{=nNVy zr{$F6FmOv3zm* zbS8{~Q8J3>QH;ZvenN4ZO5dAfAfTWzXbR*cI1@xB?`9&RIQ}2UvJt5{XyGi3Dlb9J zk90VA^EMNy=SK%hNCSWKWp`!J8YoH-5JaFjS!UFnwcnhmGb;Yhk&67|OFs|3G!EOzXpmUPvtVM`fwc*?1kt}DDK0y8xa0_F zcshw^)F+hbk`3m-TyhwKA^un)L(Xi3ca-2UU7#)+9Dl)+6Z%~=srkB&S$e8z!rvmGi`s6^zg-5Q?YJyQ`@p?r=4X7I!c*pR7!<{OI4Q6dUZI8$O1;5*FYpA6<1 zQZe~}$qg~Y0m5Tx_vqrP)}fKpfw*et@X(pMxN=8(U0k)_*Y+B`5NZQLZHQ}{+2N64 zpriNu{C&)6@8G!Uz(C)K6KIBfJ%PUAA$Ari$QuZJguomE==f*R(Z+faup{spLDUg( zQpbJ`wQ<_H^Q0Ge46+uee)HdOc}VteP?J`-bNEzu-;j^-6z~yuY3LWI`(GmPIRalG z@Bo3oM&MTn{1pPfM4$(OFA=y0Ag&tsvBQ4$AwnST?B8OnKj7^di0g1;*I?g(kNq`j z{5t?12YVHjc?32A2(g<1#r4;hcblZEAk&bEmwW6yG&FIYo2WRT84~vc($x$O(Y9&kU5~0&JlQIl9E!qo1(T?f9aLb}&P0aCO!TSX_i*6M` zhuMvvzWC9Lcec%KySMSn7r%J%{Ryx7xa%P)l{x+9%ti4Wiy9>0N@}mC~}HFP_&I#~No`w>AsT9q`Pm zX1#N?Tm~rWcS`!50)DVeB4?(#SkrW9h7>K8lBIHqAhqrgh9HV(jtRqp3hJ#>TchOE z_NaG0t60b?{@Rf@J+^4JM@FZ$i$x{ZkIlHQcg>8>7p)hH)<@I`9npT$4KmsFu1tHS#)fc9GkJr3WvG z8#;u8N5l=GBcko~Fuh1;g`2{o;pQpr0-ZZg=SFiy+AYy;fp({~T`^7?Bd0~WNTQ1b zx+tY(fi9S*3!-OY-s|Ut!WZO6q#Gowp-8vL;47#YWzw(%pOtVF>@USp{VR2Fo>VGM)lc>8rwzolIQ)D}yoWRKo^7 zI#4I57!CNbxjPv8wssZoS{>8`kZoDMGT5Zd0<)=bAD1AM&@f7lPBrX7Ek`AA{14@8 zkkie-xi1I5*^V6K9IJym2px3%oqgC>2G3aoemVHiiWKCJI!!7ar;QnqS4ryl;m#bT zHfBWLsM5gEd|hug1hpt-(Kp^V)ax8K_8}LgE8rXQ53^gGvQu#p zwv&Am0kpT+v+yXsL7vVxNGWCa!;4CPfxxE##*Nt5NjMx%f5>&7 zSsvbt92+N?(vd-5z~}N@)&NV~;F5yMTZVwA_yb#{cv2!jNpgWeYM%z2}5iZto>+np+txnOlUb3wh4Rw;CZi!H<92kVjMT0qFi8MvB zr(O>$7OgfYjz*fN%dgd3tr4wm$?9IPmd#tsL~95+4hz&VL+@sAX*PX2U>SUJE@J7 z$FyOsNLNa9r9f9AtbS52hQzrYC!a<{kYHe|Ds}`Fm(?7}Dbg&9k|P*J8jkO<;iMEFd%1)cY}gp8gJAJtgX^kGQAUaX zhvv1AbwgDvj;e8j4rh`moNmhrQ7Pl6Y%FYK%Dj^P1w%cRIZ|;7h8n<2U=1k6mzG6R zKo$wwhStLVq-2)@MtRK;Z9xF7l`G)hq_{HajO5NTGfv50+~=@mqsR+#DPAicys|NjvYleJhy`izb

!C$2d&(&~deFc;iZT_$&@NlZJ`S zM!JW?Q)J5g=v#H-Yhr>$Di^mVOsOvsn)=5ZlekoS*2RW;yKA;Wgz6pizf_YXT8?S` zS77n}Q5f-iFm@I@1YTQb?6(VNy|;^He_XU|kt|!52$Qzp2?7vnKqz-0>WUniE(q^{ z7%bWmBclhW2N&#~dAmonS4sA&1$*thy;ii>Ng;dPf_>w>eWPgKEZH~z0b@0YYU?IX zhfgn-*F+$=#`kD6G#VX^xgabxmkQ=m2!$$YXGYOI>q7S@`{z1^+U-*L%PDQ&v@S-i;YEFqAjyWguK1*K=_E?qP|Viw+Z^T z5F&1wY>E8na{E+!xczebQZa}T3lyL0Yyn6gR114SN0E+vfIx;Tgx#YIe+aAFuMsg$ zt!P*EghCmvP)S0&{OB_Ty?O)b*2}VH{tZ9(d?x^IND9$>_E zrm3RK#Z#eTPuxPPIaw4N|=T7H_)l13*8n-<`RNt+TCOc)6zW@IF3=FRRo_>ap zrNvIx>OR&dw%Lxk!*@I8I`7#Z6h8DmC>HBFq`D5Vc)wJPv>HD7CH4J$>Bg=`ctP-pV*qnrM z!T(V}t+_N6d8J>gR0TBreR~eSw%n8-RD;!3a~X$o5bQI5v%^7Z_@)kfB||aDEW^B# z|Bzx+59*Zs3lvUt2&DJfvwMlm94s`9cGzp4ZU~_DE?@n~OTrhU4HSTLe{oIjelP~Q z8yWw8fJczpZP_+K98NM>jQ+MEeWDrECh!&@*K34auW9_g4t+d_mLnUV)9b8jb`E;s za1S^G!vnq{r`He2p5eh^_QdHyXSbKF!Cf$DVtf5ta7m`E2n@45Cus zhQi7EMj+J)d{Oik2%Yq$4I?x4d(Zf~M!bQO&I*5k1qS<2IWX)*4bA}cNqpWo29rB^ z6zm%KpnGM>J=!h@c{ska1CCXp^WqIC#VPOUz_5pmtG)iv2;`^7t!?s358{qflHa=G1OmQ&2whW-i-waW1pr+a;Y zuAWog5u9a?Yd{Dmy(1xCT<3>|u|QwX0I)^nb8;C*4-EQnGCQt83qR4<8&~xV4*Pvz zcU4%@aqWtz;xxvf%rA}imMqHE(i>rV`0ivS9Men>?baN-|YV?1qyP(yW8A z&4RI7G*(N->WeLqOSQtmk>`qS`f%&}TW`K}>m@O_ZuaD-gC7mv8JZiqH})m>1$X~K z=)ncC@iozLNT9NTkKP1HKCPJLma^Ov&5H(0_@H2Qi-r=(P%^O#okW_qTHkC9pSUtK zIRsW)5o;2SWsQ5y<0C(+eK=-L~VxxN}sh}+j1N`#^Wghk;{s(I| B=G*`P delta 1118 zcmZ`&O>7cT5Z*t3LRb@u6?XamS&%|84PuCgCd9-;W3(yS!+NmN_j;hUB#ZH!uYFjis6XUzAy-DmdFNZaL{}x`-lkZDcMADYparo9Mm#~}DEQqtiX&z}O@ZHU_E!nP0Gt+yBp?V_ znvlrae;xV%D>DXJL>#lslJ`#lZjf>y6$_n#4n>YbDynqz;*2Augc@Y0J$`ei@(jdz za2!YQjJim0%NhB|EYuq7s*dBNbsuRRT{YKW49mJnAL#|zEaImvhMP40kOYRkI*EeV zPSA|Wgcig=Gq-G%W&8$(#yBa1ld|NvptK}R8ij!vX?_u^TY zZ`xem=#Bw`TqxX$wm8TKIux2hbivR>yxZVDy-=5Ti*pHDSpxj=Nf1^zgviDyzhF~- z5A7A$9s)EV$6IQD_vG3_;lf_m#>Bi~S(bIs!6ZOxJEIdZxq2Mpf5^W^>0hF2Go?Yu6e51AgH(fLg6 zexGyJ9@Q6`_V!wj%j#TKJl|pQ0!3b EKM+RUS^xk5 diff --git a/config.yaml b/config.yaml index 54f0485..fa8e567 100644 --- a/config.yaml +++ b/config.yaml @@ -9,6 +9,8 @@ model: use_skip: False # Training parameters training: + + use_multiscale_discriminator: False initial_video_repeat: 5 final_video_repeat: 2 use_ema: False @@ -31,7 +33,7 @@ training: lambda_gp: 10 # Gradient penalty coefficient lambda_mse: 1.0 n_critic: 2 # Number of discriminator updates per generator update - clip_grad_norm: 1.0 # Maximum norm for gradient clipping + clip_grad_norm: 0.5 # Maximum norm for gradient clipping r1_gamma: 10 r1_interval: 16 label_smoothing: 0.1 @@ -84,7 +86,7 @@ optimizer: # Loss function loss: - type: "wasserstein" # Changed to Wasserstein loss for WGAN-GP + type: "vanilla" # Changed to Wasserstein loss for WGAN-GP weights: perceptual: [10, 10, 10, 10, 10] equivariance_shift: 10 diff --git a/model.py b/model.py index 04edc93..4c5d705 100644 --- a/model.py +++ b/model.py @@ -561,62 +561,62 @@ def forward(self, input): Output: The forward method returns a list containing the outputs from both scales. Weight Initialization: A helper function init_weights is provided to initialize the weights of the network, which can be applied using the apply method. ''' -# class PatchDiscriminator(nn.Module): -# def __init__(self, input_nc=3, ndf=64): -# super(PatchDiscriminator, self).__init__() +class IMFPatchDiscriminator(nn.Module): + def __init__(self, input_nc=3, ndf=64): + super(IMFPatchDiscriminator, self).__init__() -# self.scale1 = nn.Sequential( -# spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1)), -# nn.LeakyReLU(0.2, inplace=True), -# spectral_norm(nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1)), -# nn.InstanceNorm2d(ndf * 2), -# nn.LeakyReLU(0.2, inplace=True), -# spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)), -# nn.InstanceNorm2d(ndf * 4), -# nn.LeakyReLU(0.2, inplace=True), -# spectral_norm(nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1)), -# nn.InstanceNorm2d(ndf * 8), -# nn.LeakyReLU(0.2, inplace=True), -# spectral_norm(nn.Conv2d(ndf * 8, 1, kernel_size=1, stride=1, padding=0)) -# ) - -# self.scale2 = nn.Sequential( -# spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1)), -# nn.LeakyReLU(0.2, inplace=True), -# spectral_norm(nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1)), -# nn.InstanceNorm2d(ndf * 2), -# nn.LeakyReLU(0.2, inplace=True), -# spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)), -# nn.InstanceNorm2d(ndf * 4), -# nn.LeakyReLU(0.2, inplace=True), -# spectral_norm(nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1)), -# nn.InstanceNorm2d(ndf * 8), -# nn.LeakyReLU(0.2, inplace=True), -# spectral_norm(nn.Conv2d(ndf * 8, 1, kernel_size=1, stride=1, padding=0)) -# ) - - -# def forward(self, x): -# debug_print(f"PatchDiscriminator input shape: {x.shape}") - -# # Scale 1 -# output1 = x -# for i, layer in enumerate(self.scale1): -# output1 = layer(output1) -# debug_print(f"Scale 1 - Layer {i} output shape: {output1.shape}") - -# # Scale 2 -# x_downsampled = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False) -# debug_print(f"Scale 2 - Downsampled input shape: {x_downsampled.shape}") - -# output2 = x_downsampled -# for i, layer in enumerate(self.scale2): -# output2 = layer(output2) -# debug_print(f"Scale 2 - Layer {i} output shape: {output2.shape}") - -# debug_print(f"PatchDiscriminator final output shapes: {output1.shape}, {output2.shape}") -# return [output1, output2] + self.scale1 = nn.Sequential( + spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1)), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm(nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1)), + nn.InstanceNorm2d(ndf * 2), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)), + nn.InstanceNorm2d(ndf * 4), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm(nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1)), + nn.InstanceNorm2d(ndf * 8), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm(nn.Conv2d(ndf * 8, 1, kernel_size=1, stride=1, padding=0)) + ) + + self.scale2 = nn.Sequential( + spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1)), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm(nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1)), + nn.InstanceNorm2d(ndf * 2), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)), + nn.InstanceNorm2d(ndf * 4), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm(nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1)), + nn.InstanceNorm2d(ndf * 8), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm(nn.Conv2d(ndf * 8, 1, kernel_size=1, stride=1, padding=0)) + ) + + + def forward(self, x): + debug_print(f"PatchDiscriminator input shape: {x.shape}") + + # Scale 1 + output1 = x + for i, layer in enumerate(self.scale1): + output1 = layer(output1) + debug_print(f"Scale 1 - Layer {i} output shape: {output1.shape}") + + # Scale 2 + x_downsampled = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False) + debug_print(f"Scale 2 - Downsampled input shape: {x_downsampled.shape}") + + output2 = x_downsampled + for i, layer in enumerate(self.scale2): + output2 = layer(output2) + debug_print(f"Scale 2 - Layer {i} output shape: {output2.shape}") + + debug_print(f"PatchDiscriminator final output shapes: {output1.shape}, {output2.shape}") + return [output1, output2] class ConvBlock(nn.Module): diff --git a/resblock.py b/resblock.py index 54970f5..3149911 100644 --- a/resblock.py +++ b/resblock.py @@ -534,12 +534,201 @@ def test_resblock_with_image(resblock, image_tensor): print("ResBlock test with image passed successfully!") + +def visualize_resblock_rgb(block, input_tensor): + x = input_tensor + + # Get intermediate outputs + residual = block.shortcut(x) + conv1_out = block.conv1(x) + bn1_out = block.bn1(conv1_out) + relu1_out = block.relu(bn1_out) + conv2_out = block.conv2(relu1_out) + bn2_out = block.bn2(conv2_out) + dropout_out = block.dropout(bn2_out) + + skip_connection = dropout_out + residual + final_output = block.relu(skip_connection) + + outputs = [x, conv1_out, bn1_out, relu1_out, conv2_out, bn2_out, dropout_out, residual, skip_connection, final_output] + titles = ['Input', 'Conv1', 'BN1', 'ReLU1', 'Conv2', 'BN2', 'Dropout', 'Residual', 'Skip Connection', 'Final Output'] + + fig, axs = plt.subplots(len(outputs), 4, figsize=(20, 4 * len(outputs))) + + for i, out in enumerate(outputs): + out_np = out[0].detach().cpu().numpy() + + # Display first three channels (or less if there are fewer channels) + for j in range(min(3, out_np.shape[0])): + channel = out_np[j] + channel = (channel - channel.min()) / (channel.max() - channel.min() + 1e-8) + + # Create RGB image with the channel value in the corresponding color channel + rgb_channel = np.zeros((channel.shape[0], channel.shape[1], 3)) + rgb_channel[:, :, j] = channel + + axs[i, j].imshow(rgb_channel) + axs[i, j].axis('off') + axs[i, j].set_title(f'{titles[i]}\n{"RGB"[j]} Channel') + + # If there are fewer than 3 channels, hide the unused subplots + for j in range(out_np.shape[0], 3): + axs[i, j].axis('off') + + # Display combined representation + if out_np.shape[0] >= 3: + combined = np.stack([ + out_np[0], + out_np[1] if out_np.shape[0] > 1 else np.zeros_like(out_np[0]), + out_np[2] if out_np.shape[0] > 2 else np.zeros_like(out_np[0]) + ], axis=-1) + else: + combined = np.stack([out_np[0]] * 3, axis=-1) + + combined = (combined - combined.min()) / (combined.max() - combined.min() + 1e-8) + axs[i, 3].imshow(combined) + axs[i, 3].axis('off') + axs[i, 3].set_title(f'{titles[i]}\nCombined') + + plt.tight_layout() + plt.show() + +def visualize_block_output(block, input_tensor, block_name, num_channels=4): + print(f"\nVisualizing {block_name}") + print(f"input_tensor: {input_tensor.shape}") + + x = input_tensor + + # Forward pass + output = block(x) + + # If the output is a tuple or list (e.g., for blocks that return multiple tensors) + if isinstance(output, (tuple, list)): + output = output[0] # Visualize only the first tensor + + print(f"Input shape: {x.shape}, Output shape: {output.shape}") + + # Visualize input and output + fig, axs = plt.subplots(2, min(num_channels, output.shape[1]), figsize=(15, 8)) + + # Display input channels + for j in range(min(num_channels, x.shape[1])): + ax = axs[0, j] + channel = x[0, j].detach().cpu().numpy() + channel = (channel - channel.min()) / (channel.max() - channel.min() + 1e-8) + ax.imshow(channel, cmap='viridis') + ax.set_title(f"Input Channel {j}") + ax.axis('off') + + # Display output channels + for j in range(min(num_channels, output.shape[1])): + ax = axs[1, j] + channel = output[0, j].detach().cpu().numpy() + channel = (channel - channel.min()) / (channel.max() - channel.min() + 1e-8) + ax.imshow(channel, cmap='viridis') + ax.set_title(f"Output Channel {j}") + ax.axis('off') + + plt.tight_layout() + plt.show() + + print(f"{block_name} visualization complete!") + + +def visualize_block_output_rgb(block, input_tensor, block_name, num_channels=4): + print(f"\nVisualizing {block_name}") + x = input_tensor + + # Forward pass + with torch.no_grad(): + output = block(x) + + # If the output is a tuple or list (e.g., for blocks that return multiple tensors) + if isinstance(output, (tuple, list)): + output = output[0] # Visualize only the first tensor + + print(f"Input shape: {x.shape}, Output shape: {output.shape}") + + # Visualize input and output + fig, axs = plt.subplots(2, min(num_channels, max(x.shape[1], output.shape[1])), figsize=(20, 10)) + + def visualize_channels(tensor, row): + for j in range(min(num_channels, tensor.shape[1])): + ax = axs[row, j] + channel = tensor[0, j].detach().cpu().numpy() + channel = (channel - channel.min()) / (channel.max() - channel.min() + 1e-8) + + # Create RGB image with the channel value in the corresponding color channel + rgb_channel = np.zeros((channel.shape[0], channel.shape[1], 3)) + if j < 3: + rgb_channel[:, :, j] = channel + else: + # For channels beyond the first 3, use grayscale + rgb_channel[:, :, :] = channel[:, :, np.newaxis] + + ax.imshow(rgb_channel) + ax.set_title(f"{'Input' if row == 0 else 'Output'} Channel {j}") + ax.axis('off') + + # Display input channels + visualize_channels(x, 0) + + # Display output channels + visualize_channels(output, 1) + + plt.tight_layout() + plt.show() + + print(f"{block_name} visualization complete!") + +def visualize_latent_token(token, save_path): + """ + Visualize a 1D latent token as a colorful bar. + + Args: + token (torch.Tensor): A 1D tensor representing the latent token. + save_path (str): Path to save the visualization. + """ + # Ensure the token is on CPU and convert to numpy + token_np = token.cpu().detach().numpy() + + # Create a figure and axis + fig, ax = plt.subplots(figsize=(10, 0.5)) + + # Normalize the token values to [0, 1] for colormap + token_normalized = (token_np - token_np.min()) / (token_np.max() - token_np.min()) + + # Create a colorful representation + cmap = plt.get_cmap('viridis') + colors = cmap(token_normalized) + + # Plot the token as a colorful bar + ax.imshow(colors.reshape(1, -1, 4), aspect='auto') + + # Remove axes + ax.set_xticks([]) + ax.set_yticks([]) + + # Add a title + plt.title(f"Latent Token (dim={len(token_np)})") + + # Save the figure + plt.savefig(save_path, bbox_inches='tight', pad_inches=0.1) + plt.close() + if __name__ == "__main__": + # Load the image image_path = "/media/oem/12TB/Downloads/CelebV-HQ/celebvhq/35666/0H5fm71cs4A_11/000000.png" image_tensor = load_and_preprocess_image(image_path) + # Create a ResBlock + resblock = ResBlock(3, 64, downsample=True) + + # Visualize the ResBlock + visualize_resblock(resblock, image_tensor) + # Run all tests with the image tensor upconv = UpConvResBlock(3, 64) test_upconvresblock(upconv, image_tensor) @@ -579,8 +768,27 @@ def test_resblock_with_image(resblock, image_tensor): test_resblock_with_image(resblock_down, image_tensor) - styledconv = StyledConv(3, 64, 3, 32, upsample=True) - latent = torch.randn(image_tensor.shape[0], 32) - test_styledconv(styledconv, image_tensor, latent) - visualize_feature_maps(styledconv, image_tensor, num_channels=4, latent_dim=32) - \ No newline at end of file + +# BROKEN + # styledconv = StyledConv(3, 64, 3, 32, upsample=True) + # latent = torch.randn(image_tensor.shape[0], 32) + # test_styledconv(styledconv, image_tensor, latent) + # visualize_feature_maps(styledconv, image_tensor, num_channels=4, latent_dim=32) + + +# upconv = UpConvResBlock(64, 128) +# test_input = torch.randn(1, 64, 32, 32) +# visualize_block_output(upconv, test_input, "UpConvResBlock") + +# downconv = DownConvResBlock(128, 64) +# test_input = torch.randn(1, 128, 32, 32) +# visualize_block_output(downconv, test_input, "DownConvResBlock") + +# featres = FeatResBlock(64) +# test_input = torch.randn(1, 64, 32, 32) +# visualize_block_output(featres, test_input, "FeatResBlock") + +# styledconv = StyledConv(64, 128, 3, 32) +# test_input = torch.randn(1, 64, 32, 32) +# latent = torch.randn(1, 32) +# visualize_block_output(lambda x: styledconv(x, latent), test_input, "StyledConv") \ No newline at end of file diff --git a/train.py b/train.py index d184196..6ef4a61 100644 --- a/train.py +++ b/train.py @@ -9,7 +9,7 @@ import yaml import os import torch.nn.functional as F -from model import IMFModel, debug_print,MultiScalePatchDiscriminator +from model import IMFModel, debug_print,MultiScalePatchDiscriminator,IMFPatchDiscriminator from VideoDataset import VideoDataset from EMODataset import EMODataset,gpu_padded_collate from torchvision.utils import save_image @@ -19,7 +19,7 @@ import lpips from torch.nn.utils import spectral_norm import torchvision.models as models -from loss import LPIPSPerceptualLoss,VGGPerceptualLoss,wasserstein_loss,hinge_loss,vanilla_gan_loss,gan_loss_fn +from loss import wasserstein_loss,hinge_loss,vanilla_gan_loss,gan_loss_fn from torch.optim.lr_scheduler import ReduceLROnPlateau import random from vggloss import VGGLoss @@ -273,7 +273,7 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato # Clip gradients - torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) + torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=config.training.clip_grad_norm) accelerator.backward(d_loss) optimizer_d.step() @@ -294,7 +294,7 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato # C. Optimization accelerator.backward(g_loss) # Clip gradients - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.training.clip_grad_norm) optimizer_g.step() @@ -417,9 +417,10 @@ def main(): add_gradient_hooks(model) - # discriminator = PatchDiscriminator(ndf=config.discriminator.ndf) Original - discriminator = MultiScalePatchDiscriminator(input_nc=3, ndf=64, n_layers=3, num_D=3) - + d0 = IMFPatchDiscriminator(ndf=config.discriminator.ndf) + d1 = MultiScalePatchDiscriminator(input_nc=3, ndf=64, n_layers=3, num_D=3) + discriminator = d1 if config.training.use_multiscale_discriminator else do + add_gradient_hooks(discriminator) transform = transforms.Compose([ From af2a39fd5dec236cf8a169f6956de25be3f7d3e8 Mon Sep 17 00:00:00 2001 From: John Pope Date: Mon, 12 Aug 2024 11:07:20 +1000 Subject: [PATCH 121/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 6ef4a61..70595e0 100644 --- a/train.py +++ b/train.py @@ -419,7 +419,7 @@ def main(): d0 = IMFPatchDiscriminator(ndf=config.discriminator.ndf) d1 = MultiScalePatchDiscriminator(input_nc=3, ndf=64, n_layers=3, num_D=3) - discriminator = d1 if config.training.use_multiscale_discriminator else do + discriminator = d1 if config.training.use_multiscale_discriminator else d0 add_gradient_hooks(discriminator) From 26111c40d8029ac19d360f470ec053a16473c334 Mon Sep 17 00:00:00 2001 From: John Pope Date: Mon, 12 Aug 2024 11:42:17 +1000 Subject: [PATCH 122/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 2 +- resblock.py | 47 ++++++++--------------------------------------- 2 files changed, 9 insertions(+), 40 deletions(-) diff --git a/config.yaml b/config.yaml index fa8e567..b815a4d 100644 --- a/config.yaml +++ b/config.yaml @@ -33,7 +33,7 @@ training: lambda_gp: 10 # Gradient penalty coefficient lambda_mse: 1.0 n_critic: 2 # Number of discriminator updates per generator update - clip_grad_norm: 0.5 # Maximum norm for gradient clipping + clip_grad_norm: 0.75 # Maximum norm for gradient clipping r1_gamma: 10 r1_interval: 16 label_smoothing: 0.1 diff --git a/resblock.py b/resblock.py index 3149911..5cc66dd 100644 --- a/resblock.py +++ b/resblock.py @@ -12,26 +12,6 @@ def debug_print(*args, **kwargs): if DEBUG: print(*args, **kwargs) -# class UpConvResBlock(nn.Module): -# def __init__(self, in_channels, out_channels): -# super().__init__() -# self.upsample = nn.Upsample(scale_factor=2, mode='nearest') -# self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) -# self.bn1 = nn.BatchNorm2d(out_channels) -# self.relu = nn.ReLU(inplace=True) -# self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) -# self.feat_res_block1 = FeatResBlock(out_channels) -# self.feat_res_block2 = FeatResBlock(out_channels) - -# def forward(self, x): -# x = self.upsample(x) -# x = self.conv1(x) -# x = self.bn1(x) -# x = self.relu(x) -# x = self.conv2(x) -# x = self.feat_res_block1(x) -# x = self.feat_res_block2(x) -# return x # class DownConvResBlock(nn.Module): @@ -131,36 +111,25 @@ def debug_print(*args, **kwargs): class UpConvResBlock(nn.Module): - def __init__(self, in_channels, out_channels, dropout_rate=0.1): + def __init__(self, in_channels, out_channels): super().__init__() self.upsample = nn.Upsample(scale_factor=2, mode='nearest') self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.bn2 = nn.BatchNorm2d(out_channels) - self.dropout = nn.Dropout2d(dropout_rate) - self.feat_res_block1 = FeatResBlock(out_channels) self.feat_res_block2 = FeatResBlock(out_channels) - - self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): x = self.upsample(x) - residual = self.residual_conv(x) - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - out = self.conv2(out) - out = self.bn2(out) - out = out + residual - out = self.relu(out) - out = self.dropout(out) - out = self.feat_res_block1(out) - out = self.feat_res_block2(out) - return out + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.feat_res_block1(x) + x = self.feat_res_block2(x) + return x From aef90c44f231427bb606030efbe5b89190a73975 Mon Sep 17 00:00:00 2001 From: John Pope Date: Mon, 12 Aug 2024 13:21:16 +1000 Subject: [PATCH 123/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/helper.cpython-311.pyc | Bin 25442 -> 23843 bytes __pycache__/resblock.cpython-311.pyc | Bin 41500 -> 40926 bytes config.yaml | 23 ++++++++--------- helper.py | 34 ------------------------- resblock.py | 12 ++++++--- sam.py | 36 +++++++++++++++++++++++++++ train.py | 2 +- 7 files changed, 55 insertions(+), 52 deletions(-) create mode 100644 sam.py diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index 8aac144e47eaf695356fc78b036ff973522f0525..3c4b811aae70360c07570c13dcb05c9eb06877c9 100644 GIT binary patch delta 678 zcmXYuO=uHQ5XUo(NwhIBM5~qNW0O{76Pke0G?juyO5?#(axm?sS-VXdKQ{8Vf<1^H zdaObiurwk)m|7&I&}>xVrKCvlAbx;?Mi9+a1r>rqF233HvHbSGGygX;?`?Aj9zF;A zn$6a2;47fiGMhsMdoTPb94|pn;gvglShg*|q+7Nx7$!aP5!u1NymGJX;q(MBYd0@W}vRcqT-vt$Ib>Pz;SjJmopS8>zuyw*Fu8M+diA&YOA#@gVP;O0! z_DytHp^kq}{2U)zGaKRINkQf=)8_-dhhmXMGq!NT4j)p0xDw5JLO3FMd?gny7h6Tu> zx4ciPgx=(cd5h#vl#=JcgB!^(gtRZonlbQ^$3?SVkol;IpAaN>KUWRY^cCsqH4)en z)-RL=z03)oP`kKA=v+#}GMmp|p2{dJf{UpS!Op$#e8a|OvN?5bhE0)!MfgV1$Wl~& zcc8_IV-Us4#6=uVd!7BAlkJ_$v+3E8qOi|+DIIXwX(s(xH)*RGOX>1Jg{1yRTcp|q j7a2nYYQ>^>LvnYIamgQpx~{EAkHBs)LiiVfIeGLy<8-{N delta 1894 zcmZWqYfMvD9KWaB-oEI=N-In%w*}D)!^wbJ9g0k&5^!m?cZrnl5VR_F>6lmMspK%=p12JLeW$(9?T<{hj~g zcb@m$Prd}dehUm2^m;9YORM{O=dCwR8S=s3rGfxY8DClLC|g!41-kr5+PqO z;Z~7QAeiw#D=frTyO4*`Ay_aL3!sJKUPz%Y8?>H$-zdG$r` zBmvgkP=e3V&AN@crzGYKfon*+z7RZ1EL(q-0ngBbO=&QSE^n>^!-?NED*>=2{@C)N z20TQO=0flVI^A5WlRFq9$wlaw=Ay!3OfyQq=nrZ zjkml8RwnMX+{h`rkL?@-EyP`nX9Y_Ll#}LC%rZ((fbR}-01ly>yC#?^;;2S})&fw2 z4!8a|H&ecCp}7Xfb2+e;gvfD^(BZbFz=bB;I?#s2R^=rHMV&?m!7@c0pv&bBR9a_6 zGfD&9uaH#eaZ%n90Zb~!u*?Lm+42s~>Cfq>ql%~kRTh@0BlrkK6>)|L%-WiWErK_v zUn!~P5-IB4=Um8{FWS#2s3;RvzF<{F>9`t=Y_+1t<qj zL=(7X3!re47EOaZYZoyj7)1;Tz2}!9zNiYcv+sU+{gxa$s3edDYVBM zWO)%AJ)uAdcJ>6=4j!(MQ)H-zUE&&xmlH!Yvkh!_X^Y zU4O7^fR%jSdEI11MSic>9p)t;yIhorz&66A5KA0#&feLP@{&Ik%qDRPScxYHKOj&= zz(rssfdd$Rr!iz`zDEjiAj9xt*gN2nBIc&-EVs+ETrT*#*TNGdQYrbnd=j*iL#xQ> zIyyqVZhz3@^NJa5m=|XGTbv2%aR@XJFcN4W*=n(;BOC}xA{@m{NM5vzy2~rMJ>7gb z!}z;JUufSzMupeH$A`Tct%!ZSlHaoj%M;!wb9jzqZy>OPz%C4&8jp>PHlumIP#Lvs zaR3fdJk8z z97BwnF;gbXiN0ff@9#gp|E|$~$7oL(i_^y9yGHJgkxLoN)5h}HmVX$^Xg(f@uTB~& zQih7Oq2l1?*oLF#8HOsb$6Xg{2hP=wtUJFhm0x+ucP(%wa5;E2IMsVQdMi48;Ohe^ z=Z@5(ok=*N}b*A8h?I!9XPh;6*5UvPAL((Fj-i_`kz*d{EfR2^wJ)Nr)(MDSQ} z=7UY|0QrVXXzdt(hL+L@-EN!po>+j?m0Nk@vdr0EVzlD6EVnIr$3hz0kjR1Y>J z4OOF6@!pZO=htGIeNuFFnyyaL)!CTFlbZOx>s3j*E=AX+={ih)rd5Lt2O9?)WAYDa z+DptO^7Tpfh`biTAbf!8+RIGxX^}7EezIpO)Ym@2B?y(jLr(~MM!-gn)k*^73WY8L fjcChmo3;&8{7~Y}L}>RIuo>*YG()`%@+a#*;zbbt diff --git a/__pycache__/resblock.cpython-311.pyc b/__pycache__/resblock.cpython-311.pyc index 6015a85bd5424483d3e594d342ce486f1138690a..730d9f204f3633e9b1efc5159e88efac3eee2d9e 100644 GIT binary patch delta 5106 zcma)A4Nz3q72dnMZ~s^p5D-`vSrLT?$e+u7{G?LIs`r2uew3!T@>5OTcsZDF`Irkw;2+3r1 zzIX1q=iYnncg{Wc>giXEfBc!z^~1zOhXH>pI)2i&FMQ0E$u>S+u!Ait2xjhQEwh4! zK{tMLgW3BHE#9C9aduD+=AbSQzq9Z=Czy-cd}VIXyVFQh4KltFF0;i8W-EnC!DFa9 zC-l_E3`T=8BRD5G@ALW=8TEOJ4Bs{tOGPMWL6xx#E}3RX#n^JgUTpo#I5V{bTZSF; zX)gLH1v^T?Ynq!-hWc!!L@7|pV2!E3O7rFD^I2)W9Hj`JGRan&Ek`Ma_f03QvPpAo z=vp2L?^>&W534GxArjhM|V}T6#pCcVo zQ)j|;M;aVU(7z|k3Yh^0Z=H#~sY}>81D9|T*~+~%Z5{z_jL#z|Adm@)5M(1S#*fB! zt3(21Cl>bYK|$?B=r_FQOuoPAP}9)t5vOO==^0?8xb4twgLfZnI^3i;R~}k9c*nuj zhga*(Rfkp$`VTf9Zk*2eU8nbLr}u2ih;z=UbB?h1g36MVJYeHaaJz~;Ip_tiK+qD$ zSqmtkkk2>Fcq!cF%G`7}`MLt(YCpQG|G8qcNmW-2(?FU{N}-M~2kxX5UGPM95gbj= zhTNnX@CPY3)|m7RyCseK(g|iFtmkeLlM#@`43aZRPU0Sth3m6Onhj59g&AK(x6P}!br9NLrxFhQG z9gIDlveZ~NCGJac(A40Y^tH>eOEck;s!3l+&&LWVPs?E^p*ih=0wlbOCa5I189`+T zHM-$qQW6Y+Z%#VQO0Q>&#~Jik#CEGCDdIZTn2~TT_M_) zvRO}9TVT6C99BXqucp-nO=grPu>|f+x8+gPsTc4#<-|5+{+qSOS-F+QiHOvbL`lR~ zA`I&*aKVa=D{4ou+aKaJ*mdK;ddCSuRc~5@;*Z{_0;O#c-s9)NC*TL}qxrSiKRsiD z6oN+tW81UqSs^lZ2dlc^dD2)Xc%sCbnr(9GB(mZu;|O9dsQ%8bkisWY&I7lnNKP&` z*FbZvfLXco8tAykvyHEzW}HF`Hj%>5r&`Ly_HLxui?KtwmyAtoX>gp;czua8x*AQU zqb+%2&T-1#frsTllX9oT6>O_C2qdn{kM0i-?s%+7h-A-Kz9bqi7T@gwN-R~$3 z>p>W=GkO)UkU=;qOr}d_SYDpWtgxw^=T2K7In7FNT~K_jyq@tGySd2-gA1xyGW>c$ zO%iRPdDt*))QqL&yd2y&)v-+2dXr+I*nL(3->A)jk8Y~X@MBYxqTO`ZMfsvzgSWxb zg-cmE+`F(^I*E$Z*u{nCObLC|*H1vJ$Z%$Hx@Rvn_Yv$z5cj@!yrO*ap6qm$|bw^CB4+iUJruS2yIHzKRL_1G|7wYHt#eH|geb>yP4mGMTAlCNM zfN3^_B_-~~x5!BM0wv9*?Cw$mQ6=aT+u}goClFK{!gzIo)_heCl+L{+}Tw?VY##E|CLM(b-eCt)-Hv2b|dM(puU$0 zULp7?!7mA3C3ubC62a>Pze13mnvH`mJg49|wzTw#1jJV!24-p=@(~ygZ;-B{?VB-E z4ygGqh!mIr#7K+iUYK*I#Ag@*BI3KbZ;7dhL$okG z1!d1xOByp6{=;Ig?fKMRua`qIOeP?jFjtcBGiD=&;8HNCrwZzNU8}+vmK*eW`Ey`A zut?9JAJZ6-H6oKH(nR06;9v?bU!55VlXm(T&HV-e?iouaL430bqtyN_K?=Nlph9xv zE*HkG9#|k1&qPsAh-tSgvS!RdLoWckIW4y1$lr{8B4b1riS3J2`iyLzMCi3o+CH}B zL@qWL5Abx`qGALG7!UiznUHOo1ut(bI$ExYM|$|AP0=1^5Qx~m0iO=~ zSyyb+u|C$9g=+M`O4C#2;*SzciG`nZQvxE$kDo*P*I|414IflH&WQKHU9bTPq&!sn zV&8!pX^scQcI?C7`^)qU63HVNHPP@^_{&fwtAdPEd8Jdz#WWsjwss}z3v~Ls_^Z%< zs)RicCr(x7h}cuWwd81IrxNysBfJyePrXWzKMQ|8wMbfoYaNNrdgvHiM&TQ)L0FH6 zd<`}F5%6hCWuqB&?@^P9rd*-lSGH(`2ey}n!+aKueSa?74oRmy?0qOXotsW>YNj=- z8o3E2z0zHTyH02KiP~RC%4Pz3dFq$PNs@*L=pn)BvgcvpAZV8EE_&5!Hu_f;N#Ht@nk#NBKN;;HqeT3Np?;B)72gLnXEGft za7z4ZP#=M9XC5qVG^H~9uZAze;MxrmZU(@0|eXkvbU!+T~oSuHxdBp z?x-NIA3;G`)}2vUe1}=r9akA23qH3fzOJi_?6^*6-5tc;S%+C=|Nq>Q#!?vr{q8yc zdEEbe|LY{5y(B&Ks+93qdb(AIpA9!ZzeX3^ol(F-&s42vtrG$Tcd-?vfvP|eK3#$0 zyL2nuffB^BKy|>0vT}Tu;nDO$UukPV zd9hNjR4EP6qOUU3-cF?krb)G8XD9A_Krh>9N2gM!R0XVi3kxmk3|kGC;80-^{K{UU&$uZHC+xM-4H7(4vlLdBTcO!>36`7e;Lc5h z-kWp3aj*B6I2ij{2T^A+T(E3{+|n}k1*}hVxxR3^9$J3*wx<`K@(nuMPdnSsINR?w zoUs%QivO@Pt(biV$I_O;Z^jithARiW`M1LuI!2 zg@0u^*>SLE@2lj~XwGzk*)%bQ-MJ(rcnn(0a^bV=HZ~V_6xLu^7Mr=yIx`V<{FRlD zMS+|#>cl1MR3leHy^Sf}R70;Im(qrssflWB&qyXH|ph@RYeiZ6ae z?7Wh=lv^w<8KLP$+GCUpFC^I%gkeuja)YA2oFDAFCc-!Rc;Ic=Qm|wEkJBXxA$VEP zbwgnrCea;*OIT*&Y~fSVbU_mN#b{@d9VtvT(S%6|sx=VV7>@aRq9KKo$CQs7IfBL; zd|xtQ9>pI|uHqB~DWsXOi+2BE(#g7lvQHB0ohPM5owRs_x3NP1$0!?@+;f<&3vh?~ z;qeqKV|!OSjS#3v*9z8-RLz64Wm#qlrB1&Y%BJUlyUM z&>O-$D_b;=A0dL0xClKlm#?C(@FNSh1SKP+K(NATX6X8{PLXmGM=I{scZry$0;NT> zUOYINafp2J%a-j8=XPvvn;}FP>_R8gS{~_*#^iL>7+dFyDym^aP}#_D#RY0wZ#eez zUPb9sG<84kj z;f|a@$6#hz&Yqe!wv6aaY_BO~&BVmo0}s;Y^;X%M;(N42HBc$zE^3+<2E&Sv^AL7; z8GKsnTs_L$h-5V_7T+9Fyn$ekRzR*wH7@9LbkKqT6~1J_a}oPKjs4!biL9R@MYVJV zLqdmnn9>}iMh=P7eWBpmFcw`NR(MP{Y42T6L$+WzuJCAtoI+E411;N(Fq~XOR#Z(i zAsiP|#!3ND;c>7DFvY{&VZ7iPneDWd*qKz#q3yw9;%%n3cKB_w^ zsgkFGH^aro*=!0-Z&FNTy^D=d-0Fnanr7wuv8x*6y%g&r$|(czb>N&doi)SCNi&T5 zQP7k)H0hY$wuQ#F63|;@_{Ef5yc1%Vsvs9P2lCp>8SQR(a!NgGgTG8!QzSg+9-4M9 z!8U^J1iRsu=33(cY|czAj~Qw zTJ+N-O$%(P8Oo-p;*ToW zVz7O_r(GD|pltAk_)6N@9FBO`a$i8T_(CCnC=yeCM7KTkN@7U0cG2z88}{`m=w{Xh zE^+DN`(WaNiN<3%V{^j4AVXrS6MW%!_8Cvo!lw}W9JAJ7oVZi#T@!yAM=;)=pjl!n z2~^Y1dgIaFIM0J8mKQP~oLoMAmC)DJsf1$@&TpalA})kg>0YK46G5+VYf(|ewNl(w z=yJiu{l!BqDxO>v>xI;kp=-rfwgKA7_OnnifqTb#|-7)-!|B(h>L{QDcnO`+W72eZ}q`yXEuM+%;;3a~S z1g8kzAb6eNO@g-&WIJB6p*2)dwHJG8ZqfnqzKw<%?T7q5xF>Xz@gmOsD)G0_d$xWu zBVSGMEWvRC5jEm!;-ju$hK{TFL7!9FU87 zCoQia5cU<87B)|QZRj7NgTFu%UnUR?2rdK-uaHb|A&3@)oIsgs#HtyIY2*lAh#@Sp zai70iwTeM+DA=u}@RiS#5ki5h*YWG%xy_;S)AWw_2+k0^Ptby(nvh9vILbeOw!T|U zc}%A(k`hPy{*%^5ZZw>8sYNu&(h%KlhAboyTnYv?S3y}{#{4jvspsM z)*on%@S4NaO1cSu31$rxbX?sHY$?Y&Ld)MFNM@##lYKw~9}+m=k%2~I37)7k@zKB} zV{HL8HRF$YG~ppbR$6)-KA)VESg`G%QoryH;W6Tg`5L(v&>2!tKqf%kX#*v`b1o&Bu8 z7{#c8nY5!U#UCXY#k;UY3i&ieXf+aZX+MSRipx60(r^p6+r1jD*q{B_kZ>jJTDo`n*ChY< z$lzR-`1as;%#lB6$lJB)aLH+d=ZwKKXz-jj!RLn@RVdGYz;xJt+E9PSP(Ns>|Hx1@ mXefH%+QZXM8ye0S8U_swXER-i$%h_h>3*ieQGKEKo#Ovk{ONcA diff --git a/config.yaml b/config.yaml index b815a4d..367ed8c 100644 --- a/config.yaml +++ b/config.yaml @@ -1,3 +1,7 @@ +# Loss function +loss: + type: "vanilla" # Changed to Wasserstein loss for WGAN-GP + # Model parameters model: latent_dim: 32 @@ -5,11 +9,14 @@ model: num_layers: 4 use_resnet_feature: False use_mlgffn: False - use_enhanced_generator: False + use_enhanced_generator: True use_skip: False # Training parameters training: - + ema_decay: 0.999 + style_mixing_prob: 0.9 + initial_noise_magnitude: 0.01 + final_noise_magnitude: 0.0001 use_multiscale_discriminator: False initial_video_repeat: 5 final_video_repeat: 2 @@ -22,10 +29,7 @@ training: initial_learning_rate_d: 1.0e-4 # Set a lower initial learning rate for discriminator # learning_rate_g: 5.0e-4 # Increased learning rate for generator # learning_rate_d: 5.0e-4 # Increased learning rate for discriminator - ema_decay: 0.999 - style_mixing_prob: 0.9 - initial_noise_magnitude: 0.01 - final_noise_magnitude: 0.001 + gradient_accumulation_steps: 1 lambda_pixel: 10 # in paper lambda-pixel = 10 Adjust this value as needed lambda_perceptual: 10 # lambda perceptual = 10 @@ -84,12 +88,5 @@ optimizer: beta1: 0.5 beta2: 0.999 -# Loss function -loss: - type: "vanilla" # Changed to Wasserstein loss for WGAN-GP - weights: - perceptual: [10, 10, 10, 10, 10] - equivariance_shift: 10 - equivariance_affine: 10 diff --git a/helper.py b/helper.py index 36ffd64..5b6f8a6 100644 --- a/helper.py +++ b/helper.py @@ -386,37 +386,3 @@ def add_gradient_hooks(model): -def visualize_latent_token(token, save_path): - """ - Visualize a 1D latent token as a colorful bar. - - Args: - token (torch.Tensor): A 1D tensor representing the latent token. - save_path (str): Path to save the visualization. - """ - # Ensure the token is on CPU and convert to numpy - token_np = token.cpu().detach().numpy() - - # Create a figure and axis - fig, ax = plt.subplots(figsize=(10, 0.5)) - - # Normalize the token values to [0, 1] for colormap - token_normalized = (token_np - token_np.min()) / (token_np.max() - token_np.min()) - - # Create a colorful representation - cmap = plt.get_cmap('viridis') - colors = cmap(token_normalized) - - # Plot the token as a colorful bar - ax.imshow(colors.reshape(1, -1, 4), aspect='auto') - - # Remove axes - ax.set_xticks([]) - ax.set_yticks([]) - - # Add a title - plt.title(f"Latent Token (dim={len(token_np)})") - - # Save the figure - plt.savefig(save_path, bbox_inches='tight', pad_inches=0.1) - plt.close() \ No newline at end of file diff --git a/resblock.py b/resblock.py index 5cc66dd..773a7e8 100644 --- a/resblock.py +++ b/resblock.py @@ -696,12 +696,16 @@ def visualize_latent_token(token, save_path): resblock = ResBlock(3, 64, downsample=True) # Visualize the ResBlock - visualize_resblock(resblock, image_tensor) + visualize_resblock_rgb(resblock, image_tensor) + + + upconv = UpConvResBlock(64, 128) + visualize_block_output(upconv, image_tensor, "UpConvResBlock") # Run all tests with the image tensor - upconv = UpConvResBlock(3, 64) - test_upconvresblock(upconv, image_tensor) - visualize_feature_maps(upconv, image_tensor) + # upconv = UpConvResBlock(3, 64) + # test_upconvresblock(upconv, image_tensor) + # visualize_feature_maps(upconv, image_tensor) downconv = DownConvResBlock(3, 64) test_downconvresblock(downconv, image_tensor) diff --git a/sam.py b/sam.py new file mode 100644 index 0000000..031f0bf --- /dev/null +++ b/sam.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn +from sam2.modeling.sam2_base import SAM2Base +from sam2.build_sam2 import build_sam2 + +class SAM2FeatureExtractor(nn.Module): + def __init__(self, config_file, ckpt_path=None, device="cuda"): + super().__init__() + self.sam2 = build_sam2(config_file, ckpt_path, device) + self.image_encoder = self.sam2.image_encoder + + def forward(self, x): + debug_print(f"⚾ SAM2FeatureExtractor input shape: {x.shape}") + + # Get the original output + original_output = self.image_encoder(x) + debug_print(f"Original SAM2 output shape: {original_output['vision_features'].shape}") + + # Extract intermediate features + features = [] + for i, feat in enumerate(original_output['backbone_fpn']): + features.append(feat) + debug_print(f"Feature {i+1} shape: {feat.shape}") + + debug_print(f"SAM2FeatureExtractor output shapes: {[f.shape for f in features]}") + + return features + +# Usage example +config_file = "path/to/sam2_config.yaml" +ckpt_path = "path/to/sam2_checkpoint.pt" +feature_extractor = SAM2FeatureExtractor(config_file, ckpt_path) + +# Example input +x = torch.randn(1, 3, 256, 256) +features = feature_extractor(x) \ No newline at end of file diff --git a/train.py b/train.py index 70595e0..cd60152 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,7 @@ from VideoDataset import VideoDataset from EMODataset import EMODataset,gpu_padded_collate from torchvision.utils import save_image -from helper import log_loss_landscape,log_grad_flow,count_model_params,normalize,visualize_latent_token, add_gradient_hooks, sample_recon +from helper import log_loss_landscape,log_grad_flow,count_model_params,normalize, add_gradient_hooks, sample_recon from torch.optim import AdamW from omegaconf import OmegaConf import lpips From 24f9d4edff4aa8de17eb0eb9b0b19d1f1c57e79b Mon Sep 17 00:00:00 2001 From: John Pope Date: Mon, 12 Aug 2024 13:22:44 +1000 Subject: [PATCH 124/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/train.py b/train.py index cd60152..e8b2839 100644 --- a/train.py +++ b/train.py @@ -109,7 +109,6 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato style_mixing_prob = config.training.style_mixing_prob - noise_magnitude = config.training.noise_magnitude r1_gamma = config.training.r1_gamma # R1 regularization strength From d570aa0a14bb6b55a7d500dd5b9abb626190695e Mon Sep 17 00:00:00 2001 From: John Pope Date: Mon, 12 Aug 2024 16:12:35 +1000 Subject: [PATCH 125/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- resblock.py | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/resblock.py b/resblock.py index 773a7e8..cf89b20 100644 --- a/resblock.py +++ b/resblock.py @@ -120,16 +120,24 @@ def __init__(self, in_channels, out_channels): self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) self.feat_res_block1 = FeatResBlock(out_channels) self.feat_res_block2 = FeatResBlock(out_channels) - + self.bn2 = nn.BatchNorm2d(out_channels) + self.dropout = nn.Dropout2d(dropout_rate) + self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + def forward(self, x): - x = self.upsample(x) - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.conv2(x) - x = self.feat_res_block1(x) - x = self.feat_res_block2(x) - return x + residual = self.residual_conv(x) + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + out = out + residual + out = self.relu(out) + out = self.dropout(out) + out = self.feat_res_block1(out) + out = self.feat_res_block2(out) + return out @@ -696,16 +704,12 @@ def visualize_latent_token(token, save_path): resblock = ResBlock(3, 64, downsample=True) # Visualize the ResBlock - visualize_resblock_rgb(resblock, image_tensor) - - - upconv = UpConvResBlock(64, 128) - visualize_block_output(upconv, image_tensor, "UpConvResBlock") + visualize_resblock(resblock, image_tensor) # Run all tests with the image tensor - # upconv = UpConvResBlock(3, 64) - # test_upconvresblock(upconv, image_tensor) - # visualize_feature_maps(upconv, image_tensor) + upconv = UpConvResBlock(3, 64) + test_upconvresblock(upconv, image_tensor) + visualize_feature_maps(upconv, image_tensor) downconv = DownConvResBlock(3, 64) test_downconvresblock(downconv, image_tensor) From 9c7d1381ae7aa98cf704f9918b103a3c6038de01 Mon Sep 17 00:00:00 2001 From: John Pope Date: Mon, 12 Aug 2024 16:15:49 +1000 Subject: [PATCH 126/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/resblock.cpython-311.pyc | Bin 40926 -> 41448 bytes resblock.py | 28 +++++++++++++-------------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/__pycache__/resblock.cpython-311.pyc b/__pycache__/resblock.cpython-311.pyc index 730d9f204f3633e9b1efc5159e88efac3eee2d9e..f9bce0de8f1ef4c9f8234760bc0c90cb86e7589e 100644 GIT binary patch delta 4251 zcma)93vg8B72f~u-hJf-*@Vq&*(DD)o5$vbNt8eg4}tKCU`Gk+vdO(*$YwYGyMdTm z0~SW8hzSQlpkVBbVwI}QwLWUAtNM?)15P_uJ8IAQHwl?EFuk+; z-SeOS-1DE;fBt(vd0YD1DJl7QQj*<(*V*H5Z0w5kCugwGN$+M>KU>Mz!J20(UL_0f zZY6t%q1mJ4AkI<>6&Lby@jespd5Rmk`D%gUX_rVfKxhc4I+8 zIr6jBQq`-L!!1Uyg><{rBB+y!g=QCOK5Xsmm zE0$|8z}1x*QXmzY%z1h83EQ(q)Q|u8^`If(7^4-#T7%CHdn|Wb8FwPrmj)x2$EtN> zOSn?EE#i^RNLQ><(G!(88Q{T~s#``!dJj{)Nr>pwxpl{=f)+dnR0Rr^37LBFJ5nA?R}}>F$7O z8P-)cUWS4kLt)_g1hWae1jPiDNA4piLinkPm*Y)m-MV$mA>D|MM@`7~c)*cR=HcY@AvBytNu zJ;_aQrh%9QedG*$lDe2Rz~0Ovcq%d5!V773f!68pN&a{aN{B;=&}FWMx+;C6S+KU= z7Yv8hP;|`pq4ZLALi;YgKw8pB&rC|=a^l^DFq*A#tC7`{cAFZEtx=;3LXp;X|HcU4 z8sy5WusvgM{t|j_N>7BX1TP3rmS!Gg3$#m_=OuGIuJC3~ZGxzyxDW7c&g`s7alf6k zj|P9ftC*eEHoMM9E7ubH#$|auSEfyNnC1lV8%mec(-*dL0lrG1*Izwu_@Wrm)1g@%iD7U_9 zzI*^BvbAGXFBr|+kmTE8bBPU?)OULKxc^kHCvHrOXc>*R-eaB8lAuKO%UP zpbze@EtU@<>C^Vqri`n7dal2CT-0ag77`}=_$GqK2=)>@PS6Km&-KbjkgU;W*Zo0q zZbPysrNpzA65@&87wlAT@u0Pd_1)z{3suboRiwoy$D_ZGrbIGCivoD1-X|YL<|6IW z`p1lH8+0z5!M4LA3mbZ`nP|UeqHYScMPqSgVjT}@oWgFOb;61dkx+a^r1WI9yHjnA zsY*at&-c^x69`c&0=7%KV1A_b@sh#UiO#OreK(Ui(Yv z!-QVS1P>6LB+v*Xf;55+1nC6ivd--MW4!lRJqzQ9CZC4Gk?-*GF=1Jc&Au?w(b5)H z6<;=AOHbw!A3eDB8B%Nv?2s=BTtv{5Qs$M;SiCanCNme z7Hn~M99gZ;^LD=Qji{@9T^Ya2BPA%fmBcOvM}ahEM6Y|I()jL+AX|# zgCs?nh|(Fa3;q_=?R2fWZpsrTICPwJoFy3V8xx#5PZJ*zq{2^jR>;}t%slPF&MFxX zG2}`#wT)o1SE3q5>@=mpcG1H)+BCTje|#kvNCf zrTTa&h$=`&1~<{;0XTi2k}ZHQ59InMZS*ND)UBJ;SfI5d*vZepl7l7eO^6+=bcu?k zlyzCC!1&7?-;lf?F3g&XuwG-P7vtK`k_ z_Mtg5+ep=nScJFUx1}xG776=fJQ$ARt7}J;Uxw+2-MwU-Zfsh%g41hyyfcZMnrWCo zv*wi%r3+sf#1}>bu|;EEO_X@kiV}5z;1B_w)0_@e{xre!1k{;1wPF4h0c|`^1zX7} z%ADNdb7|%?f-eZJBIv1TH+~Xg1OoUqC5|aLd$_>8o+L#6+ewVR9ph)164@H&J7D^e z4Es`4ia$g2Bvc(aRFci`o9BO_Z_WO7SBd}6Ym@V}e;@gVIh=#08GYRYIYTD@u*pAY z@?SB-&fbYx*=2fu&I2|RDLPNt<5=hlqIz? M1I`*VIh`~A1*w4}BLDyZ delta 4021 zcma)<3vg7`8G!ff?%nL>vE)fM$%bqmklZ{r4<15>M-mJqCd7!A5hU3zxfcj*cH`L% zkAO`Cihw) znD74QKmYyj|2XH|o4>tb{=+ZL&gbLeqD}Z)*#7gCTZ8+ZF1F-U{u(wfUvX_?%O)uW zN;+P%l#FerWgf+i%vNM26LmRworKq1B@4BA>SV>U)=aMY7+(TIY}V9lwLr~(4t2+c zp5~azY*G`ITxH7F^~+?`=cqFLz)~a?qMQbm=2Cdek|-77C>^%o=yUVLq+%SI)|8Mh z`l57VHGLEtzi-j5Mg7^svF?#r_YOvi`*z*e_vp?= zyB8V5x?OdB5A1B%-Czvsch&d#cQ)>B9MAZ}SkHU0p239S*xZrWTw(DIm32b=4m(_m zD`(}voGHu;sm{#B3o)J>y%>`5@4DG8RbDrZgJipuKoj2%9*e%X6UtL+wIQcsoWr1?m>hZqROzEMM}aBD08O3wd7hh3*&2GmD*z|vU#R%SrM$z z9rqf6RVaGYtr&HQtAVZx*=n@W-nQBw461=JpGzLSFp^qgJBDKTG%lqaO?Lg4wcJZn zB8gg(sIt5c@q*#R$i&XoaJ$mw5AZ6S3R@AF)9s{i)SbLQ^+y(~;hI3GZLN1@h_Cl^ z0QkDz-%*W4?mtXGe(58fUFk6wMkRbzup_(Of$D~ zE6OU0dL75%nGA*;8QO-d=ge#|t?y~0i8YYHIMK(Cjr@IBn{)I#TGT{OCcGwERGnMP zny^KM&{FKyF6K6|yfLAs(hV*`8euuX3klu~ST;G$_GcW1AEN!OJvI3Q%fv@;tUJ1b z;YYhvwOcjJm%`(6iB0$|gEMkToUoDn2gu3#0Tkd z%KT7p-R;E|1$kwB{G>TOgYcC|`kk^`#z+{T5k>NZSNl8FN)JZ+Mb(Bf;e}StU9QNBEE%oXPg9~ZbcXQ* z@S|#3I)dVC?Nap~i%a;9FZMKw2HoX6Qd`|K8guk$%(^wu9*!8^#e+TMc$`FJ`&}VC zMn*;mDCtghL#NsnQ5By^hXYL?M1<`KI*jlK(-yd-K^)g<4=gxqWxWvJ&}K7&f{hK{ zX3=7+x(*MrrD&!{1w+17+^^_%e<08n2!+*eiK`E-Ih96_UO^|EFX-=3c^Q1(kYzJM z0e54m^a5He)0!KdX4b0tnr_CHpP(CGM06)ruf#R)x%-@~;4k5fY~jbrN|?R`dmfp> zd~oKGIW7^R?IAT74spJWOvFO=gJXGq7QG+1cmoxZR6r1Cc^yf$aR2fxtPkE>UKk^a zD9U*NqFQFND!8|$LHZEQ)@cV@ekifuY3XhMVyq3~)RodFsDD-4pst8wJ#aZNJJl#3 zYGw1)SiD@dA^2(Vc1(Si}xMb1)<3&NAS`urByL=>{5H)P{oZ$Fx#uc`M;0 z;WUB1YIriiM@S)14mykGm+{(d_slomnLHkcSHHra%*3<19ntec9V^;{sv>7_@nU&+ z^u=Q2LhfEA{E8s*5t)gkeoZ=&nW&OT?=0$c2@m*C*hp1m`8GVdDb;P{iUx8j@1uLg z&MYIT5sqyNVuSK!k28L|lXXXIV9QnYt8Uy+f>?Hu^}T=ip{h4;xZcO~!~ zIeM2MQaM1!v=)~V&~?_oV@sKyAy0Z;p#A0v#Zwx4uk zN-m-rT_cYFr^MN}V+cFYBOaLFhaAY4a!~ElnxIO`b)&cn=kRxBp0Q`d@(?@4LbtcV zss0L934iI&8M|%AiBPw%QX{^$4u2z~#YqhRg!v}=79U*tDd|5Oo%Cv-9lgosWei!<`)3X= zAGTJESSyCC71wPhr+Y_hf6bs}I4*A_F0XgNN7i(*nlU(gsJwdES~Fs;8M4+~cbeio T+T6jTEN(3`p@}6;{FU@SL%8wJ diff --git a/resblock.py b/resblock.py index cf89b20..a5d9d73 100644 --- a/resblock.py +++ b/resblock.py @@ -111,7 +111,7 @@ def debug_print(*args, **kwargs): class UpConvResBlock(nn.Module): - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels, out_channels, dropout_rate=0.1): super().__init__() self.upsample = nn.Upsample(scale_factor=2, mode='nearest') self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) @@ -123,7 +123,7 @@ def __init__(self, in_channels, out_channels): self.bn2 = nn.BatchNorm2d(out_channels) self.dropout = nn.Dropout2d(dropout_rate) self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) - + def forward(self, x): residual = self.residual_conv(x) @@ -147,27 +147,27 @@ def __init__(self, in_channels, out_channels, dropout_rate=0.1): self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) - self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - # self.bn2 = nn.BatchNorm2d(out_channels) # 🤷 works with not without - # self.dropout = nn.Dropout2d(dropout_rate) # 🤷 - self.feat_res_block1 = FeatResBlock(out_channels) - self.feat_res_block2 = FeatResBlock(out_channels) + self.bn2 = nn.BatchNorm2d(out_channels) + + # Residual connection + self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity() def forward(self, x): + residual = self.residual(x) + out = self.conv1(x) out = self.bn1(out) out = self.relu(out) - out = self.avgpool(out) + out = self.conv2(out) - # out = self.bn2(out) # 🤷 - out = self.relu(out) # 🤷 - # out = self.dropout(out) # 🤷 - out = self.feat_res_block1(out) - out = self.feat_res_block2(out) + out = self.bn2(out) + + out += residual + out = self.relu(out) + return out - class FeatResBlock(nn.Module): def __init__(self, channels, dropout_rate=0.1): super().__init__() From 62770afbd0688fe8d9aa08c92c73b2db92cb4df0 Mon Sep 17 00:00:00 2001 From: John Pope Date: Mon, 12 Aug 2024 16:20:09 +1000 Subject: [PATCH 127/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/resblock.cpython-311.pyc | Bin 41448 -> 41386 bytes resblock.py | 168 ++++++++++++++++++--------- 2 files changed, 111 insertions(+), 57 deletions(-) diff --git a/__pycache__/resblock.cpython-311.pyc b/__pycache__/resblock.cpython-311.pyc index f9bce0de8f1ef4c9f8234760bc0c90cb86e7589e..7cddedd59fe269124e33e737f7257423885dbc6f 100644 GIT binary patch delta 3043 zcma)8dvH|c72j_+d$W)1F3GwNl58HEfZUKL2?-k`23TYB_90|*kTybHuyL_AQ zN+|e`V8_=6TWP_o@ z-O2Bsd(QVb=XcKezWn>7{?{k<_E&8-iw>XthyU`#X!LjXXN+*bUTd(giNksOeD;v~ znei%v?S-@8$-E5slfwqSbPE)kyetdqO=0s0<6fLDJK>1w(CqXs9*f6%do&wY65+1i zP$bwHi*6k;!WOe%&bVO>@=%YWnS;U3NGOpA1`|eHp#YyfoP6Q=@S6EQFzB_ZzqEYK zDl17>OfZ*V9)V0yMo>X8pP&?BB#kdnM>EoyfpG_n*-E5*GBp6RyI3uF>^{~C3+-jF z*8a4QJJ8mdhH&#qN5O!9S5GiFRN4w2$rjCmD{eOg9QNJ)x?l#W=;_k6>E!fP+=ViC6L=8Xc@Bv_ z1lh*(NEZGVkW@qSVii`d)nROWhU zwB$YPbiJ_n<CnMo0I z3FPLiV80`xMKrjcXRhN9;zF_{c;mOW;*|Xq@+AC3?o!xWlK-K; zuAN3rQ8>Mkl1Iyo?5**u#dLVgT!$@WAvg(@b{8xyYGf@pS%kAinam9rioUR>v^RWM zJf~qcP0L0=o|NubS8oK5N!fIJWOUr}TVr_#ebz!|xtryykIZ?;P$hDqCpD!;Hi(xv zXf9z5=_cx*t^DL?I~o@Ym8|F4&>}x-qm5;2aBNmPV(&NUbKV3`dm^=bKT_{cyU>S3;$ok|GL3Gbeg>M3i7xxSRXQ zJT07g*7kT6xuoZbYx?%PV;RCaMo(Y6fazl4zqEFfLri0_sBx zzcPqfSD0I^tUoK^p#Ls8h?e1{9~-QK8PNnT*w@&`{4jdI2d*||;Q*;=(ul7wn}6=W zNiRLVeW@rm5vkZiRGc`;#T7<}mzEydx-lM$MZ|^1MJtNl7R_7+tJl890`TqHCQQ%+ z9nI_mb)e%JBm1YC-T9cFT~>RPzVvBMU5gx*sQq{=9PRc=M@f7GKI(RtzfIy>1b-zs zM({U+Qv|08-XZuq!5Juy{)oK`Tcc~(HT6t%(wgl5^<)$z{23*US@^&3z0*{%Jb8wR z^1Jjud}Jn;+iA8ek9BXrbau%({4QEB=`n5dI>8A7^17Cb0EF)+Nht!@K$3{!I7*sv zr@~{2WJDsE)3Bw_lbZ|)2gn}&Jk8io5Fn_BSNbCUF%sV;7$4=jro(0w+AT zw~FPeZ|5ZE8(=3BaWJcK6lS$lQ-#}m>$C^SJl<)O6A1Awn z%W3IRWm{-__y%jY>OeMWQjZc%#JFbQv28rXBXZ&%0zuZOdSc*X#&Y15{ql+wK4=3$ z#}xbKyNSO@tPF0bX2t5m@PLu!!E=KiNyKJ^BZGC!0iO=~Z!hQ>w$@A=m7ZW{cPP$J zLCsJpdmTbU3nW4F`!F>0L$*kr7|{^ zeF9$}$am3BnxSRIN^U?Y*-poyY&geCEuC!Q7HAtTt^R<%FB4oNa8qcx1jPi42!23@ p34*HxpTqBm6Kn;{JvcgVHgituz7GTZg&m)f{)-z%HTTdN!*@?M&;bAd delta 3197 zcma)84Qx}_6@K^Ge*Q_E9~|TSIT)uv9+3YwAvEifu%xBXgaB!t{r|; zCDqVKSX5F@w}l0!{ivm3blvnci3vKirCUokHrCOKJs~KY#EOu*u(7HY(z-rwk@jTws#tW_dM(}%iFNGgFu{veS@L=r z($cr+o0M-eQ^GGnceTgdD+T7kPNfhQ+p?Go9=3(;D&rp9O3i}TY~Ndw8sf3;SWhCr zg9**BPRQQhgzWxYdmzve3C81rK-`QLd_G@4r*OCzUbo+3FYBWY zi(FPp5w6apFfsiT}BqY`Os|L z)EJVGJX4INpuqBPMe;3&PKQ^yXv@cc=*!MyPr-G^WAH&qhAj7jufqe^ zN-~)Y+gv#$#L?{%F`A(^$-NSkSInoS794+HfrX`RTHYe}${Qq3-l^8v$Jj{Ti?f^# z(Uc6aic~8Q;wA))E14NJ@V3jFa-8IY2>lY&<@#Z}KOK(Zu^fK((6b}1!Pro2!1O`d zyu;x`;StkdbSR2bIN_}5YruE-N`fzw35HOYG+a+Yf?-muMfjv2%_sjRWQS$tg!Flk zonmoKNdFIbsxS|?mo0??cfKNywL!0QUt@i9rH`66ceap~LIP0=i*RIExEARtpHUIZ zNl#^wR8^>06Ew4Rd@UG_YLWPiq%liUDv^7y?(~zMD!F>P*d8Zavx;pE$sR(uSF}?{ z#1_b6%SWLU&p1lMw(VMRzpQ>yJ=ziHwP`f%Rs@}q(Zs+`&7q`GGCF_ry>BNW>m zjR(8BBO1TIrt(vh?IB49{G=eK*08M(Y0*SDu~$`Qmr*P2IbBN%PLY)vy!@x&s$9K+ zR!x&LS0|d27nfi6eSYS(kkuBr;eUp;5HIk}!c69dkE4u9}yz#VmXkeKJBE9)s)Hd1g_Sw6>v(;XEpkH^+zmI7$w*^dH(uy%l^iSrek;- zb%vSxzc?&ISsRXOLC$xeocmy)*t?Vxbw+oiKn#04u{WXxLg6m%BlUgZ+$avS5(>hw zwTC0ZUK7SMj~2Q}>aUwp1T$ocXF*fr2IhnBHmaz6vvI)y;dpB$CTVD}{ZL zVWQMUe-u4R^h~+81;T9~0j%zAo?MXUXat2!=Jjvq#hR;RW8U@vfeP#@{5}IKgiSeo63Kg0~3H z6TD6E4#Dr>Mpr9)7gj`{WS{CUL?=>`U2YpG1skVvWH1N6g3o<6->PJ55_x#@GyLI4 z1)1AtcdUwawd1*l)I9zL>TfCpeunn?6#;qOmkdDo{%aD70NQCJqBw^m!?I7~v3N2f z5e$W8SmKFb#}30Glz~Wihc+1`Dpe1Ek!&0x5LZ?;y<(>%FAv~nkYR0)(g1q*Kjptj ztKKE}Gr=VSy5bClKA582{5`m{e=DZY`d+g`5eeglE#K&KPe&!(76FXI!r|=Ivx!G* zSfbNDouEA~5D4#pM$;a=N7MHS=0WSha+arm{b02v*TYaElenMYd4i-j`lX?HCbu|2 zq2ER-(Y5cw-NShq!n0{=$w${3;+9*`reTe0yMueelN3++uG$PGp+(~{ZkTv%H=p4a zdEz1vj9t(V3}0s~7q%Z!SI;6v9Ah&3a6i@lMM&@q%W(NfB`efZ;VTyAh1L;|A_6o+ zbflU!z$+vEl3Ce5$JU0eQ%eLox`N&OE%V%u*K+y{O4i;v_{V5oNXwy@ZqsP6lF8~g85$<<}dBN PP4YXFO8x5b-<$pmE;0al diff --git a/resblock.py b/resblock.py index a5d9d73..3d0badf 100644 --- a/resblock.py +++ b/resblock.py @@ -108,90 +108,140 @@ def debug_print(*args, **kwargs): # out = self.relu2(out) # return out - -class UpConvResBlock(nn.Module): - def __init__(self, in_channels, out_channels, dropout_rate=0.1): +class NormLayer(nn.Module): + def __init__(self, num_features, norm_type='batch'): super().__init__() - self.upsample = nn.Upsample(scale_factor=2, mode='nearest') - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.bn1 = nn.BatchNorm2d(out_channels) - self.relu = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.feat_res_block1 = FeatResBlock(out_channels) - self.feat_res_block2 = FeatResBlock(out_channels) - self.bn2 = nn.BatchNorm2d(out_channels) - self.dropout = nn.Dropout2d(dropout_rate) - self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + if norm_type == 'batch': + self.norm = nn.BatchNorm2d(num_features) + elif norm_type == 'instance': + self.norm = nn.InstanceNorm2d(num_features) + elif norm_type == 'layer': + self.norm = nn.GroupNorm(1, num_features) + else: + raise ValueError(f"Unsupported normalization type: {norm_type}") def forward(self, x): - residual = self.residual_conv(x) - + return self.norm(x) + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, + activation=nn.ReLU, norm_type='batch'): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) + self.norm = NormLayer(out_channels, norm_type) + self.activation = activation(inplace=True) if activation else None + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + +class FeatResBlock(nn.Module): + def __init__(self, channels, dropout_rate=0.1, activation=nn.ReLU, norm_type='batch'): + super().__init__() + self.conv1 = ConvBlock(channels, channels, activation=activation, norm_type=norm_type) + self.conv2 = ConvBlock(channels, channels, activation=None, norm_type=norm_type) + self.activation = activation(inplace=True) if activation else None + self.dropout = nn.Dropout2d(dropout_rate) + + def forward(self, x): + residual = x out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) out = self.conv2(out) - out = self.bn2(out) - out = out + residual - out = self.relu(out) out = self.dropout(out) - out = self.feat_res_block1(out) - out = self.feat_res_block2(out) + out += residual + if self.activation: + out = self.activation(out) return out - - class DownConvResBlock(nn.Module): - def __init__(self, in_channels, out_channels, dropout_rate=0.1): + def __init__(self, in_channels, out_channels, dropout_rate=0.1, activation=nn.ReLU, + norm_type='batch', use_residual_scaling=False): super().__init__() - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.bn1 = nn.BatchNorm2d(out_channels) - self.relu = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.bn2 = nn.BatchNorm2d(out_channels) + self.conv1 = ConvBlock(in_channels, out_channels, stride=2, activation=activation, norm_type=norm_type) + self.conv2 = ConvBlock(out_channels, out_channels, activation=None, norm_type=norm_type) + self.activation = activation(inplace=True) if activation else None + self.dropout = nn.Dropout2d(dropout_rate) + self.feat_res_block1 = FeatResBlock(out_channels, dropout_rate, activation, norm_type) + self.feat_res_block2 = FeatResBlock(out_channels, dropout_rate, activation, norm_type) - # Residual connection - self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity() + self.shortcut = ConvBlock(in_channels, out_channels, kernel_size=1, stride=2, padding=0, + activation=None, norm_type=norm_type) + self.use_residual_scaling = use_residual_scaling + if use_residual_scaling: + self.residual_scaling = nn.Parameter(torch.ones(1)) def forward(self, x): - residual = self.residual(x) + residual = self.shortcut(x) out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - out = self.conv2(out) - out = self.bn2(out) + out = self.dropout(out) + + if self.use_residual_scaling: + out = out * self.residual_scaling out += residual - out = self.relu(out) + if self.activation: + out = self.activation(out) + out = self.feat_res_block1(out) + out = self.feat_res_block2(out) return out -class FeatResBlock(nn.Module): - def __init__(self, channels, dropout_rate=0.1): +class UpConvResBlock(nn.Module): + def __init__(self, in_channels, out_channels, dropout_rate=0.1, activation=nn.ReLU, + norm_type='batch', upsample_mode='nearest', use_residual_scaling=False): super().__init__() - self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) - self.bn1 = nn.BatchNorm2d(channels) - self.relu1 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) - self.bn2 = nn.BatchNorm2d(channels) + self.upsample = nn.Upsample(scale_factor=2, mode=upsample_mode) + self.conv1 = ConvBlock(in_channels, out_channels, activation=activation, norm_type=norm_type) + self.conv2 = ConvBlock(out_channels, out_channels, activation=None, norm_type=norm_type) + self.activation = activation(inplace=True) if activation else None self.dropout = nn.Dropout2d(dropout_rate) - self.relu2 = nn.ReLU(inplace=True) + self.feat_res_block1 = FeatResBlock(out_channels, dropout_rate, activation, norm_type) + self.feat_res_block2 = FeatResBlock(out_channels, dropout_rate, activation, norm_type) + + self.shortcut = nn.Sequential( + nn.Upsample(scale_factor=2, mode=upsample_mode), + ConvBlock(in_channels, out_channels, kernel_size=1, padding=0, activation=None, norm_type=norm_type) + ) + self.use_residual_scaling = use_residual_scaling + if use_residual_scaling: + self.residual_scaling = nn.Parameter(torch.ones(1)) def forward(self, x): - residual = x - out = self.conv1(x) - out = self.bn1(out) - out = self.relu1(out) + residual = self.shortcut(x) + + out = self.upsample(x) + out = self.conv1(out) out = self.conv2(out) - out = self.bn2(out) out = self.dropout(out) + + if self.use_residual_scaling: + out = out * self.residual_scaling + out += residual - out = self.relu2(out) + if self.activation: + out = self.activation(out) + + out = self.feat_res_block1(out) + out = self.feat_res_block2(out) return out - +def init_weights(m): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + +# Apply weight initialization +def initialize_model(model): + model.apply(init_weights) + class ResBlock(nn.Module): def __init__(self, in_channels, out_channels, downsample=False, dropout_rate=0.1): super().__init__() @@ -704,12 +754,16 @@ def visualize_latent_token(token, save_path): resblock = ResBlock(3, 64, downsample=True) # Visualize the ResBlock - visualize_resblock(resblock, image_tensor) + visualize_resblock_rgb(resblock, image_tensor) + + + upconv = UpConvResBlock(64, 128) + visualize_block_output(upconv, image_tensor, "UpConvResBlock") # Run all tests with the image tensor - upconv = UpConvResBlock(3, 64) - test_upconvresblock(upconv, image_tensor) - visualize_feature_maps(upconv, image_tensor) + # upconv = UpConvResBlock(3, 64) + # test_upconvresblock(upconv, image_tensor) + # visualize_feature_maps(upconv, image_tensor) downconv = DownConvResBlock(3, 64) test_downconvresblock(downconv, image_tensor) From fc9bd5768594afc32e4b70009f81fbc45fd06623 Mon Sep 17 00:00:00 2001 From: John Pope Date: Mon, 12 Aug 2024 16:27:12 +1000 Subject: [PATCH 128/152] clean up --- __pycache__/resblock.cpython-311.pyc | Bin 41386 -> 44929 bytes common.py | 182 ----------------- crossattention.py | 76 ------- implicitmotion.py | 285 --------------------------- loss.py | 89 --------- perceptual.py | 100 ---------- resblock.py | 5 +- vggloss.py | 84 -------- 8 files changed, 4 insertions(+), 817 deletions(-) delete mode 100644 common.py delete mode 100644 crossattention.py delete mode 100644 implicitmotion.py delete mode 100644 perceptual.py delete mode 100644 vggloss.py diff --git a/__pycache__/resblock.cpython-311.pyc b/__pycache__/resblock.cpython-311.pyc index 7cddedd59fe269124e33e737f7257423885dbc6f..bb4cdd70251ff9e65c8345a18f7723339d8b2258 100644 GIT binary patch delta 10976 zcmcgy33OD~d46wZ-t05lW~9-+3kEHK*fE>H2n%GHMG#xY45N8ZFk&<#?u=OEaR$4@ z$V7={UmVLdcH}h1!jL#*oKw;|txVuRPu#l&S|cv zNX!LvsJLyMn0J~J^XX2$)Ktvk1rWEfcnOQ!Azp~_NU2x^waP>XV7XWfSRs}GR*I#7 zRbm-nwb(3HoaTFK#1^p%-nC*iq}GWwfc4^Hu@=%A;9Upr1!6sN!x0sZb&Mzb z#05QzRGd^RExN#S+$H!^o}R`CBo~VdMW?vva^nRxWWrxEvkCI*r6#HI0yK1qHHJAn z2kC`SqiJ?}Pcx)9NX=5CXr0{>t6|4p=G16WML4NkYMHl-n%kt?#D#Nau^2k8kQU#j zmHE?4*rXTM%pPga5@?F+Uz!vFCpGrB(kit*vT)AGmqEEoX^GS*Et}m!QaZQJ6{S=C zCrSrvfN(@}WS&mumqRtmK{U=&o-Om7Y2Xr9;}VPcx2qwlK{0mRsu^{+(Tv+^({kvl zPFgNC(obrN`6bbD%?f^@p({WJI^8EFA`89lkms-`t$wfs8u-Kbua{6u&Ng#cCGv_+ zlT!MU-9bOwWi1PHE-c0#k=w&TAyvB2SWXdjnsQdoQ%IdGR2C9a)j6(Tn+l7XFi&0D zk_fgL%1I@2oG64<&vP%RW~<<1yi?mnu$Zj(`hy|2-y_L_4<|N_);3(V*B=}j91M_< zBv$!h(r%x3*d6i){8gcogVNe6r(PDI1PDU!a(VsUkjo{j{Qih++Kh5RIkk$i2CK@} zj${|fRKqqB7#hTgY~1Je4M|%G2@t17)&wPApKSII4Y>LwcW8)6LD_&!xUf^%;Bt9< z?qJa63Ze|Fs;W*As;(Bj4uH!2%?8KdvZyM0KAq!u9c@ z=Cq+l5Cw1#+7a#q-=o{ZIkmE3OTd3@voGK|GL5!I3J_3}J$hDM0Sq%!lp*j8|11|$ zoCqz=+2aVS!6b4i>6Q zb%!SB6uR1-ysYWF#h|W_C&KKZE5q0f6?KbD!)u+cD;TqI0x!gr6HV z<&16wJvG!$<=URjc{*nz!pC!K61g=aZBse9BYLK#7xknb8Y2w|3lP|fX?9DUdzBSL z8zKu4oB%=8sZ9z6HDV${1T)+KS*A3#VdFJEwx@fltm6E($>TF>LD;?nOiK{%fL1}s z!y>l_#k3mfjmF~0O6YO6UCpv7RiB#Kl$zO;W|mC}vbi1DcDEE{X7$0JKZnoPH!4<@ zk-7)!KGq0kEC<3SSL8n%sWWDFG+wYfrLD`(laS?Y&%JfyFm?ny7mR-_Ui%PUlLTd04 zdNi-1w_^Onq~mITtoxyu{XpD)AYnfMhGe{MS2DzF-kAOVxc&Zw{eDcpv-w>wDn1vU z=!_RFO%yGK;^qnvhokYM)%0UyE}d!7P%SB>i_JE=-n5XuP;D~z<1(^`*y+K7 zSP7TB`%YVR9|V!Q4~lWM&!L}sjdpiE zeo`9z>(!8SCsPCmr@5Vjs|F93dw1ZtAGt1Nkz@n3$gHv5qoeQj+Ef}9wdU-u0DA%l zip~sN8Q8_8v~6TeP6b6cD4@Q@CQ}0jgJdSB5jB^EKg#6pH|NqHXf3oqXtt;nZ>qFT zcjnrtH)xrYdZo-xhqb1DHeM~bd>tE-{sQ>WOKCp}W5%5ld%{|XSu!!w97@Mv-;3J` z_Jq(MCZGP?R7lj)sy!9-cjZ=v6%1X|=uMnzvI<@qnoz9pHjdn#FjvRe zFEj(Cl(Ni>Jt~f#aTM@<^gZKi>oS{AVg$A3`qTC_453HVkfs_2J%mu!*$Z|x+=P}^ zFSY~49&Ou&3owUF*`gpNDGPNUhIS&{12C6L>9vA#?G?$7*Q>1l4k>f z?PIjF_Axryx_5AX-jsEI$ql9yYM*sVFhS33KXV!k{~4N>^8k=jSbJFoN9j^pR%XgY z#$x=Mq)+;i*1ij?B_GJW9ld6F?SG-jT5!t#Y_MjNEOL4oDuuK1!TpgKB*0*wX$$86lvvqU_T!%aqGv5_A z-<2@mb=_Ml6;?CXj31!CsGxOp&P9=su_i;Xuq0G!T@o~vJTM+(0)hvQb4 z9|P|Gm(>}rPI5Qy#Mv_scc*O=cE6jx6*goN6@tf3J)qa=2bYHN?&FY}zMEm$VUBL^s#iNhD$3W{=!MS| zc4qrwye@*REL6^1{PT=gk{3M5MnbrgOFo%fQfL%Tj4F7!%WSqj^=x7!oykr=Xeg|j z)v0E7L7vMLj6gy5jRiXsg}DXOtf%dQYdp8#1xF`YLnPmjQ$-$vDsbBw@WU1m+>TZ@@M3ALeUW1pda?TOAcK|kyt_WU_x_vHLdtCA!IvgVCJV-G@2^KfN zQ5WxbLoU(+ak!)62iUzu@@~=0?rxR3K{5bEegS{M&jXy|uIuu~s-M|??t!?jGNFS# zmPV+*t~ZYA-qBaSt*@Nmzo?H)o_zCIOkWw-cO~>)vDEK%L%~?v3##+hFR2m>*L9xeep_EWUVVPyL|a_nkkB{8^bN`OlQ|6&)p31YLSGk4{muwb zowdP~OXcC|rhJM-YR6wCu|kMJ|d0 znBP0BaJG*eg+@O_aiA#S{HigW+40zpXUuU;c|ubjOZ^fvT*0)P^{Mk(JJ}E&D}7BB zX@Y9Z*|Bwn@GKQX+njxY@rQ~+{-db9@>{Ad<`dhn z>@kFN&*V5p8Ar~yqlL`j+>~@u7XucaE~+_S$IiBOSbgr5S`ZTvenWpz)58Z*$n?{- zd3*tVq4uU5cNUp@Rg!knvOwH^(NFp@KP_vJrvxiLU`L7Z?dL{+sjrleY;u){1xnm?+fM8?)oRxOAD|fQjfC5Q7em&YA~$<_+ahmlAnC` z(MI5JxH-hWtU;dq~5Q(tS7G=xo@EYT`6y-NFdL&`F=9e7Z4u{jTJ- zl&n3>%wHCcd&SUUHhEcKQzqN#WoH>*OTX?I6-0pTwIX5F~IOdnX&%pas@*7TJWV2KL|uOLh!TtxUX)vaw5{t|)@ zMpv!PQEQbBXy^Kt#&l7TGH$mBj+BfbTtc{v@Fj#V)0fvT5WWsUZ#2ICEfxPFeQA?j z%f?wi-`Esf!m2WR&eX@MPC@x{@D*%^sag$&u25H-Q}}BL9*cga?KA2o9M*7dQ&lHh zx~kw|H-5rYB*Cne%@9RRnrQph=16An#?0Vljn5klDLt~mFy+aNC_N+4ZKgYHSm&`- zWmpiBt$pc?iU2GKOtC$@HU<#hPuK94QL-<=rd@P#N{x=%_^YnY$Jz8Z@bk&X) zGh20ybPSFu-Sp^=qUQNW@ftRcqmYfg_{GfShdU-vqOa|!)G8fQdBwh#$WCyoZ9| zLCK?(Jh;S14q?4+R;63=2LpsWf+?u+ibmk(P;qNaoSQMq{0j3h>}b{rX|m3J?9gB! z;8X5R@b(pm9`;dijgPNZa-8Pve~f>Pp4-33#N^E+e}#T~|2lp<)jqUC_$^fWN_5{t zpAz{0jQ+6qL!SQ>eM{^UOu*tdqidyJT@}+5t1`K?o@8ss8@}?QjCQD|&CupF`b-k1 zGrmI8_c8cS2nm`ua3I1|_Zt{zhB<_xR}tn}G4eyqxDFuelD;RUmB`{!0RuB_+Ko z@_qW$<0}>YVMfHXhG|k-O_Yf<5zG-n?v#plaZ@UH@IJYQkQamwnQ z9i2oCA%*6CO#kEg(nur5k@DH?G`qxL$2fLCV@CwGS5>^q>sSu=Bm}n<1h)YMeJQzy zfIfgA;S*#!f;LM~rwK}%;PjOL0nlP7A0dwim8G9PUtjqM)@6%63L)777s`rT#iJVe z0)6Ye!*m#`z%Lgp73hcOKNs1qw(#&DGPlB*>YU*Qrr*qDR#0ckI-}*RWg`d1H&3YJmb!$c?o|7f qrUXl^oLCoI+7{PrPG~mAG@ECvTyE7@Ctlz!M|ciOsEgUD|Gxo~0$7It delta 7413 zcma)A3vgT2nbwu=)zg-rvSnGag!8cFS8M|gPMoJnLN>-Yu>*}mP-NX3Ted8TBRLNx zMP?cn8`9R{xJ+o1t;-B1PSQdumck61ra)U@yI@mM;4*$pnH zKKW^Jyim+JuIjXkxqwAt9-vLk2P_sXfOfF}&>^lAt;e~}60uCQ!P6-g1Jxzk0dEt_ zMF-GI;aLLDWug^jw^?kltAVbgNX6c6`Ya^gHv?EK~^#6-kv# zECh6VRc(ovF1kw<%(6*UOUz17FJM=;>?9?Bcnh>cxxDUs)_%5#88iGs8NW zGiyc}NdqLhp`-v-w$t+^23(!XrkiY59ymE{Q#jcJu&8jsuu!@%Aa`apQPr+;8(XK* z>ttO}@)9W$b*tr!h|e34JO{kKXo$#se@K+3Q3})O(-Rt6olz3+nBZicKNt>pebNL+ ztiW;S$wFixED>4n@%V%OsK+C#gF#upI~?)$hXazVADC%rQTFwsEcoDAC!02VqrRTj z5a|b+rYl${^Sh*$-7-(4z<@HSUe4Gmd7~+Paeeeg!Lm3Et(OgVlTbJ`5QV8aQEAGV za23Be;0<_ig9$E9*-~JjYK@22x`d% zzd?eLmN)@og^KYM&R9>)%!LVXBiHGw9LF72+iBB_Ij_2@WrZ82E5A#;MpSNPnJC#?S%qyp#*(~(KRcnNtk$Z-N^y(^$Iq*c_)a_GdXVUB( z(tJekEndY<&{0FR9mPtTaI4Ie=zUTk;PHTMHP9;t>)wZPqg?%xYF(iTg0v-7Gs_xrt1 z^q$Ha4xI>rg+sUJiJpNKFP1bG@ z9sz5=B^KPu(3ufaFviRasQh_vP<;<9%hrikUcWHr@b4!{hvvNp_c8=w~ zr7WB~w)7Ir!f^@uZw@17SLWcN67>y4$u?Nct!1g2%k*WgSn?Q}+#nIZF;>62USWaU9R)xxY;_G~?rDic0Y(Q~U#FUPo& zKH#+>2jLCcTieOCrNq3RE1-_Lui^`HkZfr45^uj0g%VSM($v=jF6a%4L0Nl9@^|+{ z3BETG8vw{!g1U7e0L7eJzgRF>t@OWRSzvDB1K3L?1_^O*E_WT0gOlo)^&-u z>xVU6EN5cfOi=%B4STS3I*l3P^ED@G|MsDTy-rnFg&NF$OC-Yc&Gdq{HEq;8vAbq%KDKZT$lDF>&tMQHSn&Cg_ zjrv2u^qP_l?9v|5-_Ite^o12mWRb#{zILv)a{5H?+BPnZ%8~V5{s7y5E9fZ?kC2b# zOmD#74X*1EiuoVcO8*fMApjnTho?l>v;6 z<>$7ex(eMekilv`Y4;J8P|+j()nXlZzx8*FUc2a_Fx%Zp^}Cw@#=g)Uvv! zmNe4rO`ACvZQmqmQDc9tqjjwg`pKqEgf^Cx{TOdqbPa%t$94d=Y z`cA7`I0MbQ6E|AFtEn1T zzva^^hoaAZtELSFBJ0o}W&NH_j3~5+YEHYnCZ-ssB1-Feo4BzjfYgx>T6njE;)|Qmgl>x4pw#wBVin``}%&nfzfRWRf zxk&;+jT24VgCI}W_I_`-L|#GiKO?-1yi&F&{Z6bS&SdptBD{t04#L|A|BCQC0C$G0C-nZK%Q#A7M~h9f z=1hK1$BypS{uRc2z$O0Y==-KPnh5bATtt|?rSfjNA8AZ(cu$!#k}L3x>0QmrR~(ji zr~eK=H^&YUGZ>pg{aq06MRzgj!0F6Vm}xL0VdjF`XGMH`?&jn*B)*QoF3YaXZu|z) z*tMBV+1)P#O%}j-LlH#|Oav$lOwwLopKN4IPr%fj&6APbh%5MNQ(;@P8S$4B`+ao?{E81O}4@k7+#;m?-yS zYtB5(is_`J=OVAr|E3L!rM!wGnSC&unKcOI;!FgK#BysJFwMK@^SJiA2+Xx8w&w0? z`UCbqM98DBj@Aiwxa7LTpGR*ODy`6DR!;_zzxcY$Xr_(GN+j<6A*YEmV?j3|`w=V% ze?xc?AZC~`9Jo64;WNh~-KM+ZWt^^sgWyNfL*B#wl*VT*->sLmPzXS%lARh7Iz;AO z2+yC7z|{IK{o;(5JDh0!r#Kfc1~QyLhs$HYCeI-((gU-&Map1;pYt}uH<9$RnKMKT zunQEq`)SKqnXnwl-HF3v8--;h*aRN@u5M7QhZzo&sDRTw^pof6xh!fPFRfV=6w;hf z)^|(rfwkWoCU4P)#;ds3=#%614yJchwwxUe^+`cb5I*}tEf$kR@(U`DHwbsZUOz}! zoA0(&39@>NMUPh;yDgdZWCLHG&6IfRc9zCgf$ zM$k10nl3@161)Z3fQ-)(zC`#7fSh~KAAv8^aA6O8rG|v7_QBWei51ShIE3lD569s* zY$YAzdm}kco6cEH+hJAsuZmpcsps6&@fJ-n2d}Uq!4b&Uodib%s$1}B>9?2z6L(dne9Ul zo!WGYOl6iPGfM}ze#FD5!c*F@j48f4$yZPE)gSS;N#1s 0: - model += [nn.Dropout(p=dropout)] - self.model = nn.Sequential(*model) - - residual_block = [] - residual_block += downsampleLayer(inplanes, planes, downsample='avgpool', use_sn=use_sn) - self.residual_block = nn.Sequential(*residual_block) - - def forward(self, x): - residual = self.residual_block(x) - out = self.model(x) - out += residual - return out - -class UpConvResBlock(nn.Module): - def conv3x3(self, inplanes, out_planes, stride=1,use_sn=True): - if use_sn: - return spectral_norm(nn.Conv2d(inplanes, out_planes, kernel_size=3, stride=stride, padding=1)) - else: - return nn.Conv2d(inplanes, out_planes, kernel_size=3, stride=stride, padding=1) - - - def __init__(self, inplanes, planes, stride=1, dropout=0.0,use_sn=False,norm_layer='batch',num_groups=8): - super(UpConvResBlock, self).__init__() - model = [] - model += upsampleLayer(inplanes , planes , upsample='nearest' , use_sn=use_sn) - if norm_layer != 'none': - model += [get_norm(planes,norm_layer,num_groups)] #[nn.BatchNorm2d(planes)] - model += [nn.ReLU(inplace=True)] - model += [self.conv3x3(planes, planes,stride,use_sn)] - if norm_layer != 'none': - model += [get_norm(planes,norm_layer,num_groups)] #[nn.BatchNorm2d(planes)] - model += [nn.ReLU(inplace=True)] - if dropout > 0: - model += [nn.Dropout(p=dropout)] - self.model = nn.Sequential(*model) - - residual_block = [] - residual_block += upsampleLayer(inplanes , planes , upsample='bilinear' , use_sn=use_sn) - self.residual_block=nn.Sequential(*residual_block) - - def forward(self, x): - residual = self.residual_block(x) - out = self.model(x) - out += residual - return out diff --git a/crossattention.py b/crossattention.py deleted file mode 100644 index 77b13a8..0000000 --- a/crossattention.py +++ /dev/null @@ -1,76 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -class CrossAttentionLayer(nn.Module): - def __init__(self, dim, heads=8, dim_head=64): - super().__init__() - self.dim = dim - self.heads = heads - self.dim_head = dim_head - self.scale = dim_head ** -0.5 - - self.to_q = nn.Linear(dim, heads * dim_head) - self.to_k = nn.Linear(dim, heads * dim_head) - self.to_v = nn.Linear(dim, heads * dim_head) - self.to_out = nn.Linear(heads * dim_head, dim) - - def forward(self, ml_c, ml_r, fl_r): - B, L, C = ml_c.shape - - q = self.to_q(ml_c).view(B, L, self.heads, self.dim_head).transpose(1, 2) - k = self.to_k(ml_r).view(B, L, self.heads, self.dim_head).transpose(1, 2) - v = self.to_v(fl_r).view(B, L, self.heads, self.dim_head).transpose(1, 2) - - attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale - attn = F.softmax(attn, dim=-1) - - out = torch.matmul(attn, v) - out = out.transpose(1, 2).contiguous().view(B, L, self.heads * self.dim_head) - out = self.to_out(out) - - return out - -class CrossAttentionModule(nn.Module): - def __init__(self, dims=[128, 256, 512, 512], heads=8, dim_head=64): - super().__init__() - self.dims = dims - self.attention_layers = nn.ModuleList([ - CrossAttentionLayer(dim, heads, dim_head) for dim in dims - ]) - - def forward(self, ml_c, ml_r, fl_r): - outputs = [] - for i, layer in enumerate(self.attention_layers): - print(f"🌻 Layer {i+1} Input shapes: ml_c: {ml_c[i].shape}, ml_r: {ml_r[i].shape}, fl_r: {fl_r[i].shape}") - - B, C, H, W = fl_r[i].shape - - # Flatten and transpose inputs - ml_c_flat = ml_c[i].view(B, C, -1).transpose(1, 2) # (B, H*W, C) - ml_r_flat = ml_r[i].view(B, C, -1).transpose(1, 2) # (B, H*W, C) - fl_r_flat = fl_r[i].view(B, C, -1).transpose(1, 2) # (B, H*W, C) - - print(f"After flattening: ml_c: {ml_c_flat.shape}, ml_r: {ml_r_flat.shape}, fl_r: {fl_r_flat.shape}") - - out = layer(ml_c_flat, ml_r_flat, fl_r_flat) - out = out.view(B, H, W, C).permute(0, 3, 1, 2) # (B, C, H, W) - - print(f"Layer {i+1} output shape: {out.shape}") - outputs.append(out) - - return outputs - -# Test the module -if __name__ == "__main__": - dims = [128, 256, 512, 512] - B = 2 - ml_c = [torch.randn(B, dim, 64 // (2**i), 64 // (2**i)) for i, dim in enumerate(dims)] - ml_r = [torch.randn(B, dim, 64 // (2**i), 64 // (2**i)) for i, dim in enumerate(dims)] - fl_r = [torch.randn(B, dim, 64 // (2**i), 64 // (2**i)) for i, dim in enumerate(dims)] - - module = CrossAttentionModule(dims) - outputs = module(ml_c, ml_r, fl_r) - - for i, out in enumerate(outputs): - print(f"Final output shape for layer {i+1}: {out.shape}") \ No newline at end of file diff --git a/implicitmotion.py b/implicitmotion.py deleted file mode 100644 index 1dac7d3..0000000 --- a/implicitmotion.py +++ /dev/null @@ -1,285 +0,0 @@ -import torch -import torch.nn as nn -from einops import rearrange -import matha -from einops.layers.torch import Rearrange - - -def pair(t): - return t if isinstance(t, tuple) else (t, t) - -def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): - y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") - assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" - omega = torch.arange(dim // 4) / (dim // 4 - 1) - omega = 1.0 / (temperature ** omega) - - y = y.flatten()[:, None] * omega[None, :] - x = x.flatten()[:, None] * omega[None, :] - pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) - return pe.type(dtype) - - -# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/simple_vit.py - -''' -ImplicitMotionAlignment - The module consists of two parts: the cross-attention module and the transformer. - -Cross-Attention Module - The cross-attention module is implemented using the scaled dot-product cross-attention mechanism. - This module takes as input the motion features ml_c and ml_r as the queries (Q) and keys (K), respectively, - and the appearance features fl_r as the values (V). - - Flattening and Positional Embeddings - Before computing the attention weights, the input features are first flattened to a 1D vector. BxCx(HxW) - Then, positional embeddings Pq and Pk are added to the queries and keys, respectively. - This is done to capture the spatial relationships between the features. - - Attention Weights Computation - The attention weights are computed by taking the dot product of the queries with the keys, dividing each by √dk - (where dk is the dimensionality of the keys), and applying a softmax function to obtain the weights on the values. - - Output-Aligned Values - The output-aligned values V' are computed through matrix multiplication of the attention weights with the values. - -Transformer Blocks - The output-aligned values V' are further refined using multi-head self-attention and feedforward network-based - transformer blocks. This is done to capture the complex relationships between the features and to produce - the final appearance features fl_c of the current frame. -''' - -class ImplicitMotionAlignment(nn.Module): - def __init__(self, feature_dim, num_layers, num_heads, mlp_dim, image_size=256, patch_size=16): - super().__init__() - - self.image_size = pair(image_size) - self.patch_size = pair(patch_size) - - h, w = self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1] - num_patches = h * w - - self.to_patch_embedding = nn.Sequential( - Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size), - nn.Linear((patch_size ** 2) * feature_dim, feature_dim) - ) - - self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, feature_dim)) - self.cls_token = nn.Parameter(torch.randn(1, 1, feature_dim)) - - self.cross_attention = CrossAttention(feature_dim, num_heads) - self.transformer = TransformerBlock(feature_dim, num_layers, num_heads, feature_dim // num_heads, mlp_dim) - - self.to_latent = nn.Identity() - self.mlp_head = nn.Sequential( - nn.LayerNorm(feature_dim), - nn.Linear(feature_dim, feature_dim) - ) - - def forward(self, m_c, m_r, fl_r): - # Process appearance features - x = self.to_patch_embedding(fl_r) - b, n, _ = x.shape - - cls_tokens = self.cls_token.expand(b, -1, -1) - x = torch.cat((cls_tokens, x), dim=1) - x += self.pos_embedding[:, :(n + 1)] - - # Process motion features - m_c = self.to_patch_embedding(m_c) - m_r = self.to_patch_embedding(m_r) - - # Apply cross-attention - x = self.cross_attention(x, m_c, m_r) - - # Apply transformer - x = self.transformer(x) - x = x[:, 0] - x = self.to_latent(x) - return self.mlp_head(x) - - - -class CrossAttention(nn.Module): - def __init__(self, dim, num_heads=8): - super().__init__() - self.num_heads = num_heads - self.scale = (dim // num_heads) ** -0.5 - - self.to_q = nn.Linear(dim, dim, bias=False) - self.to_k = nn.Linear(dim, dim, bias=False) - self.to_v = nn.Linear(dim, dim, bias=False) - - self.to_out = nn.Linear(dim, dim) - - def forward(self, x, m_c, m_r): - b, n, _, h = *x.shape, self.num_heads - - q = self.to_q(x) - k = self.to_k(m_c) - v = self.to_v(m_r) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) - - dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale - attn = dots.softmax(dim=-1) - - out = torch.einsum('bhij,bhjd->bhid', attn, v) - out = rearrange(out, 'b h n d -> b n (h d)') - return self.to_out(out) - -class TransformerBlock(nn.Module): - def __init__(self, dim, depth, heads, dim_head, mlp_dim): - super().__init__() - self.layers = nn.ModuleList([]) - for _ in range(depth): - self.layers.append(nn.ModuleList([ - PreNorm(dim, MultiHeadAttention(dim, heads=heads, dim_head=dim_head)), - PreNorm(dim, FeedForward(dim, mlp_dim)) - ])) - - def forward(self, x): - for attn, ff in self.layers: - x = attn(x) + x - x = ff(x) + x - return x - -class MultiHeadAttention(nn.Module): - def __init__(self, dim, heads=8, dim_head=64): - super().__init__() - inner_dim = dim_head * heads - self.heads = heads - self.scale = dim_head ** -0.5 - - self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) - self.to_out = nn.Linear(inner_dim, dim) - - def forward(self, x): - b, n, _, h = *x.shape, self.heads - qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) - - dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale - attn = dots.softmax(dim=-1) - - out = torch.einsum('bhij,bhjd->bhid', attn, v) - out = rearrange(out, 'b h n d -> b n (h d)') - return self.to_out(out) - -class FeedForward(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.net = nn.Sequential( - nn.Linear(dim, hidden_dim), - nn.GELU(), - nn.Linear(hidden_dim, dim) - ) - - def forward(self, x): - return self.net(x) - -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - -# official paper code -# class CrossAttention(nn.Module): -# def __init__(self,feature_dim): -# super().__init__() - -# self.q_pos_embedding = posemb_sincos_2d( -# h = 256 , -# w = 256 , -# dim = feature_dim, -# ) - -# self.k_pos_embedding = posemb_sincos_2d( -# h = 256 , -# w = 256 , -# dim = feature_dim, -# ) - -# # motion current m_c, motion reference m_r, appearance feature f_r = queries, keys, values -# def forward(self, queries, keys, values): -# # (b, dim_qk, h, w) -> (b, dim_qk, dim_spatial) -> (b, dim_spatial, dim_qk) -# q = torch.flatten(queries, start_dim=2).transpose(-1, -2) -# q = q + self.q_pos_embedding # (b, dim_spatial, dim_qk) - -# # (b, dim_qk, h, w) -> (b, dim_qk, dim_spatial) -> (b, dim_spatial, dim_qk) -# k = torch.flatten(keys, start_dim=2).transpose(-1, -2) -# k = k + self.k_pos_embedding # (b, dim_spatial, dim_qk) -# # (b, dim_v, h, w) -> (b, dim_v, dim_spatial) -> (b, dim_spatial, dim_v) -# v = torch.flatten(values, start_dim=2).transpose(-1, -2) - -# # (b, dim_spatial, dim_qk) * (b, dim_qk, dim_spatial) -> (b, dim_spatial, dim_spatial) -# dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale - -# attn = self.attend(dots) # (b, dim_spatial, dim_spatial) softmax - -# # (b, dim_spatial, dim_spatial) * (b, dim_spatial, dim_v) -> (b, dim_spatial, dim_v) -# out = torch.matmul(attn, v) -# return out - - - -# class TransformerBlock(nn.Module): -# def __init__(self, dim, depth, heads, dim_head, mlp_dim): -# super().__init__() -# self.norm = nn.LayerNorm(dim) -# self.layers = nn.ModuleList([]) -# for _ in range(depth): -# self.layers.append(nn.ModuleList([ -# MultiHeadAttention(dim, heads = heads, dim_head = dim_head), -# FeedForward(dim, mlp_dim) -# ])) -# def forward(self, x): -# for attn, ff in self.layers: -# x = attn(x) + x -# x = ff(x) + x -# return self.norm(x) - - -# class MultiHeadAttention(nn.Module): -# def __init__(self, dim, heads = 8, dim_head = 64): -# super().__init__() -# inner_dim = dim_head * heads -# self.heads = heads -# self.scale = dim_head ** -0.5 -# self.norm = nn.LayerNorm(dim) - -# self.attend = nn.Softmax(dim = -1) - -# self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) -# self.to_out = nn.Linear(inner_dim, dim, bias = False) - -# def forward(self, x): -# x = self.norm(x) - -# qkv = self.to_qkv(x).chunk(3, dim = -1) -# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) - -# dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale - -# attn = self.attend(dots) - -# out = torch.matmul(attn, v) -# out = rearrange(out, 'b h n d -> b n (h d)') -# return self.to_out(out) - - -# class FeedForward(nn.Module): -# def __init__(self, dim, hidden_dim=None): -# super().__init__() -# self.net = nn.Sequential( -# nn.LayerNorm(dim), -# nn.Linear(dim, hidden_dim), -# nn.GELU(), -# nn.Linear(hidden_dim, dim), -# ) -# def forward(self, x): -# return self.net(x) \ No newline at end of file diff --git a/loss.py b/loss.py index 4e2ad31..5de9717 100644 --- a/loss.py +++ b/loss.py @@ -5,95 +5,6 @@ from model import debug_print import lpips -class LPIPSPerceptualLoss(nn.Module): - def __init__(self, net='alex', debug=True): - super(LPIPSPerceptualLoss, self).__init__() - self.debug = debug - self.lpips_model = lpips.LPIPS(net=net) - - for param in self.parameters(): - param.requires_grad = False - - def forward(self, x, y): - debug_print(f"\nLPIPSPerceptualLoss Forward Pass:") - debug_print(f"Input shapes: x: {x.shape}, y: {y.shape}") - if x.dim() != 4 or y.dim() != 4: - raise ValueError(f"Expected 4D input tensors, got x: {x.dim()}D and y: {y.dim()}D") - - total_loss = 0 - for i in range(4): # Downsample 4 times (i ∈ [0, 3]) - debug_print(f"\n Scale {i+1}:") - if i > 0: - x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False) - y = F.interpolate(y, scale_factor=0.5, mode='bilinear', align_corners=False) - - debug_print(f" Current shapes: x: {x.shape}, y: {y.shape}") - - # Compute LPIPS loss - scale_loss = self.lpips_model(x, y).mean() - total_loss += scale_loss - - debug_print(f" Scale {i+1} LPIPS loss: {scale_loss.item():.6f}") - - debug_print(f"\nTotal LPIPS loss: {total_loss.item():.6f}") - return total_loss - -class VGGPerceptualLoss(nn.Module): - def __init__(self, debug=True): - super(VGGPerceptualLoss, self).__init__() - self.debug = debug - vgg = models.vgg19(pretrained=True).features - self.slices = nn.ModuleList([ - nn.Sequential(*list(vgg.children())[:2]), - nn.Sequential(*list(vgg.children())[2:7]), - nn.Sequential(*list(vgg.children())[7:12]), - nn.Sequential(*list(vgg.children())[12:21]), - nn.Sequential(*list(vgg.children())[21:30]) - ]) - - for param in self.parameters(): - param.requires_grad = False - - def forward(self, x, y): - debug_print(f"\nVGGPerceptualLoss Forward Pass:") - debug_print(f"Input shapes: x: {x.shape}, y: {y.shape}") - if x.dim() != 4 or y.dim() != 4: - raise ValueError(f"Expected 4D input tensors, got x: {x.dim()}D and y: {y.dim()}D") - - total_loss = 0 - for i in range(4): # Downsample 4 times (i ∈ [0, 3]) - debug_print(f"\n Scale {i+1}:") - if i > 0: - x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False) - y = F.interpolate(y, scale_factor=0.5, mode='bilinear', align_corners=False) - - debug_print(f" Current shapes: x: {x.shape}, y: {y.shape}") - - # Compute VGG features - x_features, y_features = [], [] - x_current, y_current = x, y - for j, slice in enumerate(self.slices): - x_current = slice(x_current) - y_current = slice(y_current) - debug_print(f" After VGG slice {j+1}: x: {x_current.shape}, y: {y_current.shape}") - x_features.append(x_current) - y_features.append(y_current) - - # Compute loss for each VGG layer - scale_loss = 0 - for j in range(len(x_features)): - layer_loss = F.l1_loss(x_features[j], y_features[j]) - scale_loss += layer_loss - - debug_print(f" Layer {j+1} loss: {layer_loss.item():.6f}") - - total_loss += scale_loss - - debug_print(f" Scale {i+1} total loss: {scale_loss.item():.6f}") - - - debug_print(f"\nTotal perceptual loss: {total_loss.item():.6f}") - return total_loss def wasserstein_loss(real_outputs, fake_outputs): diff --git a/perceptual.py b/perceptual.py deleted file mode 100644 index 9e0ec5b..0000000 --- a/perceptual.py +++ /dev/null @@ -1,100 +0,0 @@ -from facenet_pytorch import InceptionResnetV1 -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision.models as models -from lpips import LPIPS - -class PerceptualLoss(nn.Module): - def __init__(self, device, weights={'vgg19': 20.0, 'vggface': 5.0, 'gaze': 4.0, 'lpips': 10.0}): - super(PerceptualLoss, self).__init__() - self.device = device - self.weights = weights - - # VGG19 network - vgg19 = models.vgg19(pretrained=True).features - self.vgg19 = nn.Sequential(*[vgg19[i] for i in range(30)]).to(device).eval() - self.vgg19_layers = [1, 6, 11, 20, 29] - - # VGGFace network - self.vggface = InceptionResnetV1(pretrained='vggface2').to(device).eval() - self.vggface_layers = [4, 5, 6, 7] - - - # LPips - self.lpips = LPIPS(net='vgg').to(device).eval() - - def forward(self, predicted, target, use_fm_loss=False): - # Normalize input images - predicted = self.normalize_input(predicted) - target = self.normalize_input(target) - - # Compute VGG19 perceptual loss - vgg19_loss = self.compute_vgg19_loss(predicted, target) - - # Compute VGGFace perceptual loss - vggface_loss = self.compute_vggface_loss(predicted, target) - - # Compute gaze loss - # gaze_loss = self.gaze_loss(predicted, target) - - # Compute LPIPS loss - lpips_loss = self.lpips(predicted, target).mean() - - # Compute total perceptual loss - total_loss = ( - self.weights['vgg19'] * vgg19_loss + - self.weights['vggface'] * vggface_loss + - self.weights['lpips'] * lpips_loss + - self.weights['gaze'] * 1 #gaze_loss - ) - - if use_fm_loss: - # Compute feature matching loss - fm_loss = self.compute_feature_matching_loss(predicted, target) - total_loss += fm_loss - - return total_loss - - def compute_vgg19_loss(self, predicted, target): - return self.compute_perceptual_loss(self.vgg19, self.vgg19_layers, predicted, target) - - def compute_vggface_loss(self, predicted, target): - return self.compute_perceptual_loss(self.vggface, self.vggface_layers, predicted, target) - - def compute_feature_matching_loss(self, predicted, target): - return self.compute_perceptual_loss(self.vgg19, self.vgg19_layers, predicted, target, detach=True) - - def compute_perceptual_loss(self, model, layers, predicted, target, detach=False): - loss = 0.0 - predicted_features = predicted - target_features = target - #print(f"predicted_features:{predicted_features.shape}") - #print(f"target_features:{target_features.shape}") - - for i, layer in enumerate(model.children()): - # print(f"i{i}") - if isinstance(layer, nn.Conv2d): - predicted_features = layer(predicted_features) - target_features = layer(target_features) - elif isinstance(layer, nn.Linear): - predicted_features = predicted_features.view(predicted_features.size(0), -1) - target_features = target_features.view(target_features.size(0), -1) - predicted_features = layer(predicted_features) - target_features = layer(target_features) - else: - predicted_features = layer(predicted_features) - target_features = layer(target_features) - - if i in layers: - if detach: - loss += torch.mean(torch.abs(predicted_features - target_features.detach())) - else: - loss += torch.mean(torch.abs(predicted_features - target_features)) - - return loss - - def normalize_input(self, x): - mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1) - std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1) - return (x - mean) / std \ No newline at end of file diff --git a/resblock.py b/resblock.py index 3d0badf..3a1fa2a 100644 --- a/resblock.py +++ b/resblock.py @@ -157,6 +157,8 @@ def forward(self, x): out = self.activation(out) return out + +# this is only used on the densefeatureencoder class DownConvResBlock(nn.Module): def __init__(self, in_channels, out_channels, dropout_rate=0.1, activation=nn.ReLU, norm_type='batch', use_residual_scaling=False): @@ -192,6 +194,7 @@ def forward(self, x): out = self.feat_res_block2(out) return out +# this is used on the framedecoder / enhancedframedecoder class UpConvResBlock(nn.Module): def __init__(self, in_channels, out_channels, dropout_rate=0.1, activation=nn.ReLU, norm_type='batch', upsample_mode='nearest', use_residual_scaling=False): @@ -241,7 +244,7 @@ def init_weights(m): # Apply weight initialization def initialize_model(model): model.apply(init_weights) - + F class ResBlock(nn.Module): def __init__(self, in_channels, out_channels, downsample=False, dropout_rate=0.1): super().__init__() diff --git a/vggloss.py b/vggloss.py deleted file mode 100644 index e112c32..0000000 --- a/vggloss.py +++ /dev/null @@ -1,84 +0,0 @@ -import torch -from torch import nn - -import torchvision - -# https://github.com/CompVis/ipoke/blob/84ceecfefd7bf0769cd35334f2eb824c44c4d56b/utils/losses.py#L67 -class VGG(torch.nn.Module): - def __init__(self, requires_grad=False): - super().__init__() - vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features - self.mean = [0.485, 0.456, 0.406] - self.std = [0.229, 0.224, 0.225] - self.slice1 = torch.nn.Sequential() - self.slice2 = torch.nn.Sequential() - self.slice3 = torch.nn.Sequential() - self.slice4 = torch.nn.Sequential() - self.slice5 = torch.nn.Sequential() - for x in range(2): - self.slice1.add_module(str(x), vgg_pretrained_features[x]) - for x in range(2, 7): - self.slice2.add_module(str(x), vgg_pretrained_features[x]) - for x in range(7, 12): - self.slice3.add_module(str(x), vgg_pretrained_features[x]) - for x in range(12, 21): - self.slice4.add_module(str(x), vgg_pretrained_features[x]) - for x in range(21, 30): - self.slice5.add_module(str(x), vgg_pretrained_features[x]) - if not requires_grad: - for param in self.parameters(): - param.requires_grad = False - - def forward(self, X): - # X = self.normalize(X) - h_relu1 = self.slice1(X) - h_relu2 = self.slice2(h_relu1) - h_relu3 = self.slice3(h_relu2) - h_relu4 = self.slice4(h_relu3) - h_relu5 = self.slice5(h_relu4) - out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] - return out - - def normalize(self, x): - x = x.permute(1, 0, 2, 3) - for i in range(3): - x[i] = x[i] * self.std[i] + self.mean[i] - return x.permute(1, 0, 2, 3) - -def KL(mu, logvar): - return -0.5 * torch.mean(torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), axis=1)) - -def kl_conv(mu,logvar): - mu = mu.reshape(mu.size(0),-1) - logvar = logvar.reshape(logvar.size(0),-1) - - var = torch.exp(logvar) - - return torch.mean(0.5 * torch.sum(torch.pow(mu, 2) + var - 1.0 - logvar, dim=-1)) - -def fmap_loss(fmap1, fmap2, loss): - recp_loss = 0 - for idx in range(len(fmap1)): - if loss == 'l1': - recp_loss += torch.mean(torch.abs((fmap1[idx] - fmap2[idx]))) - if loss == 'l2': - recp_loss += torch.mean((fmap1[idx] - fmap2[idx]) ** 2) - return recp_loss / len(fmap1) - -class VGGLoss(nn.Module): - def __init__(self, weighted=False): - super(VGGLoss, self).__init__() - self.vgg = VGG().cuda() - self.weighted = weighted - self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] - self.criterion = nn.L1Loss() - - def forward(self, x, y): - fmap1, fmap2 = self.vgg(x), self.vgg(y) - if self.weighted: - recp_loss = 0 - for idx in range(len(fmap1)): - recp_loss += self.weights[idx] * self.criterion(fmap2[idx], fmap1[idx]) - return recp_loss - else: - return fmap_loss(fmap1, fmap2, loss='l1') \ No newline at end of file From c95838a2fdcf7c3438aeb347113a7205fb216683 Mon Sep 17 00:00:00 2001 From: John Pope Date: Mon, 12 Aug 2024 16:42:46 +1000 Subject: [PATCH 129/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- framedecoder.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/framedecoder.py b/framedecoder.py index 4a3cf3b..681d79a 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -4,7 +4,7 @@ import math from resblock import UpConvResBlock,FeatResBlock -DEBUG = False +DEBUG = True def debug_print(*args, **kwargs): if DEBUG: print(*args, **kwargs) @@ -14,29 +14,34 @@ def __init__(self, use_attention=True): self.use_attention = use_attention + # Updated channel counts to match concatenated features self.upconv_blocks = nn.ModuleList([ - UpConvResBlock(512, 512), - UpConvResBlock(1024, 512), - UpConvResBlock(768, 256), - UpConvResBlock(384, 128), - UpConvResBlock(128, 64) + UpConvResBlock(512, 256), + UpConvResBlock(512, 256), + UpConvResBlock(512, 128), + UpConvResBlock(256, 64), + UpConvResBlock(128, 32) ]) + # Added an additional FeatResBlock for the last encoder features self.feat_blocks = nn.ModuleList([ nn.Sequential(*[FeatResBlock(512) for _ in range(3)]), nn.Sequential(*[FeatResBlock(256) for _ in range(3)]), - nn.Sequential(*[FeatResBlock(128) for _ in range(3)]) + nn.Sequential(*[FeatResBlock(128) for _ in range(3)]), + nn.Sequential(*[FeatResBlock(64) for _ in range(3)]) ]) if use_attention: + # Added an additional attention layer self.attention_layers = nn.ModuleList([ SelfAttention(512), SelfAttention(256), - SelfAttention(128) + SelfAttention(128), + SelfAttention(64) ]) self.final_conv = nn.Sequential( - nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1), + nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1), nn.Sigmoid() ) @@ -56,7 +61,7 @@ def forward(self, features): x = self.upconv_blocks[i](x) debug_print(f"After upconv_block {i+1}: {x.shape}") - if i < len(self.feat_blocks): + if i < len(features) - 1: # Process all encoder features debug_print(f"Processing feat_block {i+1}") feat_input = features[-(i+2)] debug_print(f"feat_block {i+1} input shape: {feat_input.shape}") @@ -76,8 +81,6 @@ def forward(self, features): debug_print(f"EnhancedFrameDecoder final output shape: {x.shape}") return x - - class SelfAttention(nn.Module): def __init__(self, in_dim): super().__init__() From bda32c2073e6e2f133951ed8a814f5b7e059fff9 Mon Sep 17 00:00:00 2001 From: John Pope Date: Mon, 12 Aug 2024 16:42:51 +1000 Subject: [PATCH 130/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 10985 -> 11316 bytes __pycache__/resblock.cpython-311.pyc | Bin 44929 -> 44951 bytes 2 files changed, 0 insertions(+), 0 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 38d2052a4ae686b73579dbc440dc9572e1437168..5a36a9b8be0d1df5bee6aa1aa352af73dc308a0d 100644 GIT binary patch delta 1652 zcmaJ>O>7%Q6yEWA*IwI+odm~8>^SSzPU?`P&0kAILR<(WhqN?FQ&36eI^K=#Vtboe zr%5AzBH+*qKjncaQjp4p9~A+S3u+H=;K-qBkU6kbB|xfB4oI#@g+x_gp53HBN;~#` z^WHb#o7s8q*=v*MpRfG1qQb_p@%xt-XPVtBm1aT?=4;3(vD+ETv*)iLcsT7p3O}0a zi3w!h%M)c|=Fw-YpKs)OMSv;(xhnAuE?0$HmWN^ihj~wxbxT&-iiMlJi(r&2C(NG^ zjyuPm@_31hN}D#JWAMGn1Lw>;8cp}=a_$Wk$0VJp$iK+ekJ1 zB2B|B$+8Cz{6@X7cO`i~$;?&nV1d1aBT|%vR@-$>@Y&Z9aHqOrCFiIwau&&X4G!Bv zfvc{@Gt;ZCK*1GY^6p9|4>_4UNrR8aasmNH@ZP}L&efXWPtM?FXE1-D<^$pjVfZkH zSKLl>TurB=(a#C(1>W7^4+^`B+=E~go^@wPd;UB30`bduCZPeG>@>(GTH3t@t(`F8 z>FkX#(-0O_s_BN1h-EUdFi(4s6jRx#lFHCt=GV3pi+4ZHmMJzW?MILi z24KM(9>pvfWe(Rz6LUtPx&TD(+FzzG0^b`8!0znJCXSWX0payyAYkKDW+?@MP9(g>v+o0%ow)!8r0+sKRgvvP~jaOcVcDEA@A z2>tL|XxFQy?uanoux3;}nV>0!Qglr@%@7s!L@KUQ9SLbJtIy7`XmFs9N zwR=Z4OIu(f91@S9RD)N-JACzeXZ{DdCWf;B delta 1452 zcmaJ=TWB0r7@o6p+3fCYlI8dG_v};o&>wEPUx`BLt#+ z#KsWMYXZ|kyr@a`SJIp?%RdKq`9|UbSr~##!hOl-d4ntZuuN@fE?5_0gohhKgdBz& zjuzMzl0yV6>)nHgz-fyu(SXq4jb(KZHfQyP>)s0E==H>K=B@qgWi zL>k1v5!c<$E+&*+s+!5onWn0m$XGPGR!iT>l+_4qNfD>aMDtAoJF!sWhe-OOP^t_5 z98uO^xGe`S4{q|?a#u<2Lgled|F#@2$#K|n`{6UU7k+eyNh92Izo=B5LfXB<31QD2 z1kBP}l{vYs%89Oq+pyt@2bj-bJ@bh-FrRqC#Umv-_L~epdFIJ+xYFUbR9Pfs7Ig0w zo?Db)sx@GR>i;C<474@=PFld{^Gis~Y9H&o8OD7f(hH}3uakD`dtZQz4zWGjkC0+e z=5uTFO&9m zEKI0^9}n%o>~@JbN>~6#iE`s8ijnX zsv3EtsH)Fc(YCkw=7X3Lr;83FU}ZW2pNCU%WUMTh#;Q)Aw#D;?exdAK*6GqpQKu)- za}0cu<~>s=KFc5_Q|yRNz~hm#VI*X=`QO#E1$tqQYNxF0k!hkFL+7U$n0R##)#k4o zjR;e)SD7K*@Tziz^jTX{)#CEZ+@r-nwCQFRmxWl7V|8yTIR zLxXE2Z=PGxbw-%o#yA>z)5x=FDwy;XQ|U_#QeK;-vc!(WEKu4CHxm;i3nUpLMXDrhS?eT-g$aQEVI@0wV3IEYf^9)NV15MPyhI9?uDKe%J>-T9}-Bz)bS{R;?S BS)2d> diff --git a/__pycache__/resblock.cpython-311.pyc b/__pycache__/resblock.cpython-311.pyc index bb4cdd70251ff9e65c8345a18f7723339d8b2258..c9c05474510705da2473f93d762d7528a2c94a82 100644 GIT binary patch delta 915 zcmZ9Kdq~q!6vuadzu%^9t`rq>kCcquXt80a`B;%j86yc_WkF(PwV||rF7rjYG1C!+ za17~XwlY&GlhcF#NGPJPe{9envS|HP5M*f&dJxpTv-+d^_sVB^oeZe*%JZUVm4LdBQz~fG6V& z)RdktLKr79-l!gPxfUa8?y3LX~2?Jpf$lrQ(3P?rjF;NDP=c{z)(&%tfcDPtDtV@ zX*iZ)#GZobc^jeqUyY-Asi5y*&I4!9@G-UCtd^5_L~=9?Bl%U)?b5NbAPMZ`D)0fg zakj7qigByG7YZq2_lQu55*DV|@n`a5&nxanFMz@OW670o~o5o=` z1vgiSD!!h%=UWqTiNgZ*ROr|(YND!Z!8KG`1}*fh^s5MdoGVL(0IJImK^L`_S1Fkk z+OeETpLh77pIWLH3*e=mT@_mJ;md;?HOF`oM9hsW$%0aSB}zo7ZCUhy~A?;IQe delta 921 zcmZ9Ke@IhN6vubodvj}ZVx=rM&6R>18Z%+mW+pDvVrZmhsgzk+&6MezKT1ocLtIg* z93_-Ah0@d@$0anx$dW?;DZxngkAJ1qA7N!h|8?&Z`=iU_^X~oL?>Xn8-oEOodK0Nx3XBy|QG4zH6((p7;p1Xi|z!A;uj0r4>*&gC-Vv8TzDN7#4-ju zLp?(VcIYB;%)AveIBlL#yP3%FQ-%g>7-HgKaI?&~V%$Tb_m%=QVRModBFR*eLWpT% z{(OpnTy$59H37cLRUhEpwUq(E?QAFuKqjsxw?PbX?HdABBeOH;G1^)LjsG&gv_^oQ zCA-wy!^T{iNhO9cceoUS>$U>v%LcrX5(YNnNofJNg?wrWq~rP24zSXUTB1I9nulC; zrq5};G%z5qsjaFg+bGz=GWqbR-2>TV-kvJJSyW~{g<>+9mCpw;Ki%HuNCPLyb)4bV zl2X1QPvzu79SJ%1o7YQ;g$k#;MoeOxlA~n9X2^3v3+c!!kb@UjPwfUD8uG$1y3hq( zq`feN0}tsh&euRQ{1qO~GfuR{W-;vP&I{}M)$>1fOK0APg zEZ)?s==s+01g#;n>jk)oyxXaylB>^5r3ZbpFiI};KLolsKBPs%Kqd5$y94Lsw7PGP ztN?FG{BS=+vg3XBP?$K8y>d(-{$WcHJ^1Ny3cbKxBdu!I;!@Lzjp!d~36lPbmsm$n a0>onHSUhc_>9JG|z|J+ Date: Mon, 12 Aug 2024 16:43:08 +1000 Subject: [PATCH 131/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/loss.cpython-311.pyc | Bin 10304 -> 3586 bytes train.py | 1 - 2 files changed, 1 deletion(-) diff --git a/__pycache__/loss.cpython-311.pyc b/__pycache__/loss.cpython-311.pyc index df21e34ead641eaa2066834679512a562f042222..6f51feefcfb365803a961d0ad4458dbf9d99384b 100644 GIT binary patch delta 857 zcmaJ<%WD%s7~k1!UK{hUt+6zlWV_ZR3nc{FfJhNUD}vaEf(H?nwL49hZg#^=wmvQu z^$$?y*o!xbP(0-1souq_hryeto_q5k&bLYQ5<0NI+3)wge&5V|-FmInKW4M4Bv#|o zm))OnN%{rGWXa>7Gv|H%UfYSil^$mtMc@S|E->pP1m+y{P%@LJ5W8U_o^kNguw(Uz z>-TqLwvFb7pBH2_&y_B={ejHoW{byyuuq}9uCs4)nSC#{P>%g6wdFI3wY4W&Lik7$ z0OXBBNfwA4JJuQ~&wgkt0@dbm@?PJ%syns?NFKCtdySo|IQm^ArhY zW1x@})Ja1Yy2BNMy>MC^Q#VkNtygaD|2G*`m zTEb(ncr9#$I5R(*DAfE2H7cyoSdJ$D3!BBj$^ZZW literal 10304 zcmeHNeQXqGdY{=Z&+d-xwRa6Rw(%GeU=PQ}fZ;G34B^XAFgP}x;smCfo$<2vdTriW z2fSSi6y;PmLakE~;aou7R1s$%QM)}wx>UMKX?1Py&lxlo)<~8jRZ+`-j>t(BfAx9Z z*&Xk!Z9>vsa;iFh_nG%=-jC;fpO4@B`sbCEJ_eqJOMgH7YKUR}9TUaJU5Gr%L*x!4 zF%p|(hRioR<`{CYX52aE8gdol?lDIU^2a=06+5?;%O~6)bc!|R;5(&4l`6GaX_mjFVKMG16m`o15DWe7#r&gvxZ-i z-<*t#6C{yJ#|mB0SFyt<+aUR;Fh!XeVkKtCA-QLmVHSqz9db%8Xwo&qTxW;e(5%M@ zo;iQ&{J?pc#N>(eWHfnZTu}^HG$~&nj1^nN=3Vf^^b;M(`|KSi%|fK`q#eZ=!(^HJ zka}R{gi54}FHdWLU2#hO_rS5b_Y zzU@67V=j58osuK#6uDxpTRw>aUO17v&Xk-!5%wBxWpY9$1p78PQB0%~X;C!Xl=p@^ zIgyx93^!$+;ioBaY+RC)hHoNDqGNJeCW=u>@ zI8MFuOFycanEg^)7z}XO*=~^K`NUy?B#!+c9qJZ!g!--mh+XkDy7&8}15u}*!5ZCrtGy`Jkx zbC%6I%JhpR$2iX7W0(4Ach;G8XDVSHPTRO@O88Zqb>~X-PO+tOT_puSbgkZhxBkV} zZYb+~?zSAOR>=w4sID)w#kYop3n#|OwJ4E<^Pn-hpVpj8O-!Z*gB$|rL27AKrgL#z6#CS3auPXe;H6}+>hAWYl$HE>`gIVl#gPk&b zY3xVDz6~~E_{=$BQbD=EVG05Pe}XY8hE48g8rN8w_Cg^xt^5rz4pnbZ4RveWL7h9O zatCR$XP*6}LJjpm$VWeon=y?$taFD|?(m|Y|BJ47yZ&tdt^KpncOx^Axz4}p`dQZp z`|s}0MSm8Vi!9YPsm+%^Z~fO7|LMhl-2R*Gkial_;nCo04+me<24B|)UoT`fU)E|x zy;f9t@r$|*Grpyo4YOC@tG`{p%<$gxtVw5%Eb$xF`klG{+|`c;AYulMZ`1iUm2X?D z4$OAGJ2^AC6so&@HrJZp`q^pq;_GVY4K4JB9(rSzUAN#%Z1_~Z_0#R}g77)}n+h|g z@uzkEw921e_Ar|^KWYd+YzWV{&i88#9eP8@qlUc?8}@1q`}Br=Gp85%4Rb%z_!nkQ z)Ass`1MO4m`|_Dz-T37VZGW%6zZa5b(D-vYe@><0OI(3h=DYv4=NCQN&WOG<0!cGy z+)btNc4eWjcM0;ak6x9sq^!MLtR0vp_$mYHV}I(XWA z3|h7xYV9Dk+9w1b5PU)h^$Ed&E%6E6M$N^OC+!~LV64=O9&ZM6B3w;y7m`Cr(4{v~frlaqfERDY*q~S_p4i~$UNJ?O74Udvc#AJn_Admo% zu$!7ZvkNCQO=_`-YAFkKn3{^Bi%qL^fQ2o71RpAUK(!UIMN2NEaocom+as>yA=jaC zojTX4a-BtNQ5CwC%t>n1POWOEj$!YNb7@2U?Xe}k>ektNdvng8^MB2MRP}L{2|n_U zd`;e=H||mSz3WwB^R0QO-ndugyVkF=CGXN3yHtMPdQ}|x-Fdg(xKHKxKSz~izvo5Y zSBz()|7#?GoBYcGX2Yg$0g|Ymwt*z%HdUrNcx?XZvYlebJG>2rYBos`s^kP)Yr%xpq8e+({O^7^<^vU<$RX} zZMLEmo&#LxO4Ii|_LcyMt^>8oM*=lHYXhll+2~m2)+hjGe<@(Lnf5ZJZyynGWMJBB zgSeI>kacIhE9O1xDNWg-=ro`8XL$-U{We|&OL!H^@;S@gSX&(;_T4VktqC&+0m%B- z0J0lOfb7Ps|GW1v%ij9HHuidevd=r-GM5Zi520*BUql3giUYg3)b2G8;^^c4Qb(p8$Od)D8+6H^s&o-Yh;7or2;_ zMFDS;9r*or{Jv74V6ae-sV`tqI14PCj}}V8l@y4QFn);!k?cg$hU6tA?LZ8#oSGby z2}005Onage6LLy2oJlaz70FJDIms?4g7y$Ll1?PMk?cXT7l`4zE?Pc-Z)zp(y)MQ~ zTy1!#tSCDI=opPZ>7pMeajvu@1w9*+0R?K>d&Do9hE)qxgtG^rSosuIpa5rotZ@f) z?tsc2_ydBoO#o-VVvzpeaCX~{`Pw-jz}BrVE52X%m+Ppd$hWpdfm=t z##M=D14seeH*dcCS}v1s{dhV*@_EgtV_Neuz4_Rz4_ciY(E>YWeP7}r`{q06`xo{= z1cD8f(;Udu0_XI=IhBSlzB8l;Z2vnC;|-~ylX>5-s()Fn?K-0GIs!>EXk4$(^{QNN zDX3=$$$r?yB!XMH1d~(G8wrZ>YD`{8YVc(U^TGEo+#bp~-xF_(YUl+svy__tc*!(YCW{MaO}74dCtH|m z15TA(h9-!JBnqT-c0_`V=bG0Q{1Popab6KM7@2pzfBg38Ipw``x6eVcEYL|Uz-5*) zIkw<3cp}5`Ca@PyiWP$&j*iM!%1g`W(WjRd&|^qspKyEU> zB9N&Bv-dZrC0}VvI1*|lY&+?d!AS7F85ENvThnF1iZnMg$mc*O*+ zaYQ&IBjXcO;vO+7szuK0k@H%pUk~+Tu{}IPXnqLxk?SxUlS96Y z3@rSgvIyH#4wkTJR1J=*KN+%Eq({#Dc0h~t>ydsf^r{|u70d16*^(G2Unr8rKJj-7 zh-IWG=PxCI|5|WF2O1dB10w~2q+MQyDx+F(R0kRu)dQnzm?=wTL(`}csG0f&M8Gq0 zNbD#zV3q z>LTq4`DFPf{-Rup!zuDMl>7<&lq8Uw%reJtHM3_QR$Mjd4(Is3mAI+M-1HOkumO)@ z43BvOVBRO9k6O+<>|V8q<@P{ooWw@ z%vRMN7MT`xov_F>tM*{EyzFXW0V%B}b6f6)zry^lR|;QY_haSNWX^Hd`xWMYy;7KF bJ!~u7nifgGiqX9Cq@O(uuy$=qSMa|9M&rG% diff --git a/train.py b/train.py index e8b2839..c2c9930 100644 --- a/train.py +++ b/train.py @@ -22,7 +22,6 @@ from loss import wasserstein_loss,hinge_loss,vanilla_gan_loss,gan_loss_fn from torch.optim.lr_scheduler import ReduceLROnPlateau import random -from vggloss import VGGLoss from stylegan import EMA from torch.optim import AdamW, SGD from transformers import Adafactor From 5553095588a0312f8dc972041879f9a1a9a897a9 Mon Sep 17 00:00:00 2001 From: John Pope Date: Mon, 12 Aug 2024 16:44:54 +1000 Subject: [PATCH 132/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- framedecoder.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/framedecoder.py b/framedecoder.py index 681d79a..7681278 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -14,34 +14,32 @@ def __init__(self, use_attention=True): self.use_attention = use_attention - # Updated channel counts to match concatenated features + # Adjusted channel counts to handle concatenated features self.upconv_blocks = nn.ModuleList([ - UpConvResBlock(512, 256), - UpConvResBlock(512, 256), - UpConvResBlock(512, 128), - UpConvResBlock(256, 64), - UpConvResBlock(128, 32) + UpConvResBlock(512, 512), + UpConvResBlock(1024, 512), # 512 + 512 = 1024 input channels + UpConvResBlock(768, 256), # 512 + 256 = 768 input channels + UpConvResBlock(384, 128), # 256 + 128 = 384 input channels + UpConvResBlock(256, 64) # 128 + 128 = 256 input channels ]) - # Added an additional FeatResBlock for the last encoder features self.feat_blocks = nn.ModuleList([ nn.Sequential(*[FeatResBlock(512) for _ in range(3)]), nn.Sequential(*[FeatResBlock(256) for _ in range(3)]), nn.Sequential(*[FeatResBlock(128) for _ in range(3)]), - nn.Sequential(*[FeatResBlock(64) for _ in range(3)]) + nn.Sequential(*[FeatResBlock(128) for _ in range(3)]) # Changed from 64 to 128 ]) if use_attention: - # Added an additional attention layer self.attention_layers = nn.ModuleList([ SelfAttention(512), SelfAttention(256), SelfAttention(128), - SelfAttention(64) + SelfAttention(128) # Changed from 64 to 128 ]) self.final_conv = nn.Sequential( - nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1), + nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1), nn.Sigmoid() ) @@ -61,7 +59,7 @@ def forward(self, features): x = self.upconv_blocks[i](x) debug_print(f"After upconv_block {i+1}: {x.shape}") - if i < len(features) - 1: # Process all encoder features + if i < len(features) - 1: # Process all encoder features except the last one debug_print(f"Processing feat_block {i+1}") feat_input = features[-(i+2)] debug_print(f"feat_block {i+1} input shape: {feat_input.shape}") From 1aedf66671ddb898e6c1922165d84270e8192d15 Mon Sep 17 00:00:00 2001 From: John Pope Date: Mon, 12 Aug 2024 16:47:08 +1000 Subject: [PATCH 133/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 11316 -> 11292 bytes framedecoder.py | 45 +++++++++++------------ 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 5a36a9b8be0d1df5bee6aa1aa352af73dc308a0d..49654ed512de4513bff7e21963378382903e5c44 100644 GIT binary patch delta 587 zcmaiwO(+FX6vywGxie#C-tflD!y6xC%J;@sBnz^Th28>1C<~L`TgkAY^xi^Q$i~@N zyksY(q$^^BvQtXQ&cgC#;d7_t&ahCv&aK})_kaH9bno38&I}uyhM^H26FY~KOS{Ik z0D!hw8qg9&m-3S(O0FF2)Kzdy9Ku&I3@V0b8!pq&%+wdjtm%*X2sfykp>Jx2rhu+W z7RUi`P#V$|SLMrajqP0Am*$+jc$Nr+BnkB)5|SmrvouFG&F5#k@LcIbl*hrsXE_}! z;G!4hnBi^2wmtt*UNHwA%WZ-e!>OFOvjtwhum9<8#uG(?puzUa{nYau)S7C=;=-m<&qG?8C z|I+gfcq$iQr)fbW8#NmRXkaUmd!gKsS|7-7=mp$}mav0pLVy8g7bArcwpTh0(9f#M z&jd)ayNXtcf0Q@WCI_;uW|W!ym#slu z3@A_}4kU_!3Kia=f~lA3JGVa+H#`2xG2ye4Cj5J(KSN;5V&kOH`JVFJoE<$(%{ z6@Wwo!v$%32-+Yx*Jp*$MH!tdGCFX<&B~m$jFWe9@lW2vB{KOimlC7U8 zUcfB|#Qf|@7X=ir2qpEQ|QQfBiP zMNw8@G;pcUVU(P_T)m#Lc(c03Cq~As%~Dz!j3yq8jDjB+u#+>4=2(6N%Y8vs5j1(5 Gb{YWrY@e3^ diff --git a/framedecoder.py b/framedecoder.py index 7681278..6566b5a 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -14,28 +14,25 @@ def __init__(self, use_attention=True): self.use_attention = use_attention - # Adjusted channel counts to handle concatenated features self.upconv_blocks = nn.ModuleList([ UpConvResBlock(512, 512), - UpConvResBlock(1024, 512), # 512 + 512 = 1024 input channels - UpConvResBlock(768, 256), # 512 + 256 = 768 input channels - UpConvResBlock(384, 128), # 256 + 128 = 384 input channels - UpConvResBlock(256, 64) # 128 + 128 = 256 input channels + UpConvResBlock(1024, 512), + UpConvResBlock(768, 256), + UpConvResBlock(384, 128), + UpConvResBlock(128, 64) ]) self.feat_blocks = nn.ModuleList([ nn.Sequential(*[FeatResBlock(512) for _ in range(3)]), nn.Sequential(*[FeatResBlock(256) for _ in range(3)]), - nn.Sequential(*[FeatResBlock(128) for _ in range(3)]), - nn.Sequential(*[FeatResBlock(128) for _ in range(3)]) # Changed from 64 to 128 + nn.Sequential(*[FeatResBlock(128) for _ in range(3)]) ]) if use_attention: self.attention_layers = nn.ModuleList([ SelfAttention(512), SelfAttention(256), - SelfAttention(128), - SelfAttention(128) # Changed from 64 to 128 + SelfAttention(128) ]) self.final_conv = nn.Sequential( @@ -59,26 +56,28 @@ def forward(self, features): x = self.upconv_blocks[i](x) debug_print(f"After upconv_block {i+1}: {x.shape}") - if i < len(features) - 1: # Process all encoder features except the last one - debug_print(f"Processing feat_block {i+1}") - feat_input = features[-(i+2)] - debug_print(f"feat_block {i+1} input shape: {feat_input.shape}") - feat = self.feat_blocks[i](feat_input) - debug_print(f"feat_block {i+1} output shape: {feat.shape}") - - if self.use_attention: - feat = self.attention_layers[i](feat) - debug_print(f"After attention {i+1}, feat shape: {feat.shape}") - - debug_print(f"Concatenating: x {x.shape} and feat {feat.shape}") - x = torch.cat([x, feat], dim=1) - debug_print(f"After concatenation: {x.shape}") + # if i < len(self.feat_blocks): + debug_print(f"Processing feat_block {i+1}") + feat_input = features[-(i+2)] + debug_print(f"feat_block {i+1} input shape: {feat_input.shape}") + feat = self.feat_blocks[i](feat_input) + debug_print(f"feat_block {i+1} output shape: {feat.shape}") + + if self.use_attention: + feat = self.attention_layers[i](feat) + debug_print(f"After attention {i+1}, feat shape: {feat.shape}") + + debug_print(f"Concatenating: x {x.shape} and feat {feat.shape}") + x = torch.cat([x, feat], dim=1) + debug_print(f"After concatenation: {x.shape}") debug_print("\nApplying final convolution") x = self.final_conv(x) debug_print(f"EnhancedFrameDecoder final output shape: {x.shape}") return x + + class SelfAttention(nn.Module): def __init__(self, in_dim): super().__init__() From 3a5886d191a8bb18e0d766d0766d1915db7fb652 Mon Sep 17 00:00:00 2001 From: John Pope Date: Mon, 12 Aug 2024 16:52:47 +1000 Subject: [PATCH 134/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 11292 -> 10919 bytes framedecoder.py | 28 +++++++++++------------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 49654ed512de4513bff7e21963378382903e5c44..a8ad07a48f9e53d2a7f3038650f38b7d176f9c82 100644 GIT binary patch delta 1396 zcmZ`&UuauZ7|%&^Z_?x@y)kY6woTHcO_Qwu|BtAvdu}>M zAKVli1=%S3_E4l?2I4w8F(o1s%McMCbixhXhY~3GvWI~WDQkTY1kbru9OA-ve)o5N z=ljn6zVGJSdhO1f@s7cuQ;^@|FTYvbGw&K_lzO^Xk@e(F73i%oC@SwFw)DEnfs`o~ zJWw~HQFx%V!V~o@6%x#(o=b4|44*R!#SR0!yADL=0y07u`w%(cPxcq&gr7CDC;%?) zIrvFqgtnUBkq=zD2IPfdonQG80;AmtS9My|qed58*L~})YScm9p+eoDKccX%mbA7i zr{L^W1yfUd;A{PC5(;_?S$Yfc?KKkLmRjQ5QoBATah_ui9_mxbR_fzy2#vv{;buv1 zJVq#1LVOgV4j46kh?+`2nk;B|kd%wNDFg^qR2P$3?5Eox1s{b|^nl6d@WN6>wGhwa zP9g+Y+ygfAdA6Nwg!8awzQqn5DxRwug1Wl-E61$DU=wWvMN_eR*GvepGB+r3gCaM$ zZ}$|Uh4}5!LaSsCtiOD~Sw-8KLVO$DULb_neTNi&>4cWK35lB!xru#q>qehs_7Z;l7-b`!;woM`*D)S+U53NV#Ov6R=l}5CMiIusC#6?6dLI`|OpJBbU z+o!>8sYksqVCitVn1QO|S@0r)>lQ0~YGKHlMxj#Gav3!>(QkvWi$3(YncgU0+d^48 zT(foej}Tg6){;0URG3selZgjZxQB{tI(s3J&fs3M5$GVe-B&Rk3hag7Y)+-(K!MEPR<2K!QJ>4It?E*I?#M+t8o^&V$|mh0Ra(s;Uu#V6)DVs(KUxc zuwxlRU654!z$uijb+Q_J?8EAHLr_*xfkhgBfBZnBF~>RdqwE@x=ced7fDXzn zK!^UKJDfA}F91cY;jcU{Qen*#4qy=7lUdN3JcNFc+kyPLVpO4z>c2Fg3B|j>h#n{& z04utw{0SM=29#0_fmL)@brRUoj4C994s=fy1dbe|rh%ggA6wLf!;$#3AcVsL3C-!e z3`y9~jl7y>u#V1XT#7v<<#3|Qn!D(+rruW4C|b0;l)_f=kib>BsIDy;DBZpir9`K- z2N0*NsKG`)12%FSv60(&bx_nb>>1Dv?F<;qJ=WF&FokM#*X0-Klm=yM>kLuVBKG^1 zlqr0a%jot1IGg*_FaUrT?HE4*&Rm440lnQMI3Ix*Fcj%Hn}8l79wb0gf^B3&kxIhI zY>|#kr=S-rJWA+C-nei=4OFAW5nlnJOv}BGHEM20dD=pu?L>wLe_wVDF--&*A9xr$n zm=Tc~$ulE6RTiBlVHp0TlS0Aj;u+v`m%!sYqx0aP7iLI>+{ZutT!86_hEwP z{pjZTGAo!vqB)c|hkhl_#y=z zp`o8Bm0Thm<#>1q3;(&-ShH9YV&k!aREkZcxMZR&9Oe>SDjW_G>njwR(pYmF7M`x| z!3roN&`ZEcAb`58o>#~{2nRK3kYem=#*31Zn#Ht08bz6gejDj^w5 z@zV>CIDw}ycyv;%@FlcnJqL!7&Q=Gea?Q37s6R$*MllHF8R#!}WGhh;m_Rq}lb{v- zW*-EdIlm(YKri~FzEN2&656b91^wvv`ld!w3pht0PC%-L95d%LZQkSD54hYL&Spj3 z({uyG+)BVnpcj4S?l>j2UO$$LYMvElXCN1akchHWvZ#!*7q|!u1)|W*Cxp3qmc`YP z91jpBEpQ3kk4XW(hDFF@@aS<eeXR#pYG2qCaj&k0^DBaWm=5k**iO;D_CpnM6 zEc*Pwq=hgKQ!c++dIs7O3D}HG9=Gy!Vwyy~o_3d1PwCggM5!Y{p1wcTy?in{A7}gE gN9da8eg6;urV{o4xZHWGdso_f_^ibVG|&?J8wH!F!vFvP diff --git a/framedecoder.py b/framedecoder.py index 6566b5a..f43b75f 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -56,20 +56,20 @@ def forward(self, features): x = self.upconv_blocks[i](x) debug_print(f"After upconv_block {i+1}: {x.shape}") - # if i < len(self.feat_blocks): - debug_print(f"Processing feat_block {i+1}") - feat_input = features[-(i+2)] - debug_print(f"feat_block {i+1} input shape: {feat_input.shape}") - feat = self.feat_blocks[i](feat_input) - debug_print(f"feat_block {i+1} output shape: {feat.shape}") - - if self.use_attention: - feat = self.attention_layers[i](feat) - debug_print(f"After attention {i+1}, feat shape: {feat.shape}") - - debug_print(f"Concatenating: x {x.shape} and feat {feat.shape}") - x = torch.cat([x, feat], dim=1) - debug_print(f"After concatenation: {x.shape}") + if i < len(self.feat_blocks): + debug_print(f"Processing feat_block {i+1}") + feat_input = features[-(i+2)] + debug_print(f"feat_block {i+1} input shape: {feat_input.shape}") + feat = self.feat_blocks[i](feat_input) + debug_print(f"feat_block {i+1} output shape: {feat.shape}") + + if self.use_attention: + feat = self.attention_layers[i](feat) + debug_print(f"After attention {i+1}, feat shape: {feat.shape}") + + debug_print(f"Concatenating: x {x.shape} and feat {feat.shape}") + x = torch.cat([x, feat], dim=1) + debug_print(f"After concatenation: {x.shape}") debug_print("\nApplying final convolution") x = self.final_conv(x) From 79eeb4841f0b6ad0a119a7337bc0c1fb306d9eb7 Mon Sep 17 00:00:00 2001 From: John Pope Date: Mon, 12 Aug 2024 16:56:03 +1000 Subject: [PATCH 135/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/framedecoder.cpython-311.pyc | Bin 10919 -> 10985 bytes framedecoder.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index a8ad07a48f9e53d2a7f3038650f38b7d176f9c82..05988a6af3409e32666c4a575e4e4704fd4caaab 100644 GIT binary patch delta 422 zcmZ1;`ZAPvIWI340}wQB*qQcjBky}IMvu*Zxg;1FVfk`lg@rJbV43jI;1{(ygNE>f( zxhQRUMcQ(L%Uu!q3kpFS3^xRBGzKz&=s?PmEC}nONbnVr;0q$bPeGRXZ(!WQ4sqQ9 zlOs+*zk%Egcj#Ra$tj*Q3Kyu%$XpS4K}7$ei2fB3{Rxaaj1rswa!D{Ux=a@1nKW6Ck9YG^o)?UaZkx~X zRWLKgOm-KV!e;%EL5S0O@`mbyZyb>2A%r*#aFx)5%WB}1c9`h?a<`;O(Z%8UG za9ohM(0PH{6-mtwue&_r7bLtFB(7jwm<(h9(GICAmYxviMIN6kJU$nAd~OJl^T)>*$V4B@7NVx4#*;B6tq=D!nkNXuK_X|Ak5BP3c(8k+7|`1uLx*&ID%bj exq)#D`vnP0AiBt7b%n?30*@7tvAIelUkm`j8IGy| diff --git a/framedecoder.py b/framedecoder.py index f43b75f..4a3cf3b 100644 --- a/framedecoder.py +++ b/framedecoder.py @@ -4,7 +4,7 @@ import math from resblock import UpConvResBlock,FeatResBlock -DEBUG = True +DEBUG = False def debug_print(*args, **kwargs): if DEBUG: print(*args, **kwargs) From 839b2ab3d49f45dbd32e9e6465e2f4a32fc63745 Mon Sep 17 00:00:00 2001 From: John Pope Date: Mon, 12 Aug 2024 17:02:45 +1000 Subject: [PATCH 136/152] ok --- __pycache__/framedecoder.cpython-311.pyc | Bin 10985 -> 10985 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 05988a6af3409e32666c4a575e4e4704fd4caaab..84be0fb555c947719c50e1eca66b3f6b1e95774b 100644 GIT binary patch delta 35 qcmaDE`ZAPvIWI340}%Y#xHIkjMqXA%Mz_tvj2fDZypwIT)&c<6XbR^5 delta 35 pcmaDE`ZAPvIWI340}wQB*qQcjBQGl>W5{M^YXQwt3U>ei From 163b5741c3243bda9c77ddb6067e2d1ab9016765 Mon Sep 17 00:00:00 2001 From: John Pope Date: Mon, 12 Aug 2024 17:26:03 +1000 Subject: [PATCH 137/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/config.yaml b/config.yaml index 367ed8c..560de51 100644 --- a/config.yaml +++ b/config.yaml @@ -1,6 +1,6 @@ # Loss function loss: - type: "vanilla" # Changed to Wasserstein loss for WGAN-GP + type: "hinge" # vanilla worked ok - https://wandb.ai/snoozie/IMF/runs/sa0i7gs6?nw=nwusersnoozie # Model parameters model: @@ -13,14 +13,16 @@ model: use_skip: False # Training parameters training: + use_multiscale_discriminator: True + use_ema: False ema_decay: 0.999 style_mixing_prob: 0.9 initial_noise_magnitude: 0.01 final_noise_magnitude: 0.0001 - use_multiscale_discriminator: False + initial_video_repeat: 5 final_video_repeat: 2 - use_ema: False + use_r1_reg: True batch_size: 1 # need to redo emodataset to remove the cache npz - smaller numbers won't work... num_epochs: 1000 From 2272789cd9d98732509ff0ab5021e121e1fce5d0 Mon Sep 17 00:00:00 2001 From: John Pope Date: Mon, 12 Aug 2024 17:33:35 +1000 Subject: [PATCH 138/152] ok --- loss.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/loss.py b/loss.py index 5de9717..33e5435 100644 --- a/loss.py +++ b/loss.py @@ -5,7 +5,42 @@ from model import debug_print import lpips +# Wasserstein Loss: +# Pros: +# Provides a meaningful distance metric between distributions. +# Often leads to more stable training and better convergence. +# Can help prevent mode collapse. + +# Cons: + +# Requires careful weight clipping or gradient penalty implementation for Lipschitz constraint. +# May converge slower than other losses in some cases. + + +# Hinge Loss: +# Pros: + +# Often results in sharper and more realistic images. +# Good stability in training, especially for complex architectures. +# Works well with spectral normalization. + +# Cons: + +# May be sensitive to outliers. +# Can sometimes lead to more constrained generator outputs. + + +# Vanilla GAN Loss: +# Pros: + +# Simple and straightforward implementation. +# Works well for many standard GAN applications. + +# Cons: + +# Can suffer from vanishing gradients and mode collapse. +# Often less stable than Wasserstein or Hinge loss, especially for complex models. def wasserstein_loss(real_outputs, fake_outputs): """ From 19502651d145f0a4204d8b240630d264305a1d8f Mon Sep 17 00:00:00 2001 From: John Pope Date: Tue, 13 Aug 2024 10:56:44 +1000 Subject: [PATCH 139/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/EMODataset.cpython-311.pyc | Bin 23135 -> 23130 bytes __pycache__/framedecoder.cpython-311.pyc | Bin 10985 -> 10980 bytes __pycache__/resblock.cpython-311.pyc | Bin 44951 -> 44946 bytes __pycache__/stylegan.cpython-311.pyc | Bin 6627 -> 6622 bytes __pycache__/vit_mlgffn.cpython-311.pyc | Bin 23160 -> 23155 bytes 5 files changed, 0 insertions(+), 0 deletions(-) diff --git a/__pycache__/EMODataset.cpython-311.pyc b/__pycache__/EMODataset.cpython-311.pyc index e638473279edab6fb5ac560b1415e509dd873c17..19e52bb0595fe1633e054a8e9a05ab820cf5ab42 100644 GIT binary patch delta 36 rcmcb=h4I!FM(*Xjyj%=G;JA3#M(#FN4jKL2)RfFb{mt`O4~75$*USsd delta 41 vcmcb$h4KCtM(*Xjyj%=GV9v9CBX=7sucCf#YD#9Jetv4MzTxJXtOr8?1Vs$S diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 84be0fb555c947719c50e1eca66b3f6b1e95774b..41d27d35f67d2cb05aa21dc32062eb0afb6259f9 100644 GIT binary patch delta 34 ocmaDE`XrQlIWI340}wba-nEgtmXSkFKQ}ccGf{u@RK^fZ0K38pZ~y=R delta 39 tcmaD7`ZAPzIWI340}%Y#xN{?SEhDe8er{??W}<$6YOcQF<_U};ng9vB4Bh|$ diff --git a/__pycache__/resblock.cpython-311.pyc b/__pycache__/resblock.cpython-311.pyc index c9c05474510705da2473f93d762d7528a2c94a82..24f3e21eee82fe84bfa3b0c13d1392d376016f3a 100644 GIT binary patch delta 36 qcmbP!pJ~#4Chq0Dyj%=G;JA3#M(z++4k`WI)RfFb{mnV7+N%J|4hrr7 delta 41 wcmbPqpK1DiChq0Dyj%=G@MqP|jocxuyz=_FsVSL>`uVB3`i7g+S+!RI02#{-RR910 diff --git a/__pycache__/stylegan.cpython-311.pyc b/__pycache__/stylegan.cpython-311.pyc index 1feb2386a31566607a1ec7f9d1bb46185fb718fa..14abf6853e5fdbb152e03091467d6c3ba2fae0e4 100644 GIT binary patch delta 34 ocmaECe9xGBIWI340}wba-nEgtjgdo2KQ}ccGf{u@JVsXu0I~-P4gdfE delta 39 tcmca-{MeX#IWI340}w>S@7T!Q#>gwLpPQPJnW&$inyYWPc_yQ)1OVl^3xxmx diff --git a/__pycache__/vit_mlgffn.cpython-311.pyc b/__pycache__/vit_mlgffn.cpython-311.pyc index 6abf115401ad788bdc67364083affe539d774170..009a5663f0c77b9b3957483b906cacbf48be21b1 100644 GIT binary patch delta 36 rcmeydh4J$iM(*Xjyj%=G;JA3#M(!R~4jKL2)RfFb{mn~QuY>>q-zN+m delta 41 wcmeyoh4IH0M(*Xjyj%=G@K1L8M(!R~UPb-f)RfFb{ruEieZ$T3S+9fu03%lo_W%F@ From b7078bf0ac620d22ac8e91401edc1489decb6d69 Mon Sep 17 00:00:00 2001 From: John Pope Date: Tue, 13 Aug 2024 10:57:16 +1000 Subject: [PATCH 140/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __pycache__/helper.cpython-311.pyc | Bin 23843 -> 23838 bytes __pycache__/loss.cpython-311.pyc | Bin 3586 -> 3582 bytes train.py | 2 +- 3 files changed, 1 insertion(+), 1 deletion(-) diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index 3c4b811aae70360c07570c13dcb05c9eb06877c9..6e14c9083d246a5cb62129234f634a32a4cbace7 100644 GIT binary patch delta 36 qcmZ3yi*epAM(*Xjyj%=Gz_WhWMs7PU4hj9-)RfFb{mo%qXM+I1WeT?d delta 41 vcmbQYi*fNTM(*Xjyj%=Gpq00CBexwFudIGb diff --git a/__pycache__/loss.cpython-311.pyc b/__pycache__/loss.cpython-311.pyc index 6f51feefcfb365803a961d0ad4458dbf9d99384b..5d21b2533c6f62d4a43c51205c547a6a4046c2a2 100644 GIT binary patch delta 115 zcmV-(0F3{F9R3>)rwt7X0000046nLoZV9mt$N>lzFKuOHX<;vu;{iVbF_R+#*8wt< z`vR8%EwhsY)ByoGvuFhd0RcLbZw592HM6D$^Z@}uvzrJa0Rcjjz6o~$J+n3n_yGY$ Vvo8!60Rgm=M-6NOP63m&4Nj!-FKuOHX<;vKWo<7plimS80S=QM z0@nc#llB6a0SdE@1JnTl7PDgo1_1#VlWGPw0TQ#J2J`^|9 Date: Tue, 13 Aug 2024 11:08:16 +1000 Subject: [PATCH 141/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model.py | 101 ++++++++++++++++++++++++++++++++----------------------- train.py | 84 ++------------------------------------------- 2 files changed, 60 insertions(+), 125 deletions(-) diff --git a/model.py b/model.py index 4c5d705..03da4ce 100644 --- a/model.py +++ b/model.py @@ -14,6 +14,7 @@ import math import random import colored_traceback.auto # makes terminal show color coded output when crash +from helper import normalize from framedecoder import EnhancedFrameDecoder DEBUG = False def debug_print(*args, **kwargs): @@ -382,6 +383,21 @@ def forward(self, token, condition): +''' +DenseFeatureEncoder (EF): Encodes the reference frame into multi-scale features. +LatentTokenEncoder (ET): Encodes both the current and reference frames into latent tokens. +LatentTokenDecoder (IMFD): Decodes the latent tokens into motion features. +ImplicitMotionAlignment (IMFA): Aligns the reference features to the current frame using the motion features. + +The forward pass: + +Encodes the reference frame using the dense feature encoder. +Encodes both current and reference frames into latent tokens. +Decodes the latent tokens into motion features. +For each scale, aligns the reference features to the current frame using the ImplicitMotionAlignment module. +''' + + ''' DenseFeatureEncoder (EF): Encodes the reference frame into multi-scale features. LatentTokenEncoder (ET): Encodes both the current and reference frames into latent tokens. @@ -427,7 +443,7 @@ def __init__(self,use_resnet_feature=False,use_mlgffn=False,use_skip=False,use_e self.implicit_motion_alignment.append(model) FrameDecode = EnhancedFrameDecoder if use_enhanced_generator else FrameDecoder - self.frame_decoder = FrameDecode() #CIPSFrameDecoder(feature_dims=self.motion_dims) + self.frame_decoder = FrameDecode() self.noise_level = noise_level self.style_mix_prob = style_mix_prob @@ -435,20 +451,8 @@ def __init__(self,use_resnet_feature=False,use_mlgffn=False,use_skip=False,use_e self.mapping_network = MappingNetwork(latent_dim, latent_dim, depth=8) self.noise_injection = NoiseInjection() - def add_noise(self, tensor): - return tensor + torch.randn_like(tensor) * self.noise_level - - def style_mixing(self, t_c, t_r): - device = t_c.device - if random.random() < self.style_mix_prob: - batch_size, token_dim = t_c.size() - mix_mask = (torch.rand(batch_size, token_dim, device=device) < 0.5).float() - t_c_mixed = t_c * mix_mask + t_r * (1 - mix_mask) - t_r_mixed = t_r * mix_mask + t_c * (1 - mix_mask) - return t_c_mixed, t_r_mixed - return t_c, t_r - def forward(self, x_current, x_reference): + def forward(self, x_current, x_reference,style_mixing_prob=0,noise_magnitude=0): x_current = x_current.requires_grad_() x_reference = x_reference.requires_grad_() @@ -459,41 +463,61 @@ def forward(self, x_current, x_reference): t_r = self.latent_token_encoder(x_reference) t_c = self.latent_token_encoder(x_current) - # StyleGAN2-like mapping network - t_r = self.mapping_network(t_r) - t_c = self.mapping_network(t_c) - # Add noise to latent tokens - t_r = self.add_noise(t_r) - t_c = self.add_noise(t_c) + noise_r = torch.randn_like(t_r) * noise_magnitude + noise_c = torch.randn_like(t_c) * noise_magnitude + t_r = t_r + noise_r + t_c = t_c + noise_c - # Apply style mixing - t_c, t_r = self.style_mixing(t_c, t_r) + # StyleGAN2-like mapping network + # t_r = self.mapping_network(t_r) + # t_c = self.mapping_network(t_c) + + # Style mixing (optional, based on probability) + if torch.rand(()).item() < style_mixing_prob: + batch_size = t_c.size(0) + rand_indices = torch.randperm(batch_size) + rand_t_c = t_c[rand_indices] + rand_t_r = t_r[rand_indices] + + # print(f"rand_t_c shape: {rand_t_c.shape}") + # print(f"rand_t_r shape: {rand_t_r.shape}") + + # Create a mask for mixing + mix_mask = torch.rand(batch_size, 1, device=t_c.device) < 0.5 + mix_mask = mix_mask.float() + + # print(f"mix_mask shape: {mix_mask.shape}") + + # Mix the tokens + mix_t_c = t_c * mix_mask + rand_t_c * (1 - mix_mask) + mix_t_r = t_r * mix_mask + rand_t_r * (1 - mix_mask) + else: + # print(f"no mixing...") + mix_t_c = t_c + mix_t_r = t_r # Latent token decoding - m_r = self.latent_token_decoder(t_r) - m_c = self.latent_token_decoder(t_c) + m_r = self.latent_token_decoder(mix_t_c) + m_c = self.latent_token_decoder(mix_t_r) - # Implicit motion alignment with noise injection + # Implicit motion alignment aligned_features = [] for i in range(len(self.implicit_motion_alignment)): f_r_i = f_r[i] - m_r_i = self.noise_injection(m_r[i]) - m_c_i = self.noise_injection(m_c[i]) align_layer = self.implicit_motion_alignment[i] + m_c_i = m_c[i] + m_r_i = m_r[i] aligned_feature = align_layer(m_c_i, m_r_i, f_r_i) aligned_features.append(aligned_feature) # Frame decoding - reconstructed_frame = self.frame_decoder(aligned_features) + x_reconstructed = self.frame_decoder(aligned_features) + x_reconstructed = normalize(x_reconstructed) # 🤷 images are washed out - or over saturated... - return reconstructed_frame, { - 'dense_features': f_r, - 'latent_tokens': (t_c, t_r), - 'motion_features': (m_c, m_r), - 'aligned_features': aligned_features - } + return x_reconstructed + def set_noise_level(self, noise_level): self.noise_level = noise_level @@ -501,15 +525,6 @@ def set_noise_level(self, noise_level): def set_style_mix_prob(self, style_mix_prob): self.style_mix_prob = style_mix_prob - def process_tokens(self, t_c, t_r): - if isinstance(t_c, list) and isinstance(t_r, list): - m_c = [self.latent_token_decoder(tc) for tc in t_c] - m_r = [self.latent_token_decoder(tr) for tr in t_r] - else: - m_c = self.latent_token_decoder(t_c) - m_r = self.latent_token_decoder(t_r) - - return m_c, m_r class MappingNetwork(nn.Module): def __init__(self, latent_dim, w_dim, depth): diff --git a/train.py b/train.py index f27b7f7..41fc321 100644 --- a/train.py +++ b/train.py @@ -159,88 +159,8 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato continue # Skip when current frame is the reference frame x_current = source_frames[:, current_idx] - - # A. Forward Pass - # 1. Dense Feature Encoding - f_r = model.dense_feature_encoder(x_reference) - - # 2. Latent Token Encoding (with noise addition) - t_r = model.latent_token_encoder(x_reference) - t_c = model.latent_token_encoder(x_current) - - - - # Add noise to latent tokens - noise_r = torch.randn_like(t_r) * noise_magnitude - noise_c = torch.randn_like(t_c) * noise_magnitude - t_r = t_r + noise_r - t_c = t_c + noise_c - - # Style mixing (optional, based on probability) - # Style mixing (optional, based on probability) - # print(f"Original t_c shape: {t_c.shape}") - # print(f"Original t_r shape: {t_r.shape}") - - if torch.rand(()).item() < style_mixing_prob: - batch_size = t_c.size(0) - rand_indices = torch.randperm(batch_size) - rand_t_c = t_c[rand_indices] - rand_t_r = t_r[rand_indices] - - # print(f"rand_t_c shape: {rand_t_c.shape}") - # print(f"rand_t_r shape: {rand_t_r.shape}") - - # Create a mask for mixing - mix_mask = torch.rand(batch_size, 1, device=t_c.device) < 0.5 - mix_mask = mix_mask.float() - - # print(f"mix_mask shape: {mix_mask.shape}") - - # Mix the tokens - mix_t_c = t_c * mix_mask + rand_t_c * (1 - mix_mask) - mix_t_r = t_r * mix_mask + rand_t_r * (1 - mix_mask) - else: - # print(f"no mixing...") - mix_t_c = t_c - mix_t_r = t_r - - # print(f"Final mix_t_c shape: {mix_t_c.shape}") - # print(f"Final mix_t_r shape: {mix_t_r.shape}") - - # Now use mix_t_c and mix_t_r for the rest of the processing - m_c = model.latent_token_decoder(mix_t_c) - m_r = model.latent_token_decoder(mix_t_r) - - - # Visualize latent tokens (do this every N batches to avoid overwhelming I/O) - # if batch_idx % config.logging.visualize_every == 0: - # os.makedirs(f"latent_visualizations/epoch_{epoch}", exist_ok=True) - # visualize_latent_token( - # t_r, # Visualize the first token in the batch - # f"latent_visualizations/epoch_{epoch}/t_r_token_reference_batch{batch_idx}.png" - # ) - # visualize_latent_token( - # m_c[0], # Visualize the first token in the batch - # f"latent_visualizations/epoch_{epoch}/m_c_token_current_batch{batch_idx}.png" - # ) - - - # 4. Implicit Motion Alignment - # Implicit Motion Alignment - aligned_features = [] - for i in range(len(model.implicit_motion_alignment)): - f_r_i = f_r[i] - align_layer = model.implicit_motion_alignment[i] - m_c_i = m_c[i] - m_r_i = m_r[i] - aligned_feature = align_layer(m_c_i, m_r_i, f_r_i) - aligned_features.append(aligned_feature) - - - # 5. Frame Decoding - x_reconstructed = model.frame_decoder(aligned_features) - x_reconstructed = normalize(x_reconstructed) # 🤷 images are washed out - or over saturated... - + x_reconstructed = model(x_current, x_reference, style_mixing_prob, noise_magnitude) + # B. Loss Calculation # 1. Pixel-wise Loss l_p = pixel_loss_fn(x_reconstructed, x_current).mean() From 547d0305d186c1cf9100ab346d9bc7d6d5b94508 Mon Sep 17 00:00:00 2001 From: John Pope Date: Tue, 13 Aug 2024 11:56:15 +1000 Subject: [PATCH 142/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 56 +++++++++++++++++++++++++++++------------------------ 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/config.yaml b/config.yaml index 560de51..3f3b78e 100644 --- a/config.yaml +++ b/config.yaml @@ -1,7 +1,7 @@ # Loss function loss: - type: "hinge" # vanilla worked ok - https://wandb.ai/snoozie/IMF/runs/sa0i7gs6?nw=nwusersnoozie - + type: "vanilla" # vanilla worked ok - https://wandb.ai/snoozie/IMF/runs/sa0i7gs6?nw=nwusersnoozie + lpips_spatial: True # Model parameters model: latent_dim: 32 @@ -13,48 +13,54 @@ model: use_skip: False # Training parameters training: + use_r1_reg: True + r1_gamma: 10 + r1_interval: 16 + use_multiscale_discriminator: True + ada_augmentation: False + ada_target_r_t: 0.6 + ada_kimg: 500 + ada_interval: 4 + use_ema: False ema_decay: 0.999 style_mixing_prob: 0.9 initial_noise_magnitude: 0.01 final_noise_magnitude: 0.0001 + initial_update_ratio: 1 + update_ratio_adjustment: 0.1 + + + initial_video_repeat: 5 final_video_repeat: 2 - use_r1_reg: True + batch_size: 1 # need to redo emodataset to remove the cache npz - smaller numbers won't work... num_epochs: 1000 - save_steps: 250 + learning_rate_g: 1.0e-4 # Reduced learning rate for generator - initial_learning_rate_d: 1.0e-4 # Set a lower initial learning rate for discriminator - # learning_rate_g: 5.0e-4 # Increased learning rate for generator - # learning_rate_d: 5.0e-4 # Increased learning rate for discriminator + learning_rate_d: 4e-4 gradient_accumulation_steps: 1 - lambda_pixel: 10 # in paper lambda-pixel = 10 Adjust this value as needed + lambda_perceptual: 10 # lambda perceptual = 10 - lambda_adv: 1 # adverserial = 1 + lambda_pixel: 5 # in paper lambda-pixel = 10 + lambda_adv: 0.5 # 1 # Can sometimes lead to artifacts if weighted too heavily lambda_gp: 10 # Gradient penalty coefficient - lambda_mse: 1.0 - n_critic: 2 # Number of discriminator updates per generator update + + clip_grad_norm: 0.75 # Maximum norm for gradient clipping - r1_gamma: 10 - r1_interval: 16 - label_smoothing: 0.1 - min_learning_rate_d: 1.0e-6 - max_learning_rate_d: 1.0e-3 - d_lr_adjust_frequency: 100 # Adjust D learning rate every 100 steps - d_lr_adjust_factor: 2.0 # Factor to increase/decrease D learning rate - target_d_loss_ratio: 0.6 # Target ratio of D loss to G loss + + every_xref_frames: 16 use_many_xrefs: False - scales: [1, 0.5, 0.25, 0.125] - enable_xformers_memory_efficient_attention: True - learning_rate_d: 1.0e-4 + + weight_decay: 1e-5 - lte_learning_rate: 1.0e-5 + # Dataset parameters dataset: # celeb-hq torrent https://github.com/johndpope/MegaPortrait-hack/tree/main/junk @@ -75,6 +81,7 @@ logging: visualize_every: 100 # Visualize latent tokens every 100 batches print_model_details: False log_every: 100 + save_steps: 250 # Accelerator settings accelerator: mixed_precision: "fp16" # Options: "no", "fp16", "bf16" @@ -87,8 +94,7 @@ discriminator: # Optimizer parameters optimizer: - beta1: 0.5 + beta1: 0.0 beta2: 0.999 - From 40f1fcee8a2f3a7ee366fe292be3ae1283e73a33 Mon Sep 17 00:00:00 2001 From: John Pope Date: Tue, 13 Aug 2024 12:12:12 +1000 Subject: [PATCH 143/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.yaml b/config.yaml index 3f3b78e..81f5032 100644 --- a/config.yaml +++ b/config.yaml @@ -23,7 +23,7 @@ training: ada_kimg: 500 ada_interval: 4 - use_ema: False + use_ema: True ema_decay: 0.999 style_mixing_prob: 0.9 initial_noise_magnitude: 0.01 From 622dd487c3a8dc65b6668f968c294dbd506c0437 Mon Sep 17 00:00:00 2001 From: John Pope Date: Tue, 13 Aug 2024 12:36:13 +1000 Subject: [PATCH 144/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.yaml b/config.yaml index 81f5032..ca4e7c5 100644 --- a/config.yaml +++ b/config.yaml @@ -7,7 +7,7 @@ model: latent_dim: 32 base_channels: 64 num_layers: 4 - use_resnet_feature: False + use_resnet_feature: True use_mlgffn: False use_enhanced_generator: True use_skip: False From e57677a96074f99fde8f36ed2eecdefc67d5d57d Mon Sep 17 00:00:00 2001 From: John Pope Date: Tue, 13 Aug 2024 12:39:11 +1000 Subject: [PATCH 145/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.yaml b/config.yaml index ca4e7c5..81f5032 100644 --- a/config.yaml +++ b/config.yaml @@ -7,7 +7,7 @@ model: latent_dim: 32 base_channels: 64 num_layers: 4 - use_resnet_feature: True + use_resnet_feature: False use_mlgffn: False use_enhanced_generator: True use_skip: False From 1b4021c276b704011cf69a8e5c4ec08b1ad55162 Mon Sep 17 00:00:00 2001 From: John Pope Date: Tue, 13 Aug 2024 12:39:32 +1000 Subject: [PATCH 146/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config.yaml b/config.yaml index 81f5032..9997cf6 100644 --- a/config.yaml +++ b/config.yaml @@ -31,10 +31,10 @@ training: initial_update_ratio: 1 update_ratio_adjustment: 0.1 - + - initial_video_repeat: 5 + initial_video_repeat: 6 final_video_repeat: 2 From b8cdf5b52b46be4acbbc17906fd0437d0e3d99f3 Mon Sep 17 00:00:00 2001 From: John Pope Date: Tue, 13 Aug 2024 13:21:59 +1000 Subject: [PATCH 147/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/config.yaml b/config.yaml index 9997cf6..46fd900 100644 --- a/config.yaml +++ b/config.yaml @@ -9,21 +9,24 @@ model: num_layers: 4 use_resnet_feature: False use_mlgffn: False - use_enhanced_generator: True + use_enhanced_generator: False use_skip: False # Training parameters training: - use_r1_reg: True + + weight_decay: 1e-5 # lower maybe more stable?? + + use_r1_reg: False r1_gamma: 10 r1_interval: 16 - use_multiscale_discriminator: True + use_multiscale_discriminator: False ada_augmentation: False ada_target_r_t: 0.6 ada_kimg: 500 ada_interval: 4 - use_ema: True + use_ema: False ema_decay: 0.999 style_mixing_prob: 0.9 initial_noise_magnitude: 0.01 @@ -59,7 +62,7 @@ training: use_many_xrefs: False - weight_decay: 1e-5 + # Dataset parameters dataset: From daa72756edd566c1dd45bf85619518577dcabd89 Mon Sep 17 00:00:00 2001 From: John Pope Date: Tue, 13 Aug 2024 13:33:13 +1000 Subject: [PATCH 148/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 3 +++ model.py | 14 +++++++++----- train.py | 5 +++-- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/config.yaml b/config.yaml index 46fd900..05e858e 100644 --- a/config.yaml +++ b/config.yaml @@ -1,3 +1,6 @@ +data: + image_size: 256 + # Loss function loss: type: "vanilla" # vanilla worked ok - https://wandb.ai/snoozie/IMF/runs/sa0i7gs6?nw=nwusersnoozie diff --git a/model.py b/model.py index 03da4ce..32774fa 100644 --- a/model.py +++ b/model.py @@ -412,16 +412,20 @@ def forward(self, token, condition): For each scale, aligns the reference features to the current frame using the ImplicitMotionAlignment module. ''' class IMFModel(nn.Module): - def __init__(self,use_resnet_feature=False,use_mlgffn=False,use_skip=False,use_enhanced_generator=False, latent_dim=32, base_channels=64, num_layers=4, noise_level=0.1, style_mix_prob=0.5): + def __init__(self,use_resnet_feature=False,use_mlgffn=False,use_skip=False,image_size=256,use_enhanced_generator=False, latent_dim=32, base_channels=64, num_layers=4, noise_level=0.1, style_mix_prob=0.5): super().__init__() - self.encoder_dims = [64, 128, 256, 512] + + + # Adjust your encoder dimensions based on the new image size + self.encoder_dims = [64, 128, 256, 512] if image_size == 256 else [32, 64, 128, 256] + self.output_channels = [128, 256, 512, 512,512, 512] if image_size == 256 else [64, 128, 256, 256,256, 256] self.latent_token_encoder = LatentTokenEncoder( - initial_channels=64, - output_channels=[128, 256, 512, 512,512, 512], + initial_channels=self.encoder_dims[0], + output_channels=self.output_channels, dm=32 ) - self.motion_dims = [128, 256, 512, 512] + self.motion_dims = [128, 256, 512, 512] if image_size == 256 else [64, 128, 256, 256] self.latent_token_decoder = LatentTokenDecoder() FeatureExtractor = ResNetFeatureExtractor if use_resnet_feature else DenseFeatureEncoder diff --git a/train.py b/train.py index 41fc321..16c8850 100644 --- a/train.py +++ b/train.py @@ -330,7 +330,8 @@ def main(): use_resnet_feature=config.model.use_resnet_feature, use_mlgffn=config.model.use_mlgffn, use_enhanced_generator=config.model.use_enhanced_generator, - use_skip=config.model.use_skip + use_skip=config.model.use_skip, + image_size=config.data.image_size ) add_gradient_hooks(model) @@ -342,7 +343,7 @@ def main(): add_gradient_hooks(discriminator) transform = transforms.Compose([ - transforms.Resize((256, 256)), + transforms.Resize((config.data.image_size, config.data.image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) From 756ef4c40bc2f27b6081374602aee9cee5420307 Mon Sep 17 00:00:00 2001 From: John Pope Date: Tue, 13 Aug 2024 13:34:28 +1000 Subject: [PATCH 149/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.yaml b/config.yaml index 05e858e..cf70ec1 100644 --- a/config.yaml +++ b/config.yaml @@ -17,7 +17,7 @@ model: # Training parameters training: - weight_decay: 1e-5 # lower maybe more stable?? + weight_decay: 1e-4 # lower maybe more stable?? use_r1_reg: False r1_gamma: 10 From 065553f2c8ea2ef24c80736d76be1ffeb8f65289 Mon Sep 17 00:00:00 2001 From: John Pope Date: Tue, 13 Aug 2024 13:42:04 +1000 Subject: [PATCH 150/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.yaml b/config.yaml index cf70ec1..8d7fc89 100644 --- a/config.yaml +++ b/config.yaml @@ -19,7 +19,7 @@ training: weight_decay: 1e-4 # lower maybe more stable?? - use_r1_reg: False + use_r1_reg: True r1_gamma: 10 r1_interval: 16 From df9109027d5bf8eee8c2c96538963fc024f3ef9b Mon Sep 17 00:00:00 2001 From: John Pope Date: Tue, 13 Aug 2024 13:51:45 +1000 Subject: [PATCH 151/152] =?UTF-8?q?=F0=9F=A7=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.yaml b/config.yaml index 8d7fc89..99a6cfd 100644 --- a/config.yaml +++ b/config.yaml @@ -31,7 +31,7 @@ training: use_ema: False ema_decay: 0.999 - style_mixing_prob: 0.9 + style_mixing_prob: 0.5 initial_noise_magnitude: 0.01 final_noise_magnitude: 0.0001 From 9009deafd9443a5df57d7b51ba355ce7533eee94 Mon Sep 17 00:00:00 2001 From: John Pope Date: Fri, 16 Aug 2024 16:22:06 +1000 Subject: [PATCH 152/152] BROKEN --- __pycache__/EMODataset.cpython-311.pyc | Bin 23130 -> 23130 bytes __pycache__/framedecoder.cpython-311.pyc | Bin 10980 -> 10980 bytes __pycache__/helper.cpython-311.pyc | Bin 23838 -> 25381 bytes __pycache__/loss.cpython-311.pyc | Bin 3582 -> 3582 bytes __pycache__/resblock.cpython-311.pyc | Bin 44946 -> 44946 bytes config.yaml | 8 +++--- helper.py | 30 +++++++++++++++++++++++ train.py | 16 ++++++++++-- 8 files changed, 48 insertions(+), 6 deletions(-) diff --git a/__pycache__/EMODataset.cpython-311.pyc b/__pycache__/EMODataset.cpython-311.pyc index 19e52bb0595fe1633e054a8e9a05ab820cf5ab42..6080b16e2b2f1d56c91d9d6340922ff362ae239a 100644 GIT binary patch delta 22 ccmcb$h4I!FM(*Xjyj%=G;1;lFBe#DP09G^xtN;K2 delta 22 ccmcb$h4I!FM(*Xjyj%=G;JA3#MsEKo09W}2_y7O^ diff --git a/__pycache__/framedecoder.cpython-311.pyc b/__pycache__/framedecoder.cpython-311.pyc index 41d27d35f67d2cb05aa21dc32062eb0afb6259f9..c1c4aee466244350198f398404eafa17a997336d 100644 GIT binary patch delta 20 acmaD7`XrQlIWI340}v?f-o26grWOE4=mwww delta 20 acmaD7`XrQlIWI340}wba-nEhYrWOE5EC!|k diff --git a/__pycache__/helper.cpython-311.pyc b/__pycache__/helper.cpython-311.pyc index 6e14c9083d246a5cb62129234f634a32a4cbace7..0b654a8f508c640a955bcad4e39269557a324e60 100644 GIT binary patch delta 4118 zcma(UX>1hNdEXwh>s_zc2N16@_V~cCU@$&04%p^0&OtDkVqzeh@y^&gusgGPGYhu0 z#%`5JDWN3fk~XNWswgxmsnjTKq{=2vf<~=Gt{+rYLlaf1sH*n+-mD$x zK+Wv#H{bi-cfap@@9p!ik)OOp@_rf&`W^86>(U#84V%yA6_d};uKIDW@Svl=SY9UQ zJV^S(vL8kzGLe?c0l-S-T!1U&+vGf$Et5kqDwm7p0vL&MST2Hhg}el)mGV-6RWj*y zL`qmOX&H~X&{*~n;Em%ien%XRryUu`(fJuihJ1-1b0XRB-|3LaOMqQi;Q9dp8IZLf z*|~_*^(~wZ$9R_Yu-oAnbB(#jJYzz}c^C|gdB=PiPsV%9gDh0Y2pOM^bHnbht1~&D zt?uYTP_hSgj6z0bVH751?@KUuVPT&6N~<@F<=j#jjDA!2@zUDyF@MG(`!YH3@uK2~ zoEblSd^SGl3frCl91CQ~cjU~?1TwjUPFS2ANrg1}MR>OBnxPt|qM3;i(MS)hGo)l{ zJfVo@30(xTp;JRVp_*~gj4L8cDk7CMSx<`k;Gm(H>$w&Vw$dS^jbpZ4y|~)cX)L~e zKUWiJ6T4Jun4+qs(q>jSXc3LLlv2di+iJz_wPJUz*cTaJq^02-T6&e3uF2m)4c7uO zMKPv?5;fFOuw$kH2WNIDhDsHAA!vx%W{B60izgynn1g+anWox&Xeiw#_GUAc00qYt zYN(J#jAcG@NKqDV;5tTY?pPGHAtuw1h`4kd@;skbo`&mCsgBX{#;AR1{d}2Mstdg5WZi^1^eoxGt&4s#Mpoe_P!h zySwV*N+PAu^{EkSDU?pO-cdGN8MF*k6vEF~0%@3VjNd4BiVYWc%!W42 zg*JTj$ANS3+B6LmhLWj*0GT zh2hDn$>H<9i%X~SX3AS$^<6H0d*z#DGp&1OO7_eZ_DtM$^_DE*1)ghxWs`xKK-G-9 z>I=w^WBML)(^o)O0D9BEIRQIIOi!lTStvI`>e-}!wUd^!bAkP9s)1w5NpWo~KbDOa zO{m8dS_1?F3r-Zd78+rD^J+;G(*w1HikT{tL~&wqh;8g-UJY3OKe7$qA{)3J&2muA zdfJz;-q6qzJ3Q4Q2fvhHe+spd?aZ6sR9cSS_)XoE648Wi7`)4A2P(>8efg|_FOA!H zfSR-ssh#ZK`PF1A6AQv#8bZoSwzc3t%eJ8+a+I#dw^Oc%UjyCA_7@&6+Kyap2(|z) zRsvuj7M8Iag>HXmXD5%|DuOH)3G@+|tedVZI_oCo)4#fHimbo@C~Bvt0Az#6R>clB z8eU6IvX{aKA!Yv!*N{h9O-bbtYN3Y^oJD}AvgJzXLlJ={kdQ>6A$S784*^&{n2*Y8 z%rpcT+Pj49S|QG0a-s>SP$LHC5PceWM%XJQkAR!EFFzmzzp~Ek`HHZB)em+1dPNO+ zk$qB8Q;E5bIIriz>pI$?-2jH*2hX*R2%#xfTR9~3z>w}`&s1jaL+>J$R*)(Wq~R1Y zqC-@YEjL0s2y>Q88Zaz(Qjs*i7M3SPapQF*rD4l|KuV;Qos{a79zmV=aftzm(mos> zWJXne9onaH1S$f&U@c0TTTdw(JqAS3scNwL?^dlRPcmoq#$vwG`B=zO0**|XJ7>^S zY;Se3@D%Df$MouDVa$US#2ljE(ZPg%g1!Z`)6>sazu|=Z9a&Z2tnEaBpChZwFlBl# z4ta>~qq?d^y!5w7xs2fV?AeGYUI5ha0fOl;d@2IY57${ z)eKY8Vv6N9M^Xy*B)idg+x%Juo7R$lvQ14bP)(_(;~U-sMN`D`gA6qdyx&Y^yLbX- zEOPwosU{VpJ^>>F`l^;%M8w66J7?M#Sk|8tfg^%5e&Y-UY&5v96mC0Uv%#9{S&sD{Sdk9 zm@)$AvVxZo1vPC`?Gn0ap7#5Rz$1T~{ko%y1ew(#)&|fWEEf7l0AIDM|v0vte>yOudagk9}A!jn3hKOCiaTQp3Xj)Bu6eU$Cq zb>~vdr{&f3Y}?_zkG{%Y-ql#d-3KAT?TfcwB+uS* z`dM>#uZI=^1AC&os?pvr6(bJ_*$Q4<$X0IlHp(5Nbr4D2^k2I#x=7dbb9cYzAzkc& zeIoe-dwk#KeS9^rA80)$hVRPQJNCKiQ4LMfV5`tZ`1)8tP_;qYi5xtMuORj+0&F|< zn2Eg`yze4~5Z2dwS2HTUrLP4^T~p!i$MdAZp9F($1^y$`AN1;iAKvVKbLO0x zZ_b=KXYM`vu6X7RQS^LaVSWxh1;^ePKe6&)QB-_)u<6J^#LRm@Y>%3Y9?01qGlfxM z=95%u1_`UoGP97Z)n*a(YRqyoM7>%wYKE!R%o5VpnWcpFrr46B$FN4M$c*OxmWTO) z3N%7pq6D7`HG6)_!84()twCCT&LjP42gY{Ed@yX&jwcg)EGQ(4>mkrI2vGyY*jwBx z$}mY7=E+q&8A=(hJUCd+HU&JP6gQR(G|hP~Fzeu)2O`u;uRE0@xrk+9F06=W!z)Au z{vBSa7UQ{tbrsB4#ZW`AFqRb>!E7Tsp_;M@%eMATLQf^@>lnfW_~TV&XgAd^QC>6X zUbQOJk}UnU1F25eOR}czzjVh{xUaO)_kSstU4bH4O{}n(A&isP#qie1XlZuC2`Ppy z&cOE~*NVmHE4!wO*OU2aCv7C-Nyl}8g9bXm1%X@3uvi7p=iUivQ1mVMYuVLe309ZK z@}w`3wo*`lYs&vDYh@;}8?NHkV|LPi8MNcg(dpn)=DeDr3C~5_MCkEY%s)f@;!I8K zpihJ|$E)YWVs;gn2Brz--JQ>hHqy*CCY=d9A88OH_+0G}9oX-+jbb0xYxNtL4f+}G zVek`3FAdiNFvN;shV2ad8Kwwi9@*n&Vl3?jsH>?5Linf_^&cR`4t!C2gnV0Bw=JL& z7%n+YDJEoI%Gev90K>c(S>`Ur*P9x}0X*N-SkK|rJr@e7xxkCSa)Jr^({sL87jQc^ zH%|msQy04NiRQbP?k7>NmxW;6Ih6p*RSQTTF9Vy&CcVa}D}5=;uwfnc=v9&RJb8mM zjT!(qaQ8-Z^v?DzEN*4Em0=sh#qju?Ny~=YNGZLkgbgkDq24C$!@u;EQFVwR&Q&9s zn6=F8HenhETB2R|GtWZ=F4to=ZdtruCA2K$z_Q2j@uV{YCux+OIn?r|hoU*WwAjLYuZyJy14}Nt7Qx4c^<%U5-I+vs z*W7&BCdnej^yLI?L{J~(+!=FHlRe+$VDIYe4HCo0R&T6+i#>ar;a!3&%q?A*`z?<; z%Mix4?xWET7CF=A1dLGQ3IeJE{I+|!=jR;M)D^=9C3C@=w!pQZ12?VN zUHJjo)O`5xlG&{M2+yt2gMYGkjvdD15PsjYL!8F-y_RUjW4-MngkSbX{Xf&-cQSwU-b58m_fKhNWbGyqQ2TiW zPxdv42!7h9wf@4yeE4vR;Ii1rMGGkhzmqYD_5F_qc(>4zIo0nE#Q3u9*S*qbgEJ#@ zCk@j~*b_7uKi;%c#cnpVHxAaAW4mdX8Y90o*oW%}Hbm4-mTNon73ET1!x4O8V0nq! zD%Xt+SbNi!ZNfeH!$5Q8eXJ_jJY}a7DXR~_0lg@Do%CVn;FbcJKWe1McEyMp@`rxk3GV`mBc<`WVj>uG;6-k!A-*(V@&LF zP0Sb2LLK4p;c|Z)S$i|@4?A3;-M9CNuFsh$pd5fYcB-CX2{Sg3B>iJa+8EIWI340}y=PwQnQ$XI=nFa|bX0 delta 20 acmew-{ZE>EIWI340}$}6-?fqZGcN!_%mz9D diff --git a/__pycache__/resblock.cpython-311.pyc b/__pycache__/resblock.cpython-311.pyc index 24f3e21eee82fe84bfa3b0c13d1392d376016f3a..5c92feff1ab0b004e962691a64b25e30bac9b3f4 100644 GIT binary patch delta 22 ccmbPqpJ~#4Chq0Dyj%=GptO7UM(+0Y08^v~M*si- delta 22 ccmbPqpJ~#4Chq0Dyj%=G;JA3#M(+0Y08{7(O#lD@ diff --git a/config.yaml b/config.yaml index 99a6cfd..69f0d15 100644 --- a/config.yaml +++ b/config.yaml @@ -12,11 +12,11 @@ model: num_layers: 4 use_resnet_feature: False use_mlgffn: False - use_enhanced_generator: False + use_enhanced_generator: True use_skip: False # Training parameters training: - + use_subsampling: False # saves ram? https://github.com/johndpope/MegaPortrait-hack/issues/41 weight_decay: 1e-4 # lower maybe more stable?? use_r1_reg: True @@ -29,7 +29,7 @@ training: ada_kimg: 500 ada_interval: 4 - use_ema: False + use_ema: True ema_decay: 0.999 style_mixing_prob: 0.5 initial_noise_magnitude: 0.01 @@ -90,7 +90,7 @@ logging: save_steps: 250 # Accelerator settings accelerator: - mixed_precision: "fp16" # Options: "no", "fp16", "bf16" + mixed_precision: "no" # Options: "no", "fp16", "bf16" cpu: false num_processes: 1 # Set to more than 1 for multi-GPU training diff --git a/helper.py b/helper.py index 5b6f8a6..a951ba4 100644 --- a/helper.py +++ b/helper.py @@ -14,6 +14,36 @@ from PIL import Image from mpl_toolkits.mplot3d import Axes3D + +def consistent_sub_sample(tensor1, tensor2, sub_sample_size): + """ + Consistently sub-sample two tensors with the same random offset. + + Args: + tensor1 (torch.Tensor): First input tensor of shape (B, C, H, W) + tensor2 (torch.Tensor): Second input tensor of shape (B, C, H, W) + sub_sample_size (tuple): Desired sub-sample size (h, w) + + Returns: + tuple: Sub-sampled versions of tensor1 and tensor2 + """ + assert tensor1.shape == tensor2.shape, "Input tensors must have the same shape" + assert tensor1.ndim == 4, "Input tensors should have 4 dimensions (B, C, H, W)" + + batch_size, channels, height, width = tensor1.shape + sub_h, sub_w = sub_sample_size + + assert height >= sub_h and width >= sub_w, "Sub-sample size should not exceed the tensor dimensions" + + offset_x = torch.randint(0, height - sub_h + 1, (1,)).item() + offset_y = torch.randint(0, width - sub_w + 1, (1,)).item() + + tensor1_sub = tensor1[..., offset_x:offset_x+sub_h, offset_y:offset_y+sub_w] + tensor2_sub = tensor2[..., offset_x:offset_x+sub_h, offset_y:offset_y+sub_w] + + return tensor1_sub, tensor2_sub + + def plot_loss_landscape(model, loss_fns, dataloader, num_points=20, alpha=1.0): # Store original parameters original_params = [p.clone() for p in model.parameters()] diff --git a/train.py b/train.py index 16c8850..880e2e3 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,7 @@ from VideoDataset import VideoDataset from EMODataset import EMODataset,gpu_padded_collate from torchvision.utils import save_image -from helper import log_loss_landscape,log_grad_flow,count_model_params,normalize, add_gradient_hooks, sample_recon +from helper import consistent_sub_sample,log_grad_flow,count_model_params,normalize, add_gradient_hooks, sample_recon from torch.optim import AdamW from omegaconf import OmegaConf import lpips @@ -25,6 +25,9 @@ from stylegan import EMA from torch.optim import AdamW, SGD from transformers import Adafactor +from torchvision.utils import save_image + + def load_config(config_path): return OmegaConf.load(config_path) @@ -160,7 +163,15 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato x_current = source_frames[:, current_idx] x_reconstructed = model(x_current, x_reference, style_mixing_prob, noise_magnitude) - + + save_image(x_reconstructed, "x_reconstructed.png", normalize=True) + + # Sub-sample tensors + if config.training.use_subsampling: + sub_sample_size = (128, 128) # As mentioned in the paper + x_current, x_reconstructed = consistent_sub_sample(x_current, x_reconstructed, sub_sample_size) + + # B. Loss Calculation # 1. Pixel-wise Loss l_p = pixel_loss_fn(x_reconstructed, x_current).mean() @@ -218,6 +229,7 @@ def train(config, model, discriminator, train_dataloader, val_loader, accelerato total_g_loss += g_loss.item() total_d_loss += d_loss.item() + progress_bar.update(1) progress_bar.set_postfix({"G Loss": f"{g_loss.item():.4f}", "D Loss": f"{d_loss.item():.4f}"})