Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Hallucination models with TF-IDF #4422

Closed
Youmna-H opened this issue Mar 14, 2022 · 5 comments
Closed

Hallucination models with TF-IDF #4422

Youmna-H opened this issue Mar 14, 2022 · 5 comments
Assignees

Comments

@Youmna-H
Copy link

Hi,
I am trying to fine-tune the models published in "Retrieval Augmentation Reduces Hallucination in Conversation" (e.g., zoo:hallucination/bart_rag_sequence/model) using TFIDF instead of DPR for retrieval. I am using the train_model cmd with the args: --init-model zoo:hallucination/bart_rag_sequence/model --tfidf-model-path my_tfidf_model. However, I get an exception at:
parlai/agents/rag/rag.py, line 507, in load_state_dict
self.model.load_state_dict(state_dict)
When I comment this line, it trains but then the model just generates garbage which means that something went really wrong. Any thoughts on how to solve this?

The paper mentions that one of the models was trained with TF-IDF to compare it with DPR, however the TF-IDF is not one of the publicly available. @klshuster could you advise on this?

Thanks

@klshuster klshuster self-assigned this Mar 14, 2022
@klshuster
Copy link
Contributor

can you please paste the full stack trace, as well as the command?

@Youmna-H
Copy link
Author

Thank you Kurt.
The command is:
parlai train_model --model rag --task mytask --rag-model-type sequence --rag-retriever-type tfidf --generation-model bart --init-model zoo:hallucination/bart_rag_sequence/model --batchsize 2 --fp16-impl mem_efficient --fp16 True --hnsw-indexer-store-n 512 --optimizer mem_eff_adam --model-file data/models/mytask/bart_rag_sequence_tfidf -compressed-indexer-factory None --indexer-type forced --path-to-index data/models/mytask/all_embeddings_seq.index --path-to-dpr-passages data/models/mytask/mytask_parlai.tsv --beam-size 3 --beam-min-length 20 --beam-block-ngram 3 --inference beam --force-fp16-tokens True --learningrate 1e-05 --truncate 512 --text-truncate 512 --label-truncate 128 --lr-scheduler-patience 1 --warmup-updates 0 --validation-metric-mode min --validation-every-n-epochs 0.25 --validation-max-exs 1000 --validation-metric ppl --validation-patience 5 --save-after-valid True

"mytask" is a knowledge-based task that I've created that is very similar to the wizard of wikipedia, but it trained on a different dialogue dataset and has a knowledge-base different from wikipedia. The stack trace is:

