Skip to content

Unsupervised video summarization with deep reinforcement learning (AAAI'18)

License

Notifications You must be signed in to change notification settings

KaiyangZhou/pytorch-vsumm-reinforce

Repository files navigation

pytorch-vsumm-reinforce

This repo contains the Pytorch implementation of the AAAI'18 paper - Deep Reinforcement Learning for Unsupervised Video Summarization with Diversity-Representativeness Reward. The original Theano implementation can be found here.

train

The main requirements are pytorch (v0.4.0) and python 2.7. Some dependencies that may not be installed in your machine are tabulate and h5py. Please install other missing dependencies.

Get started

  1. Download preprocessed datasets
git clone https://github.com/KaiyangZhou/pytorch-vsumm-reinforce
cd pytorch-vsumm-reinforce
# download datasets.tar.gz (173.5MB)
wget http://www.eecs.qmul.ac.uk/~kz303/vsumm-reinforce/datasets.tar.gz
tar -xvzf datasets.tar.gz

Updates: The QMUL server is inaccessible. Download the datasets from this google drive link.

  1. Make splits
python create_split.py -d datasets/eccv16_dataset_summe_google_pool5.h5 --save-dir datasets --save-name summe_splits  --num-splits 5

As a result, the dataset is randomly split for 5 times, which are saved as json file.

Train and test codes are written in main.py. To see the detailed arguments, please do python main.py -h.

How to train

python main.py -d datasets/eccv16_dataset_summe_google_pool5.h5 -s datasets/summe_splits.json -m summe --gpu 0 --save-dir log/summe-split0 --split-id 0 --verbose

How to test

python main.py -d datasets/eccv16_dataset_summe_google_pool5.h5 -s datasets/summe_splits.json -m summe --gpu 0 --save-dir log/summe-split0 --split-id 0 --evaluate --resume path_to_your_model.pth.tar --verbose --save-results

If argument --save-results is enabled, output results will be saved to results.h5 under the same folder specified by --save-dir. To visualize the score-vs-gtscore, simple do

python visualize_results.py -p path_to/result.h5

Plot

We provide codes to plot the rewards obtained at each epoch. Use parse_log.py to plot the average rewards

python parse_log.py -p path_to/log_train.txt

The plotted image would look like

overall_reward

If you wanna plot the epoch-reward curve for some specific videos, do

python parse_json.py -p path_to/rewards.json -i 0

You will obtain images like

epoch_reward epoch_reward epoch_reward

If you prefer to visualize the epoch-reward curve for all training videos, try parse_json.sh. Modify the code according to your purpose.

Visualize summary

You can use summary2video.py to transform the binary machine_summary to real summary video. You need to have a directory containing video frames. The code will automatically write summary frames to a video where the frame rate can be controlled. Use the following command to generate a .mp4 video

python summary2video.py -p path_to/result.h5 -d path_to/video_frames -i 0 --fps 30 --save-dir log --save-name summary.mp4

Please remember to specify the naming format of your video frames on this line.

How to use your own data

We preprocess data by extracting image features for videos and save them to h5 file. The file format looks like this. After that, you can make split via create_split.py. If you wanna train policy network using the entire dataset, just do train_keys = dataset.keys(). Here is the code where we initialize dataset. If you have any problems, feel free to contact me by email or raise an issue.

Citation

@article{zhou2017reinforcevsumm, 
   title={Deep Reinforcement Learning for Unsupervised Video Summarization with Diversity-Representativeness Reward},
   author={Zhou, Kaiyang and Qiao, Yu and Xiang, Tao}, 
   journal={arXiv:1801.00054}, 
   year={2017} 
}