-
Notifications
You must be signed in to change notification settings - Fork 25
/
decode.sh
48 lines (42 loc) · 1.54 KB
/
decode.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#! /bin/bash
# Copyright (c) Meta Platforms, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
LANG=en # language direction (e.g 'en' for VSR task / 'en-es' for En to Es VST task)
# set paths
ROOT=$(dirname "$(dirname "$(readlink -fm "$0")")")
MODEL_SRC=${ROOT}/src
LLM_PATH=${ROOT}/checkpoints/Llama-2-7b-hf # path to llama checkpoint
DATA_ROOT=${MODEL_SRC}/dataset # path to test dataset dir
MODEL_PATH=${ROOT}/checkpoints/checkpoint_finetune.pt # path to trained model
OUT_PATH=${ROOT}/decode # output path to save
# fix variables based on langauge
if [[ $LANG == *"-"* ]] ; then
TASK="vst"
IFS='-' read -r SRC TGT <<< ${LANG}
USE_BLEU=true
DATA_PATH=${DATA_ROOT}/${TASK}/${SRC}/${TGT}
else
TASK="vsr"
TGT=${LANG}
USE_BLEU=false
DATA_PATH=${DATA_ROOT}/${TASK}/${LANG}
fi
# start decoding
export PYTHONPATH="${ROOT}/fairseq:$PYTHONPATH"
CUDA_VISIBLE_DEVICES=0 python -B ${MODEL_SRC}/vsp_llm_decode.py \
--config-dir ${MODEL_SRC}/conf \
--config-name s2s_decode \
common.user_dir=${MODEL_SRC} \
dataset.gen_subset=test \
override.data=${DATA_PATH} \
override.label_dir=${DATA_PATH} \
generation.beam=20 \
generation.lenpen=0 \
dataset.max_tokens=3000 \
override.eval_bleu=${USE_BLEU} \
override.llm_ckpt_path=${LLM_PATH} \
common_eval.path=${MODEL_PATH} \
common_eval.results_path=${OUT_PATH}/${TASK}/${LANG}