From 2f575461dbea1c8ecd5fc09a447669bc10392a85 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 3 May 2024 10:34:09 -0700 Subject: [PATCH] remove groupnorms https://arxiv.org/abs/2312.02696 --- make_a_video_pytorch/make_a_video.py | 10 ++++------ setup.py | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/make_a_video_pytorch/make_a_video.py b/make_a_video_pytorch/make_a_video.py index ae671a9..3e1a6c9 100644 --- a/make_a_video_pytorch/make_a_video.py +++ b/make_a_video_pytorch/make_a_video.py @@ -346,12 +346,11 @@ def __init__( dim, dim_out, kernel_size = 3, - temporal_kernel_size = None, - groups = 8 + temporal_kernel_size = None ): super().__init__() self.project = PseudoConv3d(dim, dim_out, 3) - self.norm = nn.GroupNorm(groups, dim_out) + self.norm = RMSNorm(dim_out) self.act = nn.SiLU() def forward( @@ -376,7 +375,6 @@ def __init__( dim_out, *, timestep_cond_dim = None, - groups = 8 ): super().__init__() @@ -388,8 +386,8 @@ def __init__( nn.Linear(timestep_cond_dim, dim_out * 2) ) - self.block1 = Block(dim, dim_out, groups = groups) - self.block2 = Block(dim_out, dim_out, groups = groups) + self.block1 = Block(dim, dim_out) + self.block2 = Block(dim_out, dim_out) self.res_conv = PseudoConv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward( diff --git a/setup.py b/setup.py index a02f92d..1d3b330 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'make-a-video-pytorch', packages = find_packages(exclude=[]), - version = '0.3.1', + version = '0.4.0', license='MIT', description = 'Make-A-Video - Pytorch', author = 'Phil Wang',