forked from collin-burns/discovering_latent_knowledge
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgenerate.py
38 lines (32 loc) · 2.17 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from utils import get_parser, load_model, get_all_hidden_states, save_generations, get_all_hidden_states_context_both
from dataset import get_dataloader
def main(args):
# Set up the model and data
print("Loading model")
model, tokenizer, model_type = load_model(args.model_name, args.cache_dir, args.parallelize, args.device)
print("Loading dataloader")
# dataloader = get_dataloader(args.dataset_name, args.split, tokenizer, args.prompt_idx, batch_size=args.batch_size,
# num_examples=args.num_examples, model_type=model_type, use_decoder=args.use_decoder, device=args.device)
dataloader = get_dataloader(args.dataset_name, args.split, tokenizer, batch_size=args.batch_size,
num_examples=args.num_examples, context_num=args.context_num, corrupt_prob=args.corrupt_prob, context_both=args.context_both,
model_type=model_type, use_decoder=args.use_decoder, device=args.device)
# Get the hidden states and labels
print("Generating hidden states")
if args.context_both:
print("generate both")
neg_hs, pos_hs, neg_non_hs, pos_non_hs, y = get_all_hidden_states_context_both(model, dataloader, layer=args.layer, all_layers=args.all_layers,
token_idx=args.token_idx, model_type=model_type, use_decoder=args.use_decoder)
save_generations(neg_non_hs, args, generation_type="neg_non_hidden_states")
save_generations(pos_non_hs, args, generation_type="pos_non_hidden_states")
else:
neg_hs, pos_hs, y = get_all_hidden_states(model, dataloader, layer=args.layer, all_layers=args.all_layers,
token_idx=args.token_idx, model_type=model_type, use_decoder=args.use_decoder)
# Save the hidden states and labels
print("Saving hidden states")
save_generations(neg_hs, args, generation_type="negative_hidden_states")
save_generations(pos_hs, args, generation_type="positive_hidden_states")
save_generations(y, args, generation_type="labels")
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
main(args)