-
Notifications
You must be signed in to change notification settings - Fork 149
/
classification_inference.py
112 lines (97 loc) · 5.14 KB
/
classification_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import os
import pandas as pd
import torch
from dgllife.data import UnlabeledSMILES
from dgllife.utils import MolToBigraph
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import mkdir_p, collate_molgraphs_unlabeled, load_model, predict, init_featurizer
def main(args):
mol_to_g = MolToBigraph(add_self_loop=True,
node_featurizer=args['node_featurizer'],
edge_featurizer=args['edge_featurizer'])
dataset = UnlabeledSMILES(args['smiles'], mol_to_graph=mol_to_g)
dataloader = DataLoader(dataset, batch_size=args['batch_size'],
collate_fn=collate_molgraphs_unlabeled, num_workers=args['num_workers'])
model = load_model(args).to(args['device'])
checkpoint = torch.load(args['train_result_path'] + '/model.pth', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
smiles_list = []
predictions = []
with torch.no_grad():
for batch_id, batch_data in enumerate(tqdm(dataloader, desc="Iteration")):
batch_smiles, bg = batch_data
smiles_list.extend(batch_smiles)
batch_pred = predict(args, model, bg)
if not args['soft_classification']:
batch_pred = (batch_pred >= 0.5).float()
predictions.append(batch_pred.detach().cpu())
predictions = torch.cat(predictions, dim=0)
output_data = {'canonical_smiles': smiles_list}
if args['task_names'] is None:
args['task_names'] = ['task_{:d}'.format(t) for t in range(1, args['n_tasks'] + 1)]
else:
args['task_names'] = args['task_names'].split(',')
for task_id, task_name in enumerate(args['task_names']):
output_data[task_name] = predictions[:, task_id]
df = pd.DataFrame(output_data)
df.to_csv(args['inference_result_path'] + '/prediction.csv', index=False)
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser('Inference for Multi-label Binary Classification')
parser.add_argument('-f', '--file-path', type=str, required=True,
help='Path to a .csv/.txt file of SMILES strings')
parser.add_argument('-sc', '--smiles-column', type=str,
help='Header for the SMILES column in the CSV file, can be '
'omitted if the input file is a .txt file or the .csv '
'file only has one column of SMILES strings')
parser.add_argument('-tp', '--train-result-path', type=str, default='classification_results',
help='Path to the saved training results, which will be used for '
'loading the trained model and related configurations')
parser.add_argument('-ip', '--inference-result-path', type=str, default='classification_inference_results',
help='Path to save the inference results')
parser.add_argument('-t', '--task-names', default=None, type=str,
help='Task names for saving model predictions in the CSV file to output, '
'which should be the same as the ones used for training. If not '
'specified, we will simply use task1, task2, ...')
parser.add_argument('-s', '--soft-classification', action='store_true', default=False,
help='By default we will perform hard classification with binary labels. '
'This flag allows performing soft classification instead.')
parser.add_argument('-nw', '--num-workers', type=int, default=1,
help='Number of processes for data loading (default: 1)')
args = parser.parse_args().__dict__
# Load configuration
with open(args['train_result_path'] + '/configure.json', 'r') as f:
args.update(json.load(f))
if torch.cuda.is_available():
args['device'] = torch.device('cuda:0')
else:
args['device'] = torch.device('cpu')
if args['file_path'].endswith('.csv') or args['file_path'].endswith('.csv.gz'):
import pandas
df = pandas.read_csv(args['file_path'])
if args['smiles_column'] is not None:
smiles = df[args['smiles_column']].tolist()
else:
assert len(df.columns) == 1, 'The CSV file has more than 1 columns and ' \
'-sc (smiles-column) needs to be specified.'
smiles = df[df.columns[0]].tolist()
elif args['file_path'].endswith('.txt'):
from dgllife.utils import load_smiles_from_txt
smiles = load_smiles_from_txt(args['file_path'])
else:
raise ValueError('Expect the input data file to be a .csv or a .txt file, '
'got {}'.format(args['file_path']))
args['smiles'] = smiles
args = init_featurizer(args)
# Handle directories
mkdir_p(args['inference_result_path'])
assert os.path.exists(args['train_result_path']), \
'The path to the saved training results does not exist.'
main(args)