-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SRGAN - Bypass check_array_lengths(X, Y, W) in training.py for different input and output batch sizes #3940
Comments
So I finally got the full model to work by duplicating most of the code in _make_train_function , fit and _fit_loop (and commenting out the check_array_lengths function). It compiles fully and can be trained without errors, however the loss function must be a single dummy loss which accepts (y_true, y_pred) and returns a single Keras variable with value = 0. So binary_crossentropy cannot be applied to the discriminator output and must be replaced by the AdversarialLossRegularizer. Still, training can proceed properly even though the assertion of "Input arrays should have the same number of samples as target arrays. Found 8 input samples and 16 target samples." failed. However when I begin pre training the generator network with just the generator and the VGG loss, the output generated is extremely pixelated as the below two images show. In fact, just the generator and vgg model combined do not even cause the check_array_lengths error, since I am freezing the VGG weights and training only using the ContentVGGRegularizer and TVRegularizer (using dummy loss variable when compiling so that training is done only via regularizers). The generator is trained on MS COCO dataset (random 50k sample images) instead of the 50k validation set of ImageNet, since I don't have the validation set. The following two images are downscaled from original size to 32x32 for input to the generator model and 128x128 for validation against the generated results. The generator is pre trained with only VGG 2,2 loss as described in the paper, and upscales the 32x32 input to 128x128. I am using 32x32 as input size instead of 96x96 as my GPU cannot handle such a large model combined with such a large input (For 96x96, the full model contains nearly 303 million parameters). I can tell that the models are learning something, since the dots are far more pronounced in the first 1000~ validation images, but as the loss drops steadily, the output images become clearer and contain less RGB dots. However even at the end of 50k iterations, the dots do not disappear on either the pre-training network or the full SR + GAN + VGG model. At Iteration 23300 (Using SR + VGG model, for pre-training) : At Iteration 50000 (Using full SR + GAN + VGG model, for final training) : As can be seen, the RGB dots form a grid pattern throughout the upscaled image. The same dots appear when the last two Deconvolution layers are replaced with Upscale + Convolution, although such a network produces very poor images. Perhaps the error lies in trying to upscale the original image 4x without providing any information to the network of what data should fill the upsampled space. I'll have to look into the sub-pixel convolution layer mentioned in the paper Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network to see if this can upscale the image properly without producing these RGB grids. The following images are using UpSampling + Convolution layers to replace the Deconvolution layers + stride 2. At Iteration 21000 (Using SR + VGG model, for pre-training), At Iteration 45000 (Using full SR + GAN + VGG model, for final training), Since the original issue is solved simply by commenting out one crucial check, the issue is more or less solved. I understand the need to assert input batch size be same as that of output batch size, but if the model is carefully constructed then the user can still train the model correctly. I would appreciate some better way to do this without such a hackish way of bypassing internal Keras code. The code to define the generative + VGG model is simple enough, if anyone wishes to test.
EDIT: Didn't mean to close the issue. I am still looking at better ways to train such models without such a hackish way to get around Keras internal code. |
I found a bug in my normalization code so I've fixed it. The images are still pixelated in grids, however it is no longer a RGB grid but instead a pale grid of 4 colors. Sample images are below I believe this image distortion is mainly due to the network not having any idea of what it should fill in the upscaled space. To test this idea, I changed the architecture to become more like an auto encoder, providing blurred and distorted (affine transformations) inputs of size 128x128 and passed high resolution images to the Input 3 (VGG network) to pretrain. The results are significantly better, as can be seen below. The model uses only the VGG 2,2 regularizer loss to train, by passing a dummy loss function to the optimizer. This suggests that the error is no longer in my normalization layer, but in the SRGAN model itself. In hindsight, I am just wondering how the twitter team managed to avoid this pixel grid problem without passing any information of what the upscaled output should look like. They may have used a custom Sub-Pixel convolution layer which up samples the previous layer using convolutional filter information, which is what I will be trying to implement next as well. I am facing another challenge in training the Discriminator model. Since I cant use binary crossentropy to train the model (It says This is not ideal, as GANs must alternate updates between discriminator and generator, and due to the input mismatch error, I cannot continue training the discriminator network when training the full SRGAN model (SR + GAN + VGG). |
Hi, @titu1994 I don't understand the sentesnce.
Could you tell me the detail ? Thanks a lot. |
@LLCF The Auto Encoder model is not in the paper. That was the reason I didn't expand on that point. It was merely to test the theory that the model implementation and learning with VGG loss alone was correct. In any case, the auto encoder style model accepts a 128x128 image, downscales it 4x to 32x32, has multiple residual blocks and then upscales the image to 128x128 again, using deep skip connections to speed up training. The full architecture is as below (It's huge so please zoom in to see the sizes) : |
Hi, @titu1994 May be this link can be useful for the pixel grid problem http://distill.pub/2016/deconv-checkerboard/. Some questions, do you train the net with total loss for generator like in paper (vgg loss + adv loss + tv loss)? How much every part influence on the total loss? How you balancing the GAN game (discriminator - generator contest)?. |
@keeper121 Thanks for the link. I did check it out before but was wondering how to implement linear up sampling in Keras or if it's already implemented in the current UpSampling2D layer. I have not been able to train the full network yet sadly. It keeps giving some Python recursion error. If it could run, it would use the parameters from the paper. |
@titu1994 |
@Neltherion Well I was able to implement the SubPixelConvolution layer in Keras (at least for Theano, haven't tried the tensorflow implementation yet). That did result in slightly less 'grid' deformations in the images. However, there seems to be a simple solution posted in the above link, which is to use Nearest Neighbour Resize convolution, though I don't know how to implement that in Keras. I don't know if the Upscaling layer in Keras is equivalent to Nearest Neighbour resize convolution, since the results I got from it were completely unusable. As to what I was able to complete : the generator and discriminator are now pre-trainable. In the sense that they can be trained separately. However upon merging the Discriminator, Generator and VGG models into one, it causes a Theano gradient recursion error. I did try setting the python recursion limit to 20k and above, but I don't have that much RAM in my laptop to handle that. I haven't been able to move past that, since I am currently preparing for my Masters course at UIC. I plan to work on it and fix it when I have a little more time but for now I am using the Theano implementation of the paper - Neural Enhance. |
Thanks for the quick reply... I hope you'd find the time to finish this, I learned some useful things from your implementation... God speed... |
@Neltherion With a little help, I was able to get the full SRGAN (Generator + Discriminator + VGG) Network to train. However I am getting very erratic results from it : I ran the code for several thousand iterations. Notice the iteration number and the losses - Around the 3670 mark Sudden hiccup in training at the 3754 mark : Another hiccup at the 3762 mark : Sudden improvement at the 4147 mark: And then a constant discriminator loss for the next 4000 iterations : This is basically an implementation issue, I think. I have seen this happen when discriminator overpowers the generator and vice versa. Or it may be an incorrect implementation of how the weights are updated. Please update your copy to the latest commit (71) as it fixed the training bug. I hope to have some more time to fix these bugs. In any case, the original issue is closed a long time ago. The fact that the entire SRGAN model trains by bypassing the check means there is no point in keeping the issue open. |
I will update as soon as I find some useful solution. Probably I will take a good look at the Theano implementation by jcjohnson to see how he overcame this issue. I barely have a few weeks before my move so I will try to find some time in the middle to solve this. I don't like keeping bugs around for long either. |
@titu1994 By the way, could you please elaborate on how you replaced |
I replaced Deconvolution layers with SubPixelConvolution layers instead which I have implemented in Theano. It works, and shape inference is a definite plus, but while the results are better than the Deconvolution layers, I can still see grid like structures. I will be studying how to implement Resize Convolution layers as mentioned in the blog article which may fix the artifacts altogether. For now, SubPixelConvolution layers are a good enough substitute to Deconvolution layers with stride 2 or more. The boundary artifacts are only visible when attempting am upsampling task such as SRGAN. There is no such defects when training an auto encoder style JPEG artifact remover or sharpening network (which I have tried and tested to see if the SubPixelConvolution layers was working correctly) |
@Neltherion I have pushed a few commits to correct the discriminator training. Changes were :
All of these changes now allow the discriminator to learn properly, having a loss value between 3-6 (due to adversarial loss regularizer) and accuracy between 0.7 to 0.88 during training the discriminator. Also, I have switched to using the Keras Upsampling layer instead of SubPixelConvolution layer. The Keras UpSampling2D layer is equivalent to the Nearest Neighbour Resize Convolution mentioned in http://distill.pub/2016/deconv-checkerboard/ and I can see that the results are significantly better than before, but still not very good. I haven't had time to check the training of the full model. Will do so in the coming weeks. |
@titu1994 By the way, why do you think the batch size affected the training? was it perhaps stuck in some local optima because of 1-batch-size batches? some papers claim they had to reduce batch size for the GAN to work while others (such as yourself) say they've seen improvements when increasing the batch size... anyway, Thanks for the update! |
@Neltherion Ah that link is useful. I didn't know about it, but did watch many of the videos of NIPS 2016 on facebook and several slides which discussed how training methods of GANs can be changed to improve performance. I believe the soft label trick was from a paper that I read Improved Techniques for Training GANs. Flipping labels was mentioned in some slide about GANs in NIPS2016. Although now that I have read the post, I have to read up on how I can modify the loss function to Normalization to [-1, 1] makes sense since we are using tanh activation to squash the activations to [-1, 1] and then denormalizing to [0, 255] in the denormalization layer for an image upsampling task. Never thought that it was a worthwhile insight :P I set up the training of the Discriminator so that each batch would have all positive or all negative samples. Simple batch size trick I suppose but it's nice to know it is a formal method to improve performance. The leaky relu trick was in the original paper about SRGAN so I knew it was tested and important. I rarely use MaxPooling anymore, convolution with stride 2 is much more efficient in my opinion. The post suggests using Deconvolutions with stride 2 and SubPixelConvolution upscaling, but they cause grid effects so I guess it doesn't apply to image upscaling tasks. Thankfully the new upsampling function works properly, but has a chance to create completely solid colors for every image and get stuck at that for the generator. That is the reason that I have provided a pretrained SRGAN model, which has been trained on 80k images with the new upsampling function to avoid a bad initialization. I think batch size plays a huge role in training GANs. I simply could not get any GAN I implement to work with batch size < 8. Maybe it is the design of the discriminator, or some error in my code. In my experience, I simply train the discriminator with a batch size of 16, and then train the full model with any batch size (even 1). This may be due to the fact that with batch size = 1, Batch Normalization acts like Instance normalization. This may cause the discriminator to get stuck at local minimas, and always have similar loss values (7.093~8.192) |
@titu1994 can u please explain u implementation of resize-conv? |
When I run the code I get the following error "TypeError: standardize_input_data() got an unexpected keyword argument 'check_batch_dim'" and the errors are in Can anyone please help in resolving this issue? |
I am trying to implement the SRGAN model from the paper Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network. Since this uses both a GAN and VGG Perceptual losses, I am using modified ActivityRegularizers to incorporate the various losses. I am using the latest Theano and Keras, and the model is for Theano only (I'm using Windows, no Tensorflow yet)
The full model architecture and the implementation of the loss are here : SRGAN regularizers gist
A simplified view of the entire model is:
In this model, I am passing a 96 x 96 blurred low resolution image (LR) as input to the SR-ResNet network, and 384 x 384 high resolution images (HR) as input 2 and 3 to the Discriminator and VGG networks, as well as the outputs of the SR-ResNet.
The issue is that since original inputs to the SR-ResNet have a batch size of 8, Keras assumes that output batch size should also be 8. However, since we are merging the original high resolution images into both the discriminator and VGG networks, batch size becomes 16.
The reason for adding the original HR input images as Input2 and Input3 is that the ContentVGGRegularizer needs to compare the gram matrix of the HR inputs to the gram matrix of the generated outputs (G(LR)) from the generative model. Also, to train the discriminator network, we require the original images (D(G(LR)).
The error is fairly simple and understandable :
My question is, is there any way to train this network without manually creating the train_function, getting updates, add regularizers and then use the train_function?
Or is there some other way to train such networks without appending the original inputs? I have not yet found a way to mask a portion of the input batch, which could potentially solve this problem (by passing the blurred and HR images to the SR-ResNet input), however this poses another challenge - LR images are 96 x 96, HR images are 384, 384).
Any solutions? I can only think of one right now, which is to duplicate most of the _make_train_function and _fit_loop code to bypass this check.
The text was updated successfully, but these errors were encountered: