Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Puiho Chan committed Aug 24, 2020
0 parents commit 532f199
Show file tree
Hide file tree
Showing 32 changed files with 3,441 additions and 0 deletions.
165 changes: 165 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Generating Visually Aligned Sound from Videos

This is the official pytorch implementation of the TIP paper "[Generating Visually Aligned Sound from Videos][REGNET]" and the corresponding Visually Aligned Sound (VAS) dataset.

Demo videos containing sound generation results can be found [here][demo].

![](https://github.com/PeihaoChen/regnet/blob/master/overview.png)


# Contents
----

* [Usage Guide](#usage-guide)
* [Getting Started](#getting-started)
* [Installation](#installation)
* [Download Datasets](#download-datasets)
* [Data Preprocessing](#data-preprocessing)
* [Training REGNET](#training-regnet)
* [Generating Sound](#generating-sound)
* [Other Info](#other-info)
* [Citation](#citation)
* [Contact](#contact)


----
# Usage Guide

## Getting Started
[[back to top](#Generating-Visually-Aligned-Sound-from-Videos)]

### Installation

Clone this repository into a directory. We refer to that directory as *`REGNET_ROOT`*.

```bash
git clone https://github.com/PeihaoChen/regnet
cd regnet
```
Create a new Conda environment.
```bash
conda create -n regnet python=3.7.1
conda activate regnet
```
Install [PyTorch][pytorch] and other dependencies.
```bash
conda install pytorch==1.2.0 torchvision==0.4.0 cudatoolkit=10.0
conda install ffmpeg -n regnet -c conda-forge
pip install -r requirements.txt
```

### Download Datasets

In our paper, we collect 8 sound types (Dog, Fireworks, Drum, Baby form [VEGAS][vegas] and Gun, Sneeze, Cough, Hammer from [AudioSet][audioset]) to build our [Visually Aligned Sound (VAS)][VAS] dataset.
Please first download VAS dataset and unzip the data to *`$REGNET_ROOT/data/`* folder.

For each sound type in AudioSet, we download all videos from Youtube and clean data on Amazon Mechanical Turk (AMT) using the same way as [VEGAS][visual_to_sound].


```bash
unzip ./data/VAS.zip -d ./data
```



### Data Preprocessing

Run `data_preprocess.sh` to preprocess data and extract RGB and optical flow features.

Notice: The script we provided to calculate optical flow is easy to run but is resource-consuming and will take a long time. We strongly recommend you to refer to [TSN repository][TSN] and their built [docker image][TSN_docker] (our paper also uses this solution) to speed up optical flow extraction and to restrictly reproduce the results.
```bash
source data_preprocess.sh
```


## Training REGNET

Training the REGNET from scratch. The results will be saved to `ckpt/dog`.

```bash
CUDA_VISIBLE_DEVICES=7 python train.py \
save_dir ckpt/dog \
auxiliary_dim 64 \
rgb_feature_dir data/features/dog/feature_rgb_bninception_dim1024_21.5fps \
flow_feature_dir data/features/dog/feature_flow_bninception_dim1024_21.5fps \
mel_dir data/features/dog/melspec_10s_22050hz \
checkpoint_path ''
```

In case that the program stops unexpectedly, you can continue training.
```bash
CUDA_VISIBLE_DEVICES=7 python train.py \
-c ckpt/dog/opts.yml \
checkpoint_path ckpt/dog/checkpoint_018081
```

## Generating Sound


During inference, our RegNet will generate visually aligned spectrogram, and then use [WaveNet][wavenet] as vocoder to generate waveform from spectrogram. You should first download our trained WaveNet model for different sound categories (
[Dog](https://github.com/PeihaoChen/regnet/releases/download/WaveNet_model/dog_checkpoint_step000200000_ema.pth),
[Fireworks](https://github.com/PeihaoChen/regnet/releases/download/WaveNet_model/fireworks_checkpoint_step000267000_ema.pth),
[Drum](https://github.com/PeihaoChen/regnet/releases/download/WaveNet_model/drum_checkpoint_step000160000_ema.pth),
[Baby](https://github.com/PeihaoChen/regnet/releases/download/WaveNet_model/baby_checkpoint_step000470000_ema.pth),
[Gun](https://github.com/PeihaoChen/regnet/releases/download/WaveNet_model/gun_checkpoint_step000152000_ema.pth),
[Sneeze](https://github.com/PeihaoChen/regnet/releases/download/WaveNet_model/sneeze_checkpoint_step000071000_ema.pth),
[Cough](https://github.com/PeihaoChen/regnet/releases/download/WaveNet_model/cough_checkpoint_step000079000_ema.pth),
[Hammer](https://github.com/PeihaoChen/regnet/releases/download/WaveNet_model/hammer_checkpoint_step000137000_ema.pth)
).

The generated spectrogram and waveform will be saved at `ckpt/dog/inference_result`
```bash
CUDA_VISIBLE_DEVICES=7 python test.py \
-c ckpt/dog/opts.yml \
aux_zero True \
checkpoint_path ckpt/dog/checkpoint_041000 \
save_dir ckpt/dog/inference_result \
wavenet_path /path/to/wavenet_dog.pth
```

If you want to train your own WaveNet model, you can use [WaveNet repository][wavenet_repository].
```bash
git clone https://github.com/r9y9/wavenet_vocoder && cd wavenet_vocoder
git checkout 2092a64
```

Enjoy your experiments!


# Other Info
[[back to top](#Generating-Visually-Aligned-Sound-from-Videos)]

## Citation


Please cite the following paper if you feel REGNET useful to your research
```
@Article{chen2020regnet,
author = {Peihao Chen, Yang Zhang, Mingkui Tan, Hongdong Xiao, Deng Huang and Chuang Gan},
title = {Generating Visually Aligned Sound from Videos},
journal = {TIP},
year = {2020},
}
```

## Contact
For any question, please file an issue or contact
```
Peihao Chen: phchencs@gmail.com
Hongdong Xiao: xiaohongdonghd@gmail.com
```

[REGNET]:https://arxiv.org/abs/2008.00820
[audioset]:https://research.google.com/audioset/index.html
[VEGAS_link]:http://bvision11.cs.unc.edu/bigpen/yipin/visual2sound_webpage/VEGAS.zip
[pytorch]:https://github.com/pytorch/pytorch
[wavenet]:https://arxiv.org/abs/1609.03499
[wavenet_repository]:https://github.com/r9y9/wavenet_vocoder
[opencv]:https://github.com/opencv/opencv
[dense_flow]:https://github.com/yjxiong/dense_flow
[VEGAS]: http://bvision11.cs.unc.edu/bigpen/yipin/visual2sound_webpage/visual2sound.html
[visual_to_sound]: https://arxiv.org/abs/1712.01393
[TSN]: https://github.com/yjxiong/temporal-segment-networks
[VAS]: https://drive.google.com/file/d/14birixmH7vwIWKxCHI0MIWCcZyohF59g/view?usp=sharing
[TSN_docker]: https://hub.docker.com/r/bitxiong/tsn/tags
[demo]: https://youtu.be/fI_h5mZG7bg
104 changes: 104 additions & 0 deletions Recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os
import shutil
import time
import sys


class Recorder(object):
def __init__(self, snapshot_pref, exclude_dirs=None, max_file_size=10):
"""
:param snapshot_pref: The dir you want to save the backups
:param exclude_dirs: The dir name you want to exclude; eg ["results", "data"]
:param max_file_size: The minimum size of backups file; unit is MB
"""
date = time.strftime('%Y%m%d%H%M%S', time.localtime(time.time()))
if not os.path.isdir(snapshot_pref):
os.makedirs(snapshot_pref, exist_ok=True)
self.save_path = snapshot_pref
self.log_file = self.save_path + "log.txt"
self.readme = self.save_path + "README.md"
self.opt_file = self.save_path + "opt.log"
self.code_path = os.path.join(self.save_path, "code_{}/".format(date))
self.exclude_dirs = exclude_dirs
self.max_file_size = max_file_size
if os.path.isfile(self.readme):
os.remove(self.readme)
if not os.path.isdir(self.code_path):
os.mkdir(self.code_path)
self.copy_code(dst=self.code_path)
self.tee_stdout(os.path.join(snapshot_pref, "log.txt"))
print ("|===>Backups will be saved at", self.save_path)

def copy_code(self, src="./", dst="./code/"):
start_time = time.time()
file_abs_list = []
src_abs = os.path.abspath(src)
for root, dirs, files in os.walk(src_abs):
exclude_flag = True in [root.find(exclude_dir)>=0 for exclude_dir in self.exclude_dirs]
if not exclude_flag:
for name in files:
file_abs_list.append(root + "/" + name)

for file_abs in file_abs_list:
file_split = file_abs.split("/")[-1].split('.')
# if len(file_split) >= 2 and file_split[1] == "py":
if os.path.getsize(file_abs) / 1024 / 1024 < self.max_file_size and not file_split[-1] == "pyc":
src_file = file_abs
dst_file = dst + file_abs.replace(src_abs, "")
if not os.path.exists(os.path.dirname(dst_file)):
os.makedirs(os.path.dirname(dst_file))
# shutil.copyfile(src=src_file, dst=dst_file)
try:
shutil.copy2(src=src_file, dst=dst_file)
except:
print("copy file error")
print("|===>Backups using time: %.3f s"%(time.time() - start_time))

def tee_stdout(self, log_path):
log_file = open(log_path, 'a', 1)
stdout = sys.stdout

class Tee:

def write(self, string):
log_file.write(string)
stdout.write(string)

def flush(self):
log_file.flush()
stdout.flush()

sys.stdout = Tee()

def writeopt(self, opt):
with open(self.opt_file, "w") as f:
for k, v in opt.__dict__.items():
f.write(str(k)+": "+str(v)+"\n")

def writelog(self, input_data):
txt_file = open(self.log_file, 'a+')
txt_file.write(str(input_data) + "\n")
txt_file.close()

def writereadme(self, input_data):
txt_file = open(self.readme, 'a+')
txt_file.write(str(input_data) + "\n")
txt_file.close()

def gennetwork(self, var):
self.graph.draw(var=var)

def savenetwork(self):
self.graph.save(file_name=self.save_path+"network.svg")

"""def writeweights(self, input_data, block_id, layer_id, epoch_id):
txt_path = self.weight_folder + "conv_weight_" + str(epoch_id) + ".log"
txt_file = open(txt_path, 'a+')
write_str = "%d\t%d\t%d\t" % (epoch_id, block_id, layer_id)
for x in input_data:
write_str += str(x) + "\t"
txt_file.write(write_str+"\n")
def drawhist(self):
drawer = DrawHistogram(txt_folder=self.weight_folder, fig_folder=self.weight_fig_folder)
drawer.draw()"""
61 changes: 61 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from yacs.config import CfgNode as CN

_C = CN()
_C.epochs = 1000
_C.num_epoch_save = 10
_C.seed = 123
_C.dynamic_loss_scaling = True
_C.dist_backend = "nccl"
_C.dist_url = "tcp://localhost:54321"
_C.cudnn_enabled = True
_C.cudnn_benchmark = False
_C.save_dir = 'ckpt/dog'
_C.checkpoint_path = ''
_C.epoch_count = 0
_C.exclude_dirs = ['ckpt', 'data']
_C.training_files = 'filelists/dog_train.txt'
_C.test_files = 'filelists/dog_test.txt'
_C.rgb_feature_dir = "data/features/dog/feature_rgb_bninception_dim1024_21.5fps"
_C.flow_feature_dir = "data/features/dog/feature_flow_bninception_dim1024_21.5fps"
_C.mel_dir = "data/features/dog/melspec_10s_22050hz"
_C.video_samples = 215
_C.audio_samples = 10
_C.mel_samples = 860
_C.visual_dim = 2048
_C.n_mel_channels = 80

# Encoder parameters
_C.random_z_dim = 512
_C.encoder_n_lstm = 2
_C.encoder_embedding_dim = 2048
_C.encoder_kernel_size = 5
_C.encoder_n_convolutions = 3

# Auxiliary parameters
_C.auxiliary_type = "lstm_last"
_C.auxiliary_dim = 256
_C.auxiliary_sample_rate = 32
_C.mode_input = ""
_C.aux_zero = False

# Decoder parameters
_C.decoder_conv_dim = 1024

# Mel-post processing network parameters
_C.postnet_embedding_dim = 512
_C.postnet_kernel_size = 5
_C.postnet_n_convolutions = 5

_C.loss_type = "MSE"
_C.weight_decay = 1e-6
_C.grad_clip_thresh = 1.0
_C.batch_size = 64
_C.lr = 0.0002
_C.beta1 = 0.5
_C.continue_train = False
_C.lambda_Oriloss = 10000.0
_C.lambda_Silenceloss = 0
_C.niter = 100
_C.D_interval = 1
_C.wo_G_GAN = False
_C.wavenet_path = ""
26 changes: 26 additions & 0 deletions criterion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from torch import nn


class RegnetLoss(nn.Module):
def __init__(self, loss_type):
super(RegnetLoss, self).__init__()
self.loss_type = loss_type
print("Loss type: {}".format(self.loss_type))

def forward(self, model_output, targets):

mel_target = targets
mel_target.requires_grad = False
mel_out, mel_out_postnet = model_output

if self.loss_type == "MSE":
loss_fn = nn.MSELoss()
elif self.loss_type == "L1Loss":
loss_fn = nn.L1Loss()
else:
print("ERROR LOSS TYPE!")

mel_loss = loss_fn(mel_out, mel_target) + \
loss_fn(mel_out_postnet, mel_target)

return mel_loss
Loading

0 comments on commit 532f199

Please sign in to comment.