Code for the paper AdaCAD: Adaptively Decoding to Balance Conflicts between Contextual and Parametric Knowledge.
by Han Wang, Archiki Prasad, Elias Stengel-Eskin, Mohit Bansal.
You can install all required packages by running the following command:
pip install -r requirements.txt
We provide three sample input files nq_swap_2_-1.jsonl
, nq_synth_2_-1.jsonl
, and tofu_1.5_-0.5.jsonl
in data
folder. The details are described in data/README.md
.
HF_TOKEN=your_huggingface_token # User Access Token to authenticate to the Hub.
HF_HUB_CACHE=your_cache_path # where repositories from the Hub will be cached locally (models, datasets and spaces).
bash run_qa.sh /path/to/your/input/file
As an exampe, run the following command:
bash run_qa.sh data/nq_swap_2_-1.json
We explain the arguments in run_qa.sh
as follows:
GLOBALLEN
: the maximum sequence length of the model.MAXCTXLEN
: the maximum input context length.GENLEN
: the maximun generation length, should beGENLEN = GLOBALLEN - MAXCTXLEN
.SEED
: random seed.DEVICE
: the GPU device ids, for example,0,1
.TOPP
: top-p sampling, set to 0.0 for greedy decoding.GPUS
: number of gpus.FLAG
: whether to use int4 quantization to load the model.
Note: Remember to use your own huggingface token and set your local cache path.
HF_TOKEN=your_huggingface_token # User Access Token to authenticate to the Hub.
HF_HUB_CACHE=your_cache_path # where repositories from the Hub will be cached locally (models, datasets and spaces).
bash run_summ.sh /path/to/your/input/file
As an exampe, run the following command:
bash run_summ.sh tofu_1.5_-0.5.jsonl
The aguments are the same as those in run_qa.sh
, except that the new argument THRESHOLD
is added to set the threshold for the alpha
as warmup operation for long-form generation.
You can use the following code snippet to compute the JSD value and then adjust the output probability distribution during decoding.
import torch
import torch.nn.functional as F
def get_jsd(p, q):
p = F.softmax(p, dim=-1)
q = F.softmax(q, dim=-1)
p, q = p.view(-1, p.size(-1)), q.view(-1, q.size(-1))
if ((p + q) == 0).any():
m = (0.5 * (p + q)).clamp_min(1e-9).log()
else:
m = (0.5 * (p + q)).log()
if torch.any(p <= 0):
p = p.clamp_min(1e-9)
if torch.any(q <= 0):
q = q.clamp_min(1e-9)
return 0.5 * (F.kl_div(m, p, reduction='batchmean', log_target=False) + F.kl_div(m, q, reduction='batchmean', log_target=False))
# logits1 is the output logits of the input with context
# logits2 is the output logits of the input without context
alpha = get_jsd(logits1, logits2)
new_logits1 = (1 + alpha) * logits1 + (0 - alpha) * logits2
We sincerely thank the authors of CAD for their public code release.
@article{wang2024adacad,
title={AdaCAD: Adaptively Decoding to Balance Conflicts between Contextual and Parametric Knowledge},
author={Han Wang and Archiki Prasad and Elias Stengel-Eskin and Mohit Bansal},
year={2024},
journal={arXiv preprint arXiv:2409.07394}
}