Skip to content

Commit 29bd20b

Browse files
committed
bugfix & typo
1 parent 85176f4 commit 29bd20b

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ optional arguments:
5050
```
5151
For example:
5252
```bash
53-
python alphafold.py -i test_data/T1019s2.pkl -o T1019s2_out -t D -t 0
53+
python alphafold.py -i test_data/T1019s2.pkl -o T1019s2_out -t D -r 0
5454
```
5555
This uses the replica `0` of `Distogram` models to predict the distogram probs of the input data.
5656

alphafold.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def run_eval(target_path, model_path, replica, out_dir, device):
1818
model.load_state_dict(torch.load(model_file, map_location=device))
1919
else:
2020
cost_time = utils.load_tf_ckpt(model, model_file)
21+
model.to(device)
2122
print(f'Load tf model cost time: {cost_time}')
2223

2324
num_examples = 0
@@ -66,6 +67,7 @@ def run_eval(target_path, model_path, replica, out_dir, device):
6667
crop_x = torch.tensor([crop_x]).to(device)
6768
crop_y = torch.tensor([crop_y]).to(device)
6869
out = model(x_2d, crop_x, crop_y)
70+
out = {k:t.cpu() for k,t in out.items()}
6971

7072
contact_probs = out['contact_probs'][0,
7173
prepad_y:crop_size_y - postpad_y,
@@ -84,19 +86,19 @@ def run_eval(target_path, model_path, replica, out_dir, device):
8486

8587
if 'secstruct_probs' in out:
8688
sec_x = out['secstruct_probs'][0, prepad_x:crop_size_x - postpad_x].numpy()
87-
sec_y = out['secstruct_probs'][0, crop_size_x + prepad_y:crop_size_x + crop_size_y - postpad_y].detach().numpy()
89+
sec_y = out['secstruct_probs'][0, crop_size_x + prepad_y:crop_size_x + crop_size_y - postpad_y].numpy()
8890
sec_accum[ic:ic + sec_x.shape[0]] += sec_x
8991
sec_accum[jc:jc + sec_y.shape[0]] += sec_y
9092

9193
if 'torsion_probs' in out:
9294
tor_x = out['torsion_probs'][0, prepad_x:crop_size_x - postpad_x].numpy()
93-
tor_y = out['torsion_probs'][0, crop_size_x + prepad_y:crop_size_x + crop_size_y - postpad_y].detach().numpy()
95+
tor_y = out['torsion_probs'][0, crop_size_x + prepad_y:crop_size_x + crop_size_y - postpad_y].numpy()
9496
tor_accum[ic:ic + tor_x.shape[0]] += tor_x
9597
tor_accum[jc:jc + tor_y.shape[0]] += tor_y
9698

9799
if 'asa_output' in out:
98100
asa_x = out['asa_output'][0, prepad_x:crop_size_x - postpad_x].numpy()
99-
asa_y = out['asa_output'][0, crop_size_x + prepad_y:crop_size_x + crop_size_y - postpad_y].detach().numpy()
101+
asa_y = out['asa_output'][0, crop_size_x + prepad_y:crop_size_x + crop_size_y - postpad_y].numpy()
100102
asa_accum[ic:ic + asa_x.shape[0]] += np.squeeze(asa_x, 1)
101103
asa_accum[jc:jc + asa_y.shape[0]] += np.squeeze(asa_y, 1)
102104

network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def build_crops_biases(self, bias_size, raw_biases, crop_x, crop_y):
264264
crop_size_x = (crop_x[:, 1] - crop_x[:, 0]).max()
265265
crop_size_y = (crop_y[:, 1] - crop_y[:, 0]).max()
266266

267-
increment = torch.unsqueeze(-torch.arange(0, crop_size_y), 0)
267+
increment = torch.unsqueeze(-torch.arange(0, crop_size_y), 0).to(crop_x.device)
268268
row_offsets = start_diag + increment
269269
row_offsets += padded_bias_size - 1
270270

0 commit comments

Comments
 (0)