Skip to content

Commit

Permalink
Make seeding more consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
gahdritz authored Aug 9, 2023
1 parent 410e182 commit 39d0ef4
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions openfold/data/data_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import itertools
from functools import reduce, wraps
from operator import add
import random

import numpy as np
import torch
Expand Down Expand Up @@ -183,12 +184,11 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
@curry1
def sample_msa(protein, max_seq, keep_extra, seed=None):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
if(seed is None):
seed = random.randint(0, 2147483647)
num_seq = protein["msa"].shape[0]
g = torch.Generator(device=protein["msa"].device)
if seed is not None:
g.manual_seed(seed)
else:
g.seed()
g.manual_seed(seed)
shuffled = torch.randperm(num_seq - 1, generator=g) + 1
index_order = torch.cat(
(torch.tensor([0], device=shuffled.device), shuffled),
Expand Down

0 comments on commit 39d0ef4

Please sign in to comment.