Skip to content

Commit

Permalink
feat(paper): download, add FP count and extract test_small
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-gricourt committed Aug 22, 2023
1 parent 60da149 commit 9987ee8
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions src/paper/dataset/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def sanitize(
continue # aready there
if len(smi) > int(max_molecular_weight / 5): # Cheap skip
continue
mol, smi = SanitizeMolecule(Chem.MolFromSmiles(smi))
mol, smi = SanitizeMolecule(Chem.MolFromSmiles(smi)) #, formalCharge=True)
if mol is None:
continue
if Chem.Descriptors.ExactMolWt(mol) > max_molecular_weight:
Expand All @@ -56,12 +56,12 @@ def sanitize(
# Compute signature in various format
def filter(smi, radius, verbose=False):
if "." in smi: #
return "", "", "", "", None, "", ""
return "", "", "", "", None, "", "", ""
if "*" in smi: #
return "", "", "", "", None, "", ""
return "", "", "", "", None, "", "", ""
if "[" in smi: # cannot process [*] without kekularization
if "@" not in smi:
return "", "", "", "", None, "", ""
return "", "", "", "", None, "", "", ""
smiles = smi
Alphabet = SignatureAlphabet(neighbors=False, radius=radius, nBits=0)
sig1, mol, smi = SignatureFromSmiles(smi, Alphabet, verbose=False)
Expand All @@ -72,12 +72,13 @@ def filter(smi, radius, verbose=False):
Alphabet = SignatureAlphabet(neighbors=True, radius=radius, nBits=2048)
sig4, mol, smi = SignatureFromSmiles(smi, Alphabet, verbose=False)
if sig1 == "" or sig2 == "" or sig3 == "" or sig4 == "":
return "", "", "", "", None, "", ""
return "", "", "", "", None, "", "", ""

mol = AllChem.MolFromSmiles(smiles)
fpgen = AllChem.GetMorganGenerator(radius=radius, fpSize=2048)
fp = fpgen.GetFingerprint(mol) # returns a bit vector (value 1 or 0)
return sig1, sig2, sig3, sig4, mol, smi, "".join([str(x) for x in fp.ToList()])
fp_count = fpgen.GetCountFingerprint(mol)
return sig1, sig2, sig3, sig4, mol, smi, "".join([str(x) for x in fp.ToList()]), "".join([str(x) for x in fp_count.ToList()])


if __name__ == "__main__":
Expand Down Expand Up @@ -130,6 +131,7 @@ def filter(smi, radius, verbose=False):
fdataset_train = os.path.join(args.output_directory_str, "dataset.train")
fdataset_valid = os.path.join(args.output_directory_str, "dataset.valid")
fdataset_test = os.path.join(args.output_directory_str, "dataset.test")
fdataset_test_small = os.path.join(args.output_directory_str, "dataset.test.small")
falphabet = os.path.join(args.output_directory_str, "sig_alphabet.npz")

# Create output directory
Expand Down Expand Up @@ -177,10 +179,10 @@ def filter(smi, radius, verbose=False):
print(f"Number of smiles: {len(Smiles)}")

# Get to business
H = ["SMILES", "SIG", "SIG-NEIGH", "SIG-NBIT", "SIG-NEIGH-NBIT", "ECFP4"]
H = ["SMILES", "SIG", "SIG-NEIGH", "SIG-NBIT", "SIG-NEIGH-NBIT", "ECFP4", "ECFP4_COUNT"]
D, i = {}, 0
for I in range(len(Smiles)):
sig1, sig2, sig3, sig4, mol, smi, fp = filter(
sig1, sig2, sig3, sig4, mol, smi, fp, fp_count = filter(
Smiles[i], radius=args.parameters_radius_int
)
# TD WARNING: some smiles are filtered out but it should objectively not be the case
Expand All @@ -198,7 +200,7 @@ def filter(smi, radius, verbose=False):
print(Smiles[i])
i += 1
continue
D[I] = [smi, sig1, sig2, sig3, sig4, fp]
D[I] = [smi, sig1, sig2, sig3, sig4, fp, fp_count]
i, I = i + 1, I + 1
if I == args.parameters_max_dataset_size_int:
break
Expand All @@ -212,6 +214,7 @@ def filter(smi, radius, verbose=False):
not os.path.isfile(fdataset_train + ".csv")
or not os.path.isfile(fdataset_valid + ".csv")
or not os.path.isfile(fdataset_test + ".csv")
or not os.path.isfile(fdataset_test_small + ".csv")
):
H, D = read_csv(fdataset)
np.random.shuffle(D)
Expand All @@ -235,20 +238,24 @@ def filter(smi, radius, verbose=False):
train_data = D[:train_size]
valid_data = D[train_size: train_size + valid_size]
test_data = D[train_size + valid_size:]
test_small_data = D[train_size + valid_size: train_size + valid_size + 1000]
print(D.shape[0], train_data.shape[0], valid_data.shape[0], test_data.shape[0])
assert (
train_data.shape[0] + valid_data.shape[0] + test_data.shape[0] == D.shape[0]
)
assert train_data.shape[0] == train_size
assert valid_data.shape[0] == valid_size
assert test_data.shape[0] == test_size
assert test_small_data.shape[0] == 1000

df_train = pd.DataFrame(data=train_data, columns=H)
df_train.to_csv(fdataset_train + ".csv", index=False)
df_valid = pd.DataFrame(data=valid_data, columns=H)
df_valid.to_csv(fdataset_valid + ".csv", index=False)
df_test = pd.DataFrame(data=test_data, columns=H)
df_test.to_csv(fdataset_test + ".csv", index=False)
df_test_small = pd.DataFrame(data=test_small_data, columns=H)
df_test_small.to_csv(fdataset_test_small + ".csv", index=False)

# Alphabet Signature
print("Build Signature alphabet")
Expand Down

0 comments on commit 9987ee8

Please sign in to comment.