Skip to content

Commit

Permalink
Prototxts + script for training COCO caption language model
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffdonahue committed Feb 17, 2015
1 parent 33f5d74 commit 8d78878
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 0 deletions.
149 changes: 149 additions & 0 deletions examples/coco_caption/lstm_language_model.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
name: "lstm_language_model"
layer {
name: "data"
type: "HDF5Data"
top: "cont_sentence"
top: "input_sentence"
top: "target_sentence"
include { phase: TRAIN }
hdf5_data_param {
source: "./examples/coco_caption/h5_data/buffer_100/train_unaligned_batches/hdf5_chunk_list.txt"
batch_size: 20
}
}
layer {
name: "data"
type: "HDF5Data"
top: "cont_sentence"
top: "input_sentence"
top: "target_sentence"
include {
phase: TEST
stage: "test-on-train"
}
hdf5_data_param {
source: "./examples/coco_caption/h5_data/buffer_100/train_unaligned_batches/hdf5_chunk_list.txt"
batch_size: 20
}
}
layer {
name: "data"
type: "HDF5Data"
top: "cont_sentence"
top: "input_sentence"
top: "target_sentence"
include {
phase: TEST
stage: "test-on-val"
}
hdf5_data_param {
source: "./examples/coco_caption/h5_data/buffer_100/val_unaligned_batches/hdf5_chunk_list.txt"
batch_size: 20
}
}
layer {
name: "embedding"
type: "Embed"
bottom: "input_sentence"
top: "embedded_input_sentence"
param {
lr_mult: 1
}
embed_param {
bias_term: false
input_dim: 8801 # = vocab_size + 1 (for EOS)
num_output: 1000
weight_filler {
type: "uniform"
min: -0.08
max: 0.08
}
}
}
layer {
name: "embed-drop"
type: "Dropout"
bottom: "embedded_input_sentence"
top: "embedded_input_sentence"
dropout_param { dropout_ratio: 0.5 }
include { stage: "embed-drop" }
}
layer {
name: "lstm1"
type: "LSTM"
bottom: "embedded_input_sentence"
bottom: "cont_sentence"
top: "lstm1"
recurrent_param {
num_output: 1000
weight_filler {
type: "uniform"
min: -0.08
max: 0.08
}
bias_filler {
type: "constant"
value: 0
}
}
}
layer {
name: "lstm-drop"
type: "Dropout"
bottom: "lstm1"
top: "lstm1"
dropout_param { dropout_ratio: 0.5 }
include { stage: "lstm-drop" }
}
layer {
name: "predict"
type: "InnerProduct"
bottom: "lstm1"
top: "predict"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
inner_product_param {
num_output: 8801 # = vocab_size + 1 (+1 for EOS)
weight_filler {
type: "uniform"
min: -0.08
max: 0.08
}
bias_filler {
type: "constant"
value: 0
}
axis: 2
}
}
layer {
name: "cross_entropy_loss"
type: "SoftmaxWithLoss"
bottom: "predict"
bottom: "target_sentence"
top: "cross_entropy_loss"
loss_weight: 20
loss_param {
ignore_label: -1
}
softmax_param {
axis: 2
}
}
layer {
name: "accuracy"
type: "Accuracy"
bottom: "predict"
bottom: "target_sentence"
top: "accuracy"
include { phase: TEST }
loss_param {
ignore_label: -1
}
}
21 changes: 21 additions & 0 deletions examples/coco_caption/lstm_lm_solver.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
net: "./examples/coco_caption/lstm_language_model.prototxt"
train_state: { stage: 'embed-drop' stage: 'lstm-drop' }
test_iter: 25
test_state: { stage: 'test-on-train' }
test_iter: 25
test_state: { stage: 'test-on-val' }
test_interval: 100
base_lr: 0.1
lr_policy: "step"
gamma: 0.5
stepsize: 20000
display: 1
max_iter: 110000
momentum: 0.9
weight_decay: 0.0000
snapshot: 5000
snapshot_prefix: "./examples/coco_caption/lstm_lm"
solver_mode: GPU
random_seed: 1701
average_loss: 100
clip_gradients: 10
14 changes: 14 additions & 0 deletions examples/coco_caption/train_language_model.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/usr/bin/env bash

GPU_ID=0
DATA_DIR=./examples/coco_caption/h5_data/
if [ ! -d $DATA_DIR ]; then
echo "Data directory not found: $DATA_DIR"
echo "First, download the COCO dataset (follow instructions in data/coco)"
echo "Then, run ./examples/coco_caption/coco_to_hdf5_data.py to create the Caffe input data"
exit 1
fi

./build/tools/caffe train \
-solver ./examples/coco_caption/lstm_lm_solver.prototxt \
-gpu $GPU_ID

0 comments on commit 8d78878

Please sign in to comment.