From 7020e63c947692256536645eac18cf55bb8313b7 Mon Sep 17 00:00:00 2001 From: t1101675 Date: Sun, 6 Aug 2023 06:55:51 -0700 Subject: [PATCH] fix retriever training bug --- README.md | 2 +- data_utils/retriever_datasets.py | 3 ++- scripts/retriever/train.sh | 4 +++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 534e904..6ddaf99 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ M H ≤ ℏ c 3 8 π G k B T u {\displaystyle M_{\mathrm {H} }\leq {\frac {\hbar -
Question Ansering +
Question Answering ``` ########## Query ########## diff --git a/data_utils/retriever_datasets.py b/data_utils/retriever_datasets.py index 29f2141..9f1bd30 100644 --- a/data_utils/retriever_datasets.py +++ b/data_utils/retriever_datasets.py @@ -139,7 +139,8 @@ def load_data(self, path): "easy_neg_contexts": easy_neg_contexts, "hard_neg_contexts": hard_neg_contexts, "label": line["label"], - "neg_labels": line["neg_labels"] + "easy_neg_labels": line["easy_neg_labels"], + "hard_neg_labels": line["hard_neg_labels"], }) return data diff --git a/scripts/retriever/train.sh b/scripts/retriever/train.sh index 6d639c0..4cc57ed 100644 --- a/scripts/retriever/train.sh +++ b/scripts/retriever/train.sh @@ -10,7 +10,7 @@ MASTER_PORT=${3-2010} MODEL_NAME="roberta-base" MODEL_DIR="${WORKING_DIR}/checkpoints/${MODEL_NAME}/" # data -DATA_NAME="${4-TRAIN/p1_en1_hn1_s42}" +DATA_NAME="${4-TRAIN/p1_en1_hn4_s42}" DATA_DIR="${WORKING_DIR}/retriever_data/${DATA_NAME}/merge" # hp BATCH_SIZE=64 @@ -39,6 +39,8 @@ OPTS+=" --gradient-accumulation-steps ${GRAD_ACC}" OPTS+=" --eval-batch-size ${EVAL_BATCH_SIZE}" OPTS+=" --epochs ${EPOCHS}" OPTS+=" --max-length 256" +OPTS+=" --save-interval -1" +OPTS+=" --eval-interval -1" # runtime OPTS+=" --do-train" OPTS+=" --save ${SAVE_PATH}"