Traceback (most recent call last):
File "/ParlAI/parlai/agents/rag/rag.py", line 488, in load_state_dict
self.model.load_state_dict(state_dict)
File "/parlai_venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1498, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for RagModel:
Unexpected key(s) in state_dict: "retriever.query_encoder.embeddings.weight", "retriever.query_encoder.position_embeddings.weight", "retriever.query_encoder.norm_embeddings.weight", "retriever.query_encoder.norm_embeddings.bias", "retriever.query_encoder.segment_embeddings.weight", "retriever.query_encoder.layers.0.attention.q_lin.weight", "retriever.query_encoder.layers.0.attention.q_lin.bias", "retriever.query_encoder.layers.0.attention.k_lin.weight", "retriever.query_encoder.layers.0.attention.k_lin.bias", "retriever.query_encoder.layers.0.attention.v_lin.weight", "retriever.query_encoder.layers.0.attention.v_lin.bias", "retriever.query_encoder.layers.0.attention.out_lin.weight", "retriever.query_encoder.layers.0.attention.out_lin.bias", "retriever.query_encoder.layers.0.norm1.weight", "retriever.query_encoder.layers.0.norm1.bias", "retriever.query_encoder.layers.0.ffn.lin1.weight", "retriever.query_encoder.layers.0.ffn.lin1.bias", "retriever.query_encoder.layers.0.ffn.lin2.weight", "retriever.query_encoder.layers.0.ffn.lin2.bias", "retriever.query_encoder.layers.0.norm2.weight", "retriever.query_encoder.layers.0.norm2.bias", "retriever.query_encoder.layers.1.attention.q_lin.weight", "retriever.query_encoder.layers.1.attention.q_lin.bias", "retriever.query_encoder.layers.1.attention.k_lin.weight", "retriever.query_encoder.layers.1.attention.k_lin.bias", "retriever.query_encoder.layers.1.attention.v_lin.weight", "retriever.query_encoder.layers.1.attention.v_lin.bias", "retriever.query_encoder.layers.1.attention.out_lin.weight", "retriever.query_encoder.layers.1.attention.out_lin.bias", "retriever.query_encoder.layers.1.norm1.weight", "retriever.query_encoder.layers.1.norm1.bias", "retriever.query_encoder.layers.1.ffn.lin1.weight", "retriever.query_encoder.layers.1.ffn.lin1.bias", "retriever.query_encoder.layers.1.ffn.lin2.weight", "retriever.query_encoder.layers.1.ffn.lin2.bias", "retriever.query_encoder.layers.1.norm2.weight", "retriever.query_encoder.layers.1.norm2.bias", "retriever.query_encoder.layers.2.attention.q_lin.weight", "retriever.query_encoder.layers.2.attention.q_lin.bias", "retriever.query_encoder.layers.2.attention.k_lin.weight", "retriever.query_encoder.layers.2.attention.k_lin.bias", "retriever.query_encoder.layers.2.attention.v_lin.weight", "retriever.query_encoder.layers.2.attention.v_lin.bias", "retriever.query_encoder.layers.2.attention.out_lin.weight", "retriever.query_encoder.layers.2.attention.out_lin.bias", "retriever.query_encoder.layers.2.norm1.weight", "retriever.query_encoder.layers.2.norm1.bias", "retriever.query_encoder.layers.2.ffn.lin1.weight", "retriever.query_encoder.layers.2.ffn.lin1.bias", "retriever.query_encoder.layers.2.ffn.lin2.weight", "retriever.query_encoder.layers.2.ffn.lin2.bias", "retriever.query_encoder.layers.2.norm2.weight", "retriever.query_encoder.layers.2.norm2.bias", "retriever.query_encoder.layers.3.attention.q_lin.weight", "retriever.query_encoder.layers.3.attention.q_lin.bias", "retriever.query_encoder.layers.3.attention.k_lin.weight", "retriever.query_encoder.layers.3.attention.k_lin.bias", "retriever.query_encoder.layers.3.attention.v_lin.weight", "retriever.query_encoder.layers.3.attention.v_lin.bias", "retriever.query_encoder.layers.3.attention.out_lin.weight", "retriever.query_encoder.layers.3.attention.out_lin.bias", "retriever.query_encoder.layers.3.norm1.weight", "retriever.query_encoder.layers.3.norm1.bias", "retriever.query_encoder.layers.3.ffn.lin1.weight", "retriever.query_encoder.layers.3.ffn.lin1.bias", "retriever.query_encoder.layers.3.ffn.lin2.weight", "retriever.query_encoder.layers.3.ffn.lin2.bias", "retriever.query_encoder.layers.3.norm2.weight", "retriever.query_encoder.layers.3.norm2.bias", "retriever.query_encoder.layers.4.attention.q_lin.weight", "retriever.query_encoder.layers.4.attention.q_lin.bias", "retriever.query_encoder.layers.4.attention.k_lin.weight", "retriever.query_encoder.layers.4.attention.k_lin.bias", "retriever.query_encoder.layers.4.attention.v_lin.weight", "retriever.query_encoder.layers.4.attention.v_lin.bias", "retriever.query_encoder.layers.4.attention.out_lin.weight", "retriever.query_encoder.layers.4.attention.out_lin.bias", "retriever.query_encoder.layers.4.norm1.weight", "retriever.query_encoder.layers.4.norm1.bias", "retriever.query_encoder.layers.4.ffn.lin1.weight", "retriever.query_encoder.layers.4.ffn.lin1.bias", "retriever.query_encoder.layers.4.ffn.lin2.weight", "retriever.query_encoder.layers.4.ffn.lin2.bias", "retriever.query_encoder.layers.4.norm2.weight", "retriever.query_encoder.layers.4.norm2.bias", "retriever.query_encoder.layers.5.attention.q_lin.weight", "retriever.query_encoder.layers.5.attention.q_lin.bias", "retriever.query_encoder.layers.5.attention.k_lin.weight", "retriever.query_encoder.layers.5.attention.k_lin.bias", "retriever.query_encoder.layers.5.attention.v_lin.weight", "retriever.query_encoder.layers.5.attention.v_lin.bias", "retriever.query_encoder.layers.5.attention.out_lin.weight", "retriever.query_encoder.layers.5.attention.out_lin.bias", "retriever.query_encoder.layers.5.norm1.weight", "retriever.query_encoder.layers.5.norm1.bias", "retriever.query_encoder.layers.5.ffn.lin1.weight", "retriever.query_encoder.layers.5.ffn.lin1.bias", "retriever.query_encoder.layers.5.ffn.lin2.weight", "retriever.query_encoder.layers.5.ffn.lin2.bias", "retriever.query_encoder.layers.5.norm2.weight", "retriever.query_encoder.layers.5.norm2.bias", "retriever.query_encoder.layers.6.attention.q_lin.weight", "retriever.query_encoder.layers.6.attention.q_lin.bias", "retriever.query_encoder.layers.6.attention.k_lin.weight", "retriever.query_encoder.layers.6.attention.k_lin.bias", "retriever.query_encoder.layers.6.attention.v_lin.weight", "retriever.query_encoder.layers.6.attention.v_lin.bias", "retriever.query_encoder.layers.6.attention.out_lin.weight", "retriever.query_encoder.layers.6.attention.out_lin.bias", "retriever.query_encoder.layers.6.norm1.weight", "retriever.query_encoder.layers.6.norm1.bias", "retriever.query_encoder.layers.6.ffn.lin1.weight", "retriever.query_encoder.layers.6.ffn.lin1.bias", "retriever.query_encoder.layers.6.ffn.lin2.weight", "retriever.query_encoder.layers.6.ffn.lin2.bias", "retriever.query_encoder.layers.6.norm2.weight", "retriever.query_encoder.layers.6.norm2.bias", "retriever.query_encoder.layers.7.attention.q_lin.weight", "retriever.query_encoder.layers.7.attention.q_lin.bias", "retriever.query_encoder.layers.7.attention.k_lin.weight", "retriever.query_encoder.layers.7.attention.k_lin.bias", "retriever.query_encoder.layers.7.attention.v_lin.weight", "retriever.query_encoder.layers.7.attention.v_lin.bias", "retriever.query_encoder.layers.7.attention.out_lin.weight", "retriever.query_encoder.layers.7.attention.out_lin.bias", "retriever.query_encoder.layers.7.norm1.weight", "retriever.query_encoder.layers.7.norm1.bias", "retriever.query_encoder.layers.7.ffn.lin1.weight", "retriever.query_encoder.layers.7.ffn.lin1.bias", "retriever.query_encoder.layers.7.ffn.lin2.weight", "retriever.query_encoder.layers.7.ffn.lin2.bias", "retriever.query_encoder.layers.7.norm2.weight", "retriever.query_encoder.layers.7.norm2.bias", "retriever.query_encoder.layers.8.attention.q_lin.weight", "retriever.query_encoder.layers.8.attention.q_lin.bias", "retriever.query_encoder.layers.8.attention.k_lin.weight", "retriever.query_encoder.layers.8.attention.k_lin.bias", "retriever.query_encoder.layers.8.attention.v_lin.weight", "retriever.query_encoder.layers.8.attention.v_lin.bias", "retriever.query_encoder.layers.8.attention.out_lin.weight", "retriever.query_encoder.layers.8.attention.out_lin.bias", "retriever.query_encoder.layers.8.norm1.weight", "retriever.query_encoder.layers.8.norm1.bias", "retriever.query_encoder.layers.8.ffn.lin1.weight", "retriever.query_encoder.layers.8.ffn.lin1.bias", "retriever.query_encoder.layers.8.ffn.lin2.weight", "retriever.query_encoder.layers.8.ffn.lin2.bias", "retriever.query_encoder.layers.8.norm2.weight", "retriever.query_encoder.layers.8.norm2.bias", "retriever.query_encoder.layers.9.attention.q_lin.weight", "retriever.query_encoder.layers.9.attention.q_lin.bias", "retriever.query_encoder.layers.9.attention.k_lin.weight", "retriever.query_encoder.layers.9.attention.k_lin.bias", "retriever.query_encoder.layers.9.attention.v_lin.weight", "retriever.query_encoder.layers.9.attention.v_lin.bias", "retriever.query_encoder.layers.9.attention.out_lin.weight", "retriever.query_encoder.layers.9.attention.out_lin.bias", "retriever.query_encoder.layers.9.norm1.weight", "retriever.query_encoder.layers.9.norm1.bias", "retriever.query_encoder.layers.9.ffn.lin1.weight", "retriever.query_encoder.layers.9.ffn.lin1.bias", "retriever.query_encoder.layers.9.ffn.lin2.weight", "retriever.query_encoder.layers.9.ffn.lin2.bias", "retriever.query_encoder.layers.9.norm2.weight", "retriever.query_encoder.layers.9.norm2.bias", "retriever.query_encoder.layers.10.attention.q_lin.weight", "retriever.query_encoder.layers.10.attention.q_lin.bias", "retriever.query_encoder.layers.10.attention.k_lin.weight", "retriever.query_encoder.layers.10.attention.k_lin.bias", "retriever.query_encoder.layers.10.attention.v_lin.weight", "retriever.query_encoder.layers.10.attention.v_lin.bias", "retriever.query_encoder.layers.10.attention.out_lin.weight", "retriever.query_encoder.layers.10.attention.out_lin.bias", "retriever.query_encoder.layers.10.norm1.weight", "retriever.query_encoder.layers.10.norm1.bias", "retriever.query_encoder.layers.10.ffn.lin1.weight", "retriever.query_encoder.layers.10.ffn.lin1.bias", "retriever.query_encoder.layers.10.ffn.lin2.weight", "retriever.query_encoder.layers.10.ffn.lin2.bias", "retriever.query_encoder.layers.10.norm2.weight", "retriever.query_encoder.layers.10.norm2.bias", "retriever.query_encoder.layers.11.attention.q_lin.weight", "retriever.query_encoder.layers.11.attention.q_lin.bias", "retriever.query_encoder.layers.11.attention.k_lin.weight", "retriever.query_encoder.layers.11.attention.k_lin.bias", "retriever.query_encoder.layers.11.attention.v_lin.weight", "retriever.query_encoder.layers.11.attention.v_lin.bias", "retriever.query_encoder.layers.11.attention.out_lin.weight", "retriever.query_encoder.layers.11.attention.out_lin.bias", "retriever.query_encoder.layers.11.norm1.weight", "retriever.query_encoder.layers.11.norm1.bias", "retriever.query_encoder.layers.11.ffn.lin1.weight", "retriever.query_encoder.layers.11.ffn.lin1.bias", "retriever.query_encoder.layers.11.ffn.lin2.weight", "retriever.query_encoder.layers.11.ffn.lin2.bias", "retriever.query_encoder.layers.11.norm2.weight", "retriever.query_encoder.layers.11.norm2.bias".

@klshuster
Copy link
Contributor

Thanks, I've put a fix up in #4436

@Youmna-H
Copy link
Author

Thank you so much Kurt!

@klshuster
Copy link
Contributor

going to close for now, please reopen if you run into further issues here

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants