Skip to content

Commit

Permalink
Merge pull request #350 from aqlaboratory/fix-msastack-test-error
Browse files Browse the repository at this point in the history
Fixes cuda/float wrapper error in unit tests
  • Loading branch information
jnwei authored Sep 21, 2023
2 parents 2134cc0 + 73ff40b commit 60d0b15
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion scripts/install_third_party_dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ echo "Downloading AlphaFold parameters..."
bash scripts/download_alphafold_params.sh openfold/resources

# Decompress test data
gunzip tests/test_data/sample_feats.pickle.gz
gunzip -c tests/test_data/sample_feats.pickle.gz > tests/test_data/sample_feats.pickle
4 changes: 2 additions & 2 deletions tests/test_evoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def test_shape(self):
n_res,
),
device="cuda",
)
).float()
pair_mask = torch.randint(
0,
2,
Expand All @@ -216,7 +216,7 @@ def test_shape(self):
n_res,
),
device="cuda",
)
).float()

shape_z_before = z.shape

Expand Down
22 changes: 11 additions & 11 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,27 +47,27 @@ def test_dry_run(self):
c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test

model = AlphaFold(c)
model = AlphaFold(c).cuda()
model.eval()

batch = {}
tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,))
tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,)).cuda()
batch["target_feat"] = nn.functional.one_hot(
tf, c.model.input_embedder.tf_dim
).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
batch["residue_index"] = torch.arange(n_res)
batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim))
).float().cuda()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1).cuda()
batch["residue_index"] = torch.arange(n_res).cuda()
batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim)).cuda()
t_feats = random_template_feats(n_templ, n_res)
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
batch.update({k: torch.tensor(v).cuda() for k, v in t_feats.items()})
extra_feats = random_extra_msa_feats(n_extra_seq, n_res)
batch.update({k: torch.tensor(v) for k, v in extra_feats.items()})
batch.update({k: torch.tensor(v).cuda() for k, v in extra_feats.items()})
batch["msa_mask"] = torch.randint(
low=0, high=2, size=(n_seq, n_res)
).float()
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float()
).float().cuda()
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float().cuda()
batch.update(data_transforms.make_atom14_masks(batch))
batch["no_recycling_iters"] = torch.tensor(2.)
batch["no_recycling_iters"] = torch.tensor(2.).cuda()

add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
Expand Down

0 comments on commit 60d0b15

Please sign in to comment.