Skip to content

Commit

Permalink
fix retriever training bug
Browse files Browse the repository at this point in the history
  • Loading branch information
t1101675 committed Aug 6, 2023
1 parent a263302 commit 7020e63
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ M H ≤ ℏ c 3 8 π G k B T u {\displaystyle M_{\mathrm {H} }\leq {\frac {\hbar
</details>


<details><summary><b>Question Ansering</b></summary>
<details><summary><b>Question Answering</b></summary>

```
########## Query ##########
Expand Down
3 changes: 2 additions & 1 deletion data_utils/retriever_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def load_data(self, path):
"easy_neg_contexts": easy_neg_contexts,
"hard_neg_contexts": hard_neg_contexts,
"label": line["label"],
"neg_labels": line["neg_labels"]
"easy_neg_labels": line["easy_neg_labels"],
"hard_neg_labels": line["hard_neg_labels"],
})

return data
Expand Down
4 changes: 3 additions & 1 deletion scripts/retriever/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ MASTER_PORT=${3-2010}
MODEL_NAME="roberta-base"
MODEL_DIR="${WORKING_DIR}/checkpoints/${MODEL_NAME}/"
# data
DATA_NAME="${4-TRAIN/p1_en1_hn1_s42}"
DATA_NAME="${4-TRAIN/p1_en1_hn4_s42}"
DATA_DIR="${WORKING_DIR}/retriever_data/${DATA_NAME}/merge"
# hp
BATCH_SIZE=64
Expand Down Expand Up @@ -39,6 +39,8 @@ OPTS+=" --gradient-accumulation-steps ${GRAD_ACC}"
OPTS+=" --eval-batch-size ${EVAL_BATCH_SIZE}"
OPTS+=" --epochs ${EPOCHS}"
OPTS+=" --max-length 256"
OPTS+=" --save-interval -1"
OPTS+=" --eval-interval -1"
# runtime
OPTS+=" --do-train"
OPTS+=" --save ${SAVE_PATH}"
Expand Down

0 comments on commit 7020e63

Please sign in to comment.