diff --git a/examples/coco_caption/lstm_language_model.prototxt b/examples/coco_caption/lstm_language_model.prototxt new file mode 100644 index 00000000000..3cf4f6a686f --- /dev/null +++ b/examples/coco_caption/lstm_language_model.prototxt @@ -0,0 +1,149 @@ +name: "lstm_language_model" +layer { + name: "data" + type: "HDF5Data" + top: "cont_sentence" + top: "input_sentence" + top: "target_sentence" + include { phase: TRAIN } + hdf5_data_param { + source: "./examples/coco_caption/h5_data/buffer_100/train_unaligned_batches/hdf5_chunk_list.txt" + batch_size: 20 + } +} +layer { + name: "data" + type: "HDF5Data" + top: "cont_sentence" + top: "input_sentence" + top: "target_sentence" + include { + phase: TEST + stage: "test-on-train" + } + hdf5_data_param { + source: "./examples/coco_caption/h5_data/buffer_100/train_unaligned_batches/hdf5_chunk_list.txt" + batch_size: 20 + } +} +layer { + name: "data" + type: "HDF5Data" + top: "cont_sentence" + top: "input_sentence" + top: "target_sentence" + include { + phase: TEST + stage: "test-on-val" + } + hdf5_data_param { + source: "./examples/coco_caption/h5_data/buffer_100/val_unaligned_batches/hdf5_chunk_list.txt" + batch_size: 20 + } +} +layer { + name: "embedding" + type: "Embed" + bottom: "input_sentence" + top: "embedded_input_sentence" + param { + lr_mult: 1 + } + embed_param { + bias_term: false + input_dim: 8801 # = vocab_size + 1 (for EOS) + num_output: 1000 + weight_filler { + type: "uniform" + min: -0.08 + max: 0.08 + } + } +} +layer { + name: "embed-drop" + type: "Dropout" + bottom: "embedded_input_sentence" + top: "embedded_input_sentence" + dropout_param { dropout_ratio: 0.5 } + include { stage: "embed-drop" } +} +layer { + name: "lstm1" + type: "LSTM" + bottom: "embedded_input_sentence" + bottom: "cont_sentence" + top: "lstm1" + recurrent_param { + num_output: 1000 + weight_filler { + type: "uniform" + min: -0.08 + max: 0.08 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layer { + name: "lstm-drop" + type: "Dropout" + bottom: "lstm1" + top: "lstm1" + dropout_param { dropout_ratio: 0.5 } + include { stage: "lstm-drop" } +} +layer { + name: "predict" + type: "InnerProduct" + bottom: "lstm1" + top: "predict" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 0 + } + inner_product_param { + num_output: 8801 # = vocab_size + 1 (+1 for EOS) + weight_filler { + type: "uniform" + min: -0.08 + max: 0.08 + } + bias_filler { + type: "constant" + value: 0 + } + axis: 2 + } +} +layer { + name: "cross_entropy_loss" + type: "SoftmaxWithLoss" + bottom: "predict" + bottom: "target_sentence" + top: "cross_entropy_loss" + loss_weight: 20 + loss_param { + ignore_label: -1 + } + softmax_param { + axis: 2 + } +} +layer { + name: "accuracy" + type: "Accuracy" + bottom: "predict" + bottom: "target_sentence" + top: "accuracy" + include { phase: TEST } + loss_param { + ignore_label: -1 + } +} diff --git a/examples/coco_caption/lstm_lm_solver.prototxt b/examples/coco_caption/lstm_lm_solver.prototxt new file mode 100644 index 00000000000..09f61b4cc04 --- /dev/null +++ b/examples/coco_caption/lstm_lm_solver.prototxt @@ -0,0 +1,21 @@ +net: "./examples/coco_caption/lstm_language_model.prototxt" +train_state: { stage: 'embed-drop' stage: 'lstm-drop' } +test_iter: 25 +test_state: { stage: 'test-on-train' } +test_iter: 25 +test_state: { stage: 'test-on-val' } +test_interval: 100 +base_lr: 0.1 +lr_policy: "step" +gamma: 0.5 +stepsize: 20000 +display: 1 +max_iter: 110000 +momentum: 0.9 +weight_decay: 0.0000 +snapshot: 5000 +snapshot_prefix: "./examples/coco_caption/lstm_lm" +solver_mode: GPU +random_seed: 1701 +average_loss: 100 +clip_gradients: 10 diff --git a/examples/coco_caption/train_language_model.sh b/examples/coco_caption/train_language_model.sh new file mode 100755 index 00000000000..6e8a8c47b37 --- /dev/null +++ b/examples/coco_caption/train_language_model.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash + +GPU_ID=0 +DATA_DIR=./examples/coco_caption/h5_data/ +if [ ! -d $DATA_DIR ]; then + echo "Data directory not found: $DATA_DIR" + echo "First, download the COCO dataset (follow instructions in data/coco)" + echo "Then, run ./examples/coco_caption/coco_to_hdf5_data.py to create the Caffe input data" + exit 1 +fi + +./build/tools/caffe train \ + -solver ./examples/coco_caption/lstm_lm_solver.prototxt \ + -gpu $GPU_ID