-
Notifications
You must be signed in to change notification settings - Fork 108
/
test_fatezero_dataset.py
52 lines (40 loc) · 2.67 KB
/
test_fatezero_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from test_fatezero import *
from glob import glob
import copy
@click.command()
@click.option("--edit_config", type=str, default="config/supp/style/0313_style_edit_warp_640.yaml")
@click.option("--dataset_config", type=str, default="data/supp_edit_dataset/dataset_prompt.yaml")
def run(edit_config, dataset_config):
Omegadict_edit_config = OmegaConf.load(edit_config)
Omegadict_dataset_config = OmegaConf.load(dataset_config)
# Go trough all data sample
data_sample_list = sorted(Omegadict_dataset_config.keys())
print(f'Datasample to evaluate: {data_sample_list}')
dataset_time_string = get_time_string()
for data_sample in data_sample_list:
print(f'Evaluate {data_sample}')
for p2p_config_index, p2p_config in Omegadict_edit_config['editing_config']['p2p_config'].items():
edit_config_now = copy.deepcopy(Omegadict_edit_config)
edit_config_now['dataset_config'] = copy.deepcopy(Omegadict_dataset_config[data_sample])
edit_config_now['dataset_config'].pop('target')
if 'eq_params' in edit_config_now['dataset_config']:
edit_config_now['dataset_config'].pop('eq_params')
# edit_config_now['dataset_config']['prompt'] = Omegadict_dataset_config[data_sample]['source']
edit_config_now['editing_config']['editing_prompts'] \
= copy.deepcopy( [Omegadict_dataset_config[data_sample]['prompt'],]+ OmegaConf.to_object(Omegadict_dataset_config[data_sample]['target']))
p2p_config_now = dict()
for i in range(len(edit_config_now['editing_config']['editing_prompts'])):
p2p_config_now[i] = p2p_config
if 'eq_params' in Omegadict_dataset_config[data_sample]:
p2p_config_now[i]['eq_params'] = Omegadict_dataset_config[data_sample]['eq_params']
edit_config_now['editing_config']['p2p_config'] = copy.deepcopy(p2p_config_now)
edit_config_now['editing_config']['source_prompt'] = Omegadict_dataset_config[data_sample]['prompt']
# edit_config_now['editing_config']['source_prompt'] = Omegadict_dataset_config[data_sample]['eq_params']
# if 'logdir' not in edit_config_now:
logdir = edit_config.replace('config', 'result').replace('.yml', '').replace('.yaml', '')+f'_config_{p2p_config_index}'+f'_{os.path.basename(dataset_config)[:-5]}'+f'_{dataset_time_string}'
logdir += f"/{data_sample}"
edit_config_now['logdir'] = logdir
print(f'Saving at {logdir}')
test(config=edit_config, **edit_config_now)
if __name__ == "__main__":
run()