-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Puiho Chan
committed
Aug 24, 2020
0 parents
commit 532f199
Showing
32 changed files
with
3,441 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
# Generating Visually Aligned Sound from Videos | ||
|
||
This is the official pytorch implementation of the TIP paper "[Generating Visually Aligned Sound from Videos][REGNET]" and the corresponding Visually Aligned Sound (VAS) dataset. | ||
|
||
Demo videos containing sound generation results can be found [here][demo]. | ||
|
||
 | ||
|
||
|
||
# Contents | ||
---- | ||
|
||
* [Usage Guide](#usage-guide) | ||
* [Getting Started](#getting-started) | ||
* [Installation](#installation) | ||
* [Download Datasets](#download-datasets) | ||
* [Data Preprocessing](#data-preprocessing) | ||
* [Training REGNET](#training-regnet) | ||
* [Generating Sound](#generating-sound) | ||
* [Other Info](#other-info) | ||
* [Citation](#citation) | ||
* [Contact](#contact) | ||
|
||
|
||
---- | ||
# Usage Guide | ||
|
||
## Getting Started | ||
[[back to top](#Generating-Visually-Aligned-Sound-from-Videos)] | ||
|
||
### Installation | ||
|
||
Clone this repository into a directory. We refer to that directory as *`REGNET_ROOT`*. | ||
|
||
```bash | ||
git clone https://github.com/PeihaoChen/regnet | ||
cd regnet | ||
``` | ||
Create a new Conda environment. | ||
```bash | ||
conda create -n regnet python=3.7.1 | ||
conda activate regnet | ||
``` | ||
Install [PyTorch][pytorch] and other dependencies. | ||
```bash | ||
conda install pytorch==1.2.0 torchvision==0.4.0 cudatoolkit=10.0 | ||
conda install ffmpeg -n regnet -c conda-forge | ||
pip install -r requirements.txt | ||
``` | ||
|
||
### Download Datasets | ||
|
||
In our paper, we collect 8 sound types (Dog, Fireworks, Drum, Baby form [VEGAS][vegas] and Gun, Sneeze, Cough, Hammer from [AudioSet][audioset]) to build our [Visually Aligned Sound (VAS)][VAS] dataset. | ||
Please first download VAS dataset and unzip the data to *`$REGNET_ROOT/data/`* folder. | ||
|
||
For each sound type in AudioSet, we download all videos from Youtube and clean data on Amazon Mechanical Turk (AMT) using the same way as [VEGAS][visual_to_sound]. | ||
|
||
|
||
```bash | ||
unzip ./data/VAS.zip -d ./data | ||
``` | ||
|
||
|
||
|
||
### Data Preprocessing | ||
|
||
Run `data_preprocess.sh` to preprocess data and extract RGB and optical flow features. | ||
|
||
Notice: The script we provided to calculate optical flow is easy to run but is resource-consuming and will take a long time. We strongly recommend you to refer to [TSN repository][TSN] and their built [docker image][TSN_docker] (our paper also uses this solution) to speed up optical flow extraction and to restrictly reproduce the results. | ||
```bash | ||
source data_preprocess.sh | ||
``` | ||
|
||
|
||
## Training REGNET | ||
|
||
Training the REGNET from scratch. The results will be saved to `ckpt/dog`. | ||
|
||
```bash | ||
CUDA_VISIBLE_DEVICES=7 python train.py \ | ||
save_dir ckpt/dog \ | ||
auxiliary_dim 64 \ | ||
rgb_feature_dir data/features/dog/feature_rgb_bninception_dim1024_21.5fps \ | ||
flow_feature_dir data/features/dog/feature_flow_bninception_dim1024_21.5fps \ | ||
mel_dir data/features/dog/melspec_10s_22050hz \ | ||
checkpoint_path '' | ||
``` | ||
|
||
In case that the program stops unexpectedly, you can continue training. | ||
```bash | ||
CUDA_VISIBLE_DEVICES=7 python train.py \ | ||
-c ckpt/dog/opts.yml \ | ||
checkpoint_path ckpt/dog/checkpoint_018081 | ||
``` | ||
|
||
## Generating Sound | ||
|
||
|
||
During inference, our RegNet will generate visually aligned spectrogram, and then use [WaveNet][wavenet] as vocoder to generate waveform from spectrogram. You should first download our trained WaveNet model for different sound categories ( | ||
[Dog](https://github.com/PeihaoChen/regnet/releases/download/WaveNet_model/dog_checkpoint_step000200000_ema.pth), | ||
[Fireworks](https://github.com/PeihaoChen/regnet/releases/download/WaveNet_model/fireworks_checkpoint_step000267000_ema.pth), | ||
[Drum](https://github.com/PeihaoChen/regnet/releases/download/WaveNet_model/drum_checkpoint_step000160000_ema.pth), | ||
[Baby](https://github.com/PeihaoChen/regnet/releases/download/WaveNet_model/baby_checkpoint_step000470000_ema.pth), | ||
[Gun](https://github.com/PeihaoChen/regnet/releases/download/WaveNet_model/gun_checkpoint_step000152000_ema.pth), | ||
[Sneeze](https://github.com/PeihaoChen/regnet/releases/download/WaveNet_model/sneeze_checkpoint_step000071000_ema.pth), | ||
[Cough](https://github.com/PeihaoChen/regnet/releases/download/WaveNet_model/cough_checkpoint_step000079000_ema.pth), | ||
[Hammer](https://github.com/PeihaoChen/regnet/releases/download/WaveNet_model/hammer_checkpoint_step000137000_ema.pth) | ||
). | ||
|
||
The generated spectrogram and waveform will be saved at `ckpt/dog/inference_result` | ||
```bash | ||
CUDA_VISIBLE_DEVICES=7 python test.py \ | ||
-c ckpt/dog/opts.yml \ | ||
aux_zero True \ | ||
checkpoint_path ckpt/dog/checkpoint_041000 \ | ||
save_dir ckpt/dog/inference_result \ | ||
wavenet_path /path/to/wavenet_dog.pth | ||
``` | ||
|
||
If you want to train your own WaveNet model, you can use [WaveNet repository][wavenet_repository]. | ||
```bash | ||
git clone https://github.com/r9y9/wavenet_vocoder && cd wavenet_vocoder | ||
git checkout 2092a64 | ||
``` | ||
|
||
Enjoy your experiments! | ||
|
||
|
||
# Other Info | ||
[[back to top](#Generating-Visually-Aligned-Sound-from-Videos)] | ||
|
||
## Citation | ||
|
||
|
||
Please cite the following paper if you feel REGNET useful to your research | ||
``` | ||
@Article{chen2020regnet, | ||
author = {Peihao Chen, Yang Zhang, Mingkui Tan, Hongdong Xiao, Deng Huang and Chuang Gan}, | ||
title = {Generating Visually Aligned Sound from Videos}, | ||
journal = {TIP}, | ||
year = {2020}, | ||
} | ||
``` | ||
|
||
## Contact | ||
For any question, please file an issue or contact | ||
``` | ||
Peihao Chen: phchencs@gmail.com | ||
Hongdong Xiao: xiaohongdonghd@gmail.com | ||
``` | ||
|
||
[REGNET]:https://arxiv.org/abs/2008.00820 | ||
[audioset]:https://research.google.com/audioset/index.html | ||
[VEGAS_link]:http://bvision11.cs.unc.edu/bigpen/yipin/visual2sound_webpage/VEGAS.zip | ||
[pytorch]:https://github.com/pytorch/pytorch | ||
[wavenet]:https://arxiv.org/abs/1609.03499 | ||
[wavenet_repository]:https://github.com/r9y9/wavenet_vocoder | ||
[opencv]:https://github.com/opencv/opencv | ||
[dense_flow]:https://github.com/yjxiong/dense_flow | ||
[VEGAS]: http://bvision11.cs.unc.edu/bigpen/yipin/visual2sound_webpage/visual2sound.html | ||
[visual_to_sound]: https://arxiv.org/abs/1712.01393 | ||
[TSN]: https://github.com/yjxiong/temporal-segment-networks | ||
[VAS]: https://drive.google.com/file/d/14birixmH7vwIWKxCHI0MIWCcZyohF59g/view?usp=sharing | ||
[TSN_docker]: https://hub.docker.com/r/bitxiong/tsn/tags | ||
[demo]: https://youtu.be/fI_h5mZG7bg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import os | ||
import shutil | ||
import time | ||
import sys | ||
|
||
|
||
class Recorder(object): | ||
def __init__(self, snapshot_pref, exclude_dirs=None, max_file_size=10): | ||
""" | ||
:param snapshot_pref: The dir you want to save the backups | ||
:param exclude_dirs: The dir name you want to exclude; eg ["results", "data"] | ||
:param max_file_size: The minimum size of backups file; unit is MB | ||
""" | ||
date = time.strftime('%Y%m%d%H%M%S', time.localtime(time.time())) | ||
if not os.path.isdir(snapshot_pref): | ||
os.makedirs(snapshot_pref, exist_ok=True) | ||
self.save_path = snapshot_pref | ||
self.log_file = self.save_path + "log.txt" | ||
self.readme = self.save_path + "README.md" | ||
self.opt_file = self.save_path + "opt.log" | ||
self.code_path = os.path.join(self.save_path, "code_{}/".format(date)) | ||
self.exclude_dirs = exclude_dirs | ||
self.max_file_size = max_file_size | ||
if os.path.isfile(self.readme): | ||
os.remove(self.readme) | ||
if not os.path.isdir(self.code_path): | ||
os.mkdir(self.code_path) | ||
self.copy_code(dst=self.code_path) | ||
self.tee_stdout(os.path.join(snapshot_pref, "log.txt")) | ||
print ("|===>Backups will be saved at", self.save_path) | ||
|
||
def copy_code(self, src="./", dst="./code/"): | ||
start_time = time.time() | ||
file_abs_list = [] | ||
src_abs = os.path.abspath(src) | ||
for root, dirs, files in os.walk(src_abs): | ||
exclude_flag = True in [root.find(exclude_dir)>=0 for exclude_dir in self.exclude_dirs] | ||
if not exclude_flag: | ||
for name in files: | ||
file_abs_list.append(root + "/" + name) | ||
|
||
for file_abs in file_abs_list: | ||
file_split = file_abs.split("/")[-1].split('.') | ||
# if len(file_split) >= 2 and file_split[1] == "py": | ||
if os.path.getsize(file_abs) / 1024 / 1024 < self.max_file_size and not file_split[-1] == "pyc": | ||
src_file = file_abs | ||
dst_file = dst + file_abs.replace(src_abs, "") | ||
if not os.path.exists(os.path.dirname(dst_file)): | ||
os.makedirs(os.path.dirname(dst_file)) | ||
# shutil.copyfile(src=src_file, dst=dst_file) | ||
try: | ||
shutil.copy2(src=src_file, dst=dst_file) | ||
except: | ||
print("copy file error") | ||
print("|===>Backups using time: %.3f s"%(time.time() - start_time)) | ||
|
||
def tee_stdout(self, log_path): | ||
log_file = open(log_path, 'a', 1) | ||
stdout = sys.stdout | ||
|
||
class Tee: | ||
|
||
def write(self, string): | ||
log_file.write(string) | ||
stdout.write(string) | ||
|
||
def flush(self): | ||
log_file.flush() | ||
stdout.flush() | ||
|
||
sys.stdout = Tee() | ||
|
||
def writeopt(self, opt): | ||
with open(self.opt_file, "w") as f: | ||
for k, v in opt.__dict__.items(): | ||
f.write(str(k)+": "+str(v)+"\n") | ||
|
||
def writelog(self, input_data): | ||
txt_file = open(self.log_file, 'a+') | ||
txt_file.write(str(input_data) + "\n") | ||
txt_file.close() | ||
|
||
def writereadme(self, input_data): | ||
txt_file = open(self.readme, 'a+') | ||
txt_file.write(str(input_data) + "\n") | ||
txt_file.close() | ||
|
||
def gennetwork(self, var): | ||
self.graph.draw(var=var) | ||
|
||
def savenetwork(self): | ||
self.graph.save(file_name=self.save_path+"network.svg") | ||
|
||
"""def writeweights(self, input_data, block_id, layer_id, epoch_id): | ||
txt_path = self.weight_folder + "conv_weight_" + str(epoch_id) + ".log" | ||
txt_file = open(txt_path, 'a+') | ||
write_str = "%d\t%d\t%d\t" % (epoch_id, block_id, layer_id) | ||
for x in input_data: | ||
write_str += str(x) + "\t" | ||
txt_file.write(write_str+"\n") | ||
def drawhist(self): | ||
drawer = DrawHistogram(txt_folder=self.weight_folder, fig_folder=self.weight_fig_folder) | ||
drawer.draw()""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from yacs.config import CfgNode as CN | ||
|
||
_C = CN() | ||
_C.epochs = 1000 | ||
_C.num_epoch_save = 10 | ||
_C.seed = 123 | ||
_C.dynamic_loss_scaling = True | ||
_C.dist_backend = "nccl" | ||
_C.dist_url = "tcp://localhost:54321" | ||
_C.cudnn_enabled = True | ||
_C.cudnn_benchmark = False | ||
_C.save_dir = 'ckpt/dog' | ||
_C.checkpoint_path = '' | ||
_C.epoch_count = 0 | ||
_C.exclude_dirs = ['ckpt', 'data'] | ||
_C.training_files = 'filelists/dog_train.txt' | ||
_C.test_files = 'filelists/dog_test.txt' | ||
_C.rgb_feature_dir = "data/features/dog/feature_rgb_bninception_dim1024_21.5fps" | ||
_C.flow_feature_dir = "data/features/dog/feature_flow_bninception_dim1024_21.5fps" | ||
_C.mel_dir = "data/features/dog/melspec_10s_22050hz" | ||
_C.video_samples = 215 | ||
_C.audio_samples = 10 | ||
_C.mel_samples = 860 | ||
_C.visual_dim = 2048 | ||
_C.n_mel_channels = 80 | ||
|
||
# Encoder parameters | ||
_C.random_z_dim = 512 | ||
_C.encoder_n_lstm = 2 | ||
_C.encoder_embedding_dim = 2048 | ||
_C.encoder_kernel_size = 5 | ||
_C.encoder_n_convolutions = 3 | ||
|
||
# Auxiliary parameters | ||
_C.auxiliary_type = "lstm_last" | ||
_C.auxiliary_dim = 256 | ||
_C.auxiliary_sample_rate = 32 | ||
_C.mode_input = "" | ||
_C.aux_zero = False | ||
|
||
# Decoder parameters | ||
_C.decoder_conv_dim = 1024 | ||
|
||
# Mel-post processing network parameters | ||
_C.postnet_embedding_dim = 512 | ||
_C.postnet_kernel_size = 5 | ||
_C.postnet_n_convolutions = 5 | ||
|
||
_C.loss_type = "MSE" | ||
_C.weight_decay = 1e-6 | ||
_C.grad_clip_thresh = 1.0 | ||
_C.batch_size = 64 | ||
_C.lr = 0.0002 | ||
_C.beta1 = 0.5 | ||
_C.continue_train = False | ||
_C.lambda_Oriloss = 10000.0 | ||
_C.lambda_Silenceloss = 0 | ||
_C.niter = 100 | ||
_C.D_interval = 1 | ||
_C.wo_G_GAN = False | ||
_C.wavenet_path = "" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from torch import nn | ||
|
||
|
||
class RegnetLoss(nn.Module): | ||
def __init__(self, loss_type): | ||
super(RegnetLoss, self).__init__() | ||
self.loss_type = loss_type | ||
print("Loss type: {}".format(self.loss_type)) | ||
|
||
def forward(self, model_output, targets): | ||
|
||
mel_target = targets | ||
mel_target.requires_grad = False | ||
mel_out, mel_out_postnet = model_output | ||
|
||
if self.loss_type == "MSE": | ||
loss_fn = nn.MSELoss() | ||
elif self.loss_type == "L1Loss": | ||
loss_fn = nn.L1Loss() | ||
else: | ||
print("ERROR LOSS TYPE!") | ||
|
||
mel_loss = loss_fn(mel_out, mel_target) + \ | ||
loss_fn(mel_out_postnet, mel_target) | ||
|
||
return mel_loss |
Oops, something went wrong.