diff --git a/parlai/agents/rag/retrievers.py b/parlai/agents/rag/retrievers.py index bfff0fb77ba..5ec0aa596c3 100644 --- a/parlai/agents/rag/retrievers.py +++ b/parlai/agents/rag/retrievers.py @@ -641,12 +641,11 @@ def index_retrieve( # recompute exact FAISS scores scores = torch.bmm(query.unsqueeze(1), vectors.transpose(1, 2)).squeeze(1) if torch.isnan(scores).sum().item(): - logging.error( - '\n[ Document scores are NaN; please look into the built index. ]\n' - '[ If using a compressed index, try building an exact index: ]\n' - '[ $ python index_dense_embeddings --indexer-type exact... ]' + raise AssertionError( + '\n[ Document scores are NaN; please make sure the passages file does not have repeated entries.]\n' + '[ Also set --num-shards to small values during generating dense embeddings: ]\n' + '[ e.g., when --shard-id is 0, --num-shards should be a value < int(len(rows of the passages file)/(num_of_maximum_passages_you_want_to_retrieve-1)).]' ) - scores.fill_(1) ids = torch.tensor([[int(s) for s in ss] for ss in ids]) return ids, scores