Skip to content
Alexandru Dinu edited this page Feb 12, 2021 · 27 revisions

The models are inspired from [1], [2]. They are just experiments I have tried and do not fully implement the architectures described in the referenced papers.

Training

The input consists of 720p images from the YouTube-8M dataset (credit goes to gsssrao for the downloader and frame generator scripts).

Two datasets were used for training:

  • 121,827 frames (download: weights)
  • 2,286 frames (download: dataset / weights) - quickest way to start experimenting!

The images are padded to 1280x768 (i.e. 24,24 height pad), so that they can be split into 60 128x128 patches. The model only gets to see a singular patch per forward pass (i.e. there are 60 forward passes and optimization steps for an image). The loss is computed (per patch) as MSELoss(orig_patch_ij, out_patch_ij), and we have an average loss per image.

Models and architecture

Model Patch latent size Compressed size
cae_16x8x8_zero_pad_bin 16x8x8 7.5KB
cae_16x8x8_refl_pad_bin 16x8x8 7.5KB
cae_16x16x16_zero_pad_bin 16x16x16 30KB
cae_32x32x32_zero_pad_bin 32x32x32 240KB

All models implement stochastic binarization [2], that is, the encoded representation is in binary format. The number of bits per patch is given by the patch latent size, and the compressed size will be 60 * bits_per_patch / 8 / 1024 KB.

Stochastic binarization

The benefits of stochastic binarization, as mentioned in [2], are:

(1) bit vectors are trivially serializable/deserializable for image transmission over the wire,

(2) control of the network compression rate is achieved simply by putting constraints on the bit allowance

(3) a binary bottleneck helps force the network to learn efficient representations compared to standard floating-point layers, which may have many redundant bit patterns that have no effect on the output.

Since the best performing model is cae_32x32x32_zero_pad_bin, we only describe its (high-level) architecture. Output shapes and block descriptions can be found in the code.

encoder:
x => conv1 --> conv2 --> enc_block1 --> (+) --> enc_block2 --> (+) --> enc_block3 --> (+) --> conv3 (tanh) => enc
                 |-----------------------^   |------------------^   |------------------^

decoder:
enc => up_conv1 --> dec_block1 --> (+) --> dec_block2 --> (+) --> dec_block3 --> (+) --> up_conv2 --> up_conv3 (tanh) => x
           |------------------------^   |------------------^   |------------------^

References

Clone this wiki locally