Skip to content

Commit

Permalink
refactor(download): update ouput name for the signature alphabet file
Browse files Browse the repository at this point in the history
  • Loading branch information
tduigou committed Aug 4, 2023
1 parent 6ceb79a commit a2ef08c
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/paper/dataset/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def filter(smi, radius, verbose=False):
fdataset_train = os.path.join(args.output_directory_str, "dataset.train")
fdataset_valid = os.path.join(args.output_directory_str, "dataset.valid")
fdataset_test = os.path.join(args.output_directory_str, "dataset.test")
falphabet = os.path.join(args.output_directory_str, "alphabet.npz")
falphabet = os.path.join(args.output_directory_str, "sig_alphabet.npz")

# Create output directory
if not os.path.isdir(args.output_directory_str):
Expand Down Expand Up @@ -184,16 +184,17 @@ def filter(smi, radius, verbose=False):
):
H, D = read_csv(fdataset)
np.random.shuffle(D)

Smiles = np.asarray(list(set(D[:, 1])))

total_size = D.shape[0]
valid_size = round(args.parameters_valid_percent_float * total_size / 100.0)
test_size = round(args.parameters_test_percent_float * total_size / 100.0)
train_size = total_size - valid_size - test_size
print("Total size:", total_size, "Train size:", train_size, "Valid size:", valid_size, "Test size:", test_size)
train_data = D[:train_size]
valid_data = D[train_size: train_size+valid_size]
test_data = D[train_size+valid_size:]
train_data = D[: train_size]
valid_data = D[train_size: train_size + valid_size]
test_data = D[train_size + valid_size:]
print(D.shape[0], train_data.shape[0], valid_data.shape[0], test_data.shape[0])
assert train_data.shape[0] + valid_data.shape[0] + test_data.shape[0] == D.shape[0]
assert train_data.shape[0] == train_size
Expand All @@ -208,10 +209,8 @@ def filter(smi, radius, verbose=False):
df_test.to_csv(fdataset_test+".csv", index=False)

# Alphabet Signature
print("Build Alphabet")
print("Build Signature alphabet")
df = pd.read_csv(fdataset+".csv")
# Smiles = np.asarray(list(set(D[:, 0])))
# print(Smiles)
Alphabet = SignatureAlphabet(
radius=args.parameters_radius_int, nBits=0, neighbors=False, allHsExplicit=False
)
Expand Down

0 comments on commit a2ef08c

Please sign in to comment.