You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I used code from your example and generate method of llama2-chat-7B-TruthX produces different output than base llama2-chat-7B. However I have a problem when I try to extract the logits (and token probabilities) of llama2-chat-7B-TruthX... In probabilities, there is almost no difference to the base llama. With outputs generated by the model being so different, the difference in token probabilities (between base and TruthX model) should also be significant. Could you help me on that?
That's the code that I use to extract the token probabilites and save them to file.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# comment / uncomment to get probs of TruthX / base model
llama2chat = "Llama-2-7b-chat-TruthX" # downloaded locally from 'https://huggingface.co/ICTNLP/Llama-2-7b-chat-TruthX'
# llama2chat = "daryl149/llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(
llama2chat, trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
llama2chat_with,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
trust_remote_code=True,
).cuda()
question = "What are the benefits of eating an apple a day?"
# using TruthfulQA prompt
from llm import PROF_PRIMER as TRUTHFULQA_PROMPT
encoded_inputs = tokenizer(TRUTHFULQA_PROMPT.format(question), return_tensors="pt")[
"input_ids"
]
encoded_inputs = tokenizer(question, return_tensors="pt")["input_ids"]
outputs = model.generate(encoded_inputs.cuda(), max_new_tokens=4000)[0, encoded_inputs.shape[-1] :]
outputs_text = (
tokenizer.decode(outputs, skip_special_tokens=True).split("Q:")[0].strip()
)
print(outputs_text)
# save probs over tokens
with torch.no_grad():
logits = model(encoded_inputs.cuda()).logits.cpu().type(torch.float32)
probs = torch.nn.functional.softmax(logits, dim=-1)
torch.save(probs, f"{llama2chat.replace('/', '_')}.pt")
The text was updated successfully, but these errors were encountered:
I used code from your example and
generate
method of llama2-chat-7B-TruthX produces different output than base llama2-chat-7B. However I have a problem when I try to extract the logits (and token probabilities) of llama2-chat-7B-TruthX... In probabilities, there is almost no difference to the base llama. With outputs generated by the model being so different, the difference in token probabilities (between base and TruthX model) should also be significant. Could you help me on that?That's the code that I use to extract the token probabilites and save them to file.
The text was updated successfully, but these errors were encountered: