@@ -18,6 +18,7 @@ def run_eval(target_path, model_path, replica, out_dir, device):
18
18
model .load_state_dict (torch .load (model_file , map_location = device ))
19
19
else :
20
20
cost_time = utils .load_tf_ckpt (model , model_file )
21
+ model .to (device )
21
22
print (f'Load tf model cost time: { cost_time } ' )
22
23
23
24
num_examples = 0
@@ -66,6 +67,7 @@ def run_eval(target_path, model_path, replica, out_dir, device):
66
67
crop_x = torch .tensor ([crop_x ]).to (device )
67
68
crop_y = torch .tensor ([crop_y ]).to (device )
68
69
out = model (x_2d , crop_x , crop_y )
70
+ out = {k :t .cpu () for k ,t in out .items ()}
69
71
70
72
contact_probs = out ['contact_probs' ][0 ,
71
73
prepad_y :crop_size_y - postpad_y ,
@@ -84,19 +86,19 @@ def run_eval(target_path, model_path, replica, out_dir, device):
84
86
85
87
if 'secstruct_probs' in out :
86
88
sec_x = out ['secstruct_probs' ][0 , prepad_x :crop_size_x - postpad_x ].numpy ()
87
- sec_y = out ['secstruct_probs' ][0 , crop_size_x + prepad_y :crop_size_x + crop_size_y - postpad_y ].detach (). numpy ()
89
+ sec_y = out ['secstruct_probs' ][0 , crop_size_x + prepad_y :crop_size_x + crop_size_y - postpad_y ].numpy ()
88
90
sec_accum [ic :ic + sec_x .shape [0 ]] += sec_x
89
91
sec_accum [jc :jc + sec_y .shape [0 ]] += sec_y
90
92
91
93
if 'torsion_probs' in out :
92
94
tor_x = out ['torsion_probs' ][0 , prepad_x :crop_size_x - postpad_x ].numpy ()
93
- tor_y = out ['torsion_probs' ][0 , crop_size_x + prepad_y :crop_size_x + crop_size_y - postpad_y ].detach (). numpy ()
95
+ tor_y = out ['torsion_probs' ][0 , crop_size_x + prepad_y :crop_size_x + crop_size_y - postpad_y ].numpy ()
94
96
tor_accum [ic :ic + tor_x .shape [0 ]] += tor_x
95
97
tor_accum [jc :jc + tor_y .shape [0 ]] += tor_y
96
98
97
99
if 'asa_output' in out :
98
100
asa_x = out ['asa_output' ][0 , prepad_x :crop_size_x - postpad_x ].numpy ()
99
- asa_y = out ['asa_output' ][0 , crop_size_x + prepad_y :crop_size_x + crop_size_y - postpad_y ].detach (). numpy ()
101
+ asa_y = out ['asa_output' ][0 , crop_size_x + prepad_y :crop_size_x + crop_size_y - postpad_y ].numpy ()
100
102
asa_accum [ic :ic + asa_x .shape [0 ]] += np .squeeze (asa_x , 1 )
101
103
asa_accum [jc :jc + asa_y .shape [0 ]] += np .squeeze (asa_y , 1 )
102
104
0 commit comments