Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Update retrievers.py
Browse files Browse the repository at this point in the history
Raise AssertionError when `NaN` values appear.
  • Loading branch information
jianguoz authored Aug 4, 2021
1 parent 3b90417 commit d7dff60
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions parlai/agents/rag/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,12 +641,11 @@ def index_retrieve(
# recompute exact FAISS scores
scores = torch.bmm(query.unsqueeze(1), vectors.transpose(1, 2)).squeeze(1)
if torch.isnan(scores).sum().item():
logging.error(
'\n[ Document scores are NaN; please look into the built index. ]\n'
'[ If using a compressed index, try building an exact index: ]\n'
'[ $ python index_dense_embeddings --indexer-type exact... ]'
raise AssertionError(
'\n[ Document scores are NaN; please make sure the passages file does not have repeated entries.]\n'
'[ Also set --num-shards to small values during generating dense embeddings: ]\n'
'[ e.g., when --shard-id is 0, --num-shards should be a value < int(len(rows of the passages file)/(num_of_maximum_passages_you_want_to_retrieve-1)).]'
)
scores.fill_(1)
ids = torch.tensor([[int(s) for s in ss] for ss in ids])

return ids, scores
Expand Down

0 comments on commit d7dff60

Please sign in to comment.