Skip to content

Commit a30f278

Browse files
Steijaert, Marvin [JRDBE Non-J&J]stsouko
authored andcommitted
Pull request #7: Feature/split model instances
Merge in ASX-JFUG/chytorch from feature/split_model_instances to dev * commit 'fa2e1a664500072a8759726bb57a6c7a3f9f0fcb': minor improvements add unittest and make sagemaker-ready create copy of encoder for MC inference
2 parents 2a4d538 + fa2e1a6 commit a30f278

File tree

3 files changed

+124
-12
lines changed

3 files changed

+124
-12
lines changed

chytorch/zoo/autodl/inference_v1/gtpp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,16 @@
2727
from functools import cached_property
2828
from gtpp_client import ComputedPropertyDescription, Table, TableColumn, TableRow, ComputeExJob
2929
from gtpp_client.util import ContextStopwatch
30-
from kserve_gtpp.model_base import ModelBase, L as BASE_LOGGER, BadDataException
3130
from os import getenv
3231
from os.path import abspath
3332
from torch import set_num_threads, no_grad
3433
from torch.utils.data import DataLoader
3534
from typing import List
3635
from ..model import ChIMP, RotaryChIMP
37-
36+
if getenv("GTPP_DEPLOYMENT_FLAVOR", "kserve").lower() == "sagemaker":
37+
from sagemaker_gtpp.model_base import ModelBase, L as BASE_LOGGER, BadDataException
38+
else :
39+
from kserve_gtpp.model_base import ModelBase, L as BASE_LOGGER, BadDataException
3840

3941
prompt_shift = 120
4042
set_num_threads(int(getenv('CPU_COUNT', 6)))
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from concurrent.futures import ThreadPoolExecutor
2+
from copy import deepcopy
3+
import functools
4+
import gtpp_client.util as cu
5+
from gtpp_client.models.compute_ex_job import ComputeExJob
6+
import numpy as np
7+
from os import getenv
8+
import time
9+
from typing import List
10+
import unittest
11+
from warnings import warn
12+
from .gtpp import ChIMPServe
13+
if getenv("GTPP_DEPLOYMENT_FLAVOR", "kserve").lower() == "sagemaker":
14+
from sagemaker_gtpp.model_base import ModelBase
15+
import sagemaker_gtpp.test_base as test_base
16+
from sagemaker_gtpp.test_base import MicroserviceException
17+
import sagemaker_gtpp.test_util as tu
18+
else :
19+
from kserve_gtpp.model_base import ModelBase
20+
import kserve_gtpp.test_base as test_base
21+
from kserve_gtpp.test_base import MicroserviceException
22+
import kserve_gtpp.test_util as tu
23+
24+
class MyTestCase(test_base.TestBase):
25+
26+
def __init__(self, *args, **kwargs):
27+
super().__init__(*args, **kwargs)
28+
29+
def get_model(self) -> ModelBase:
30+
"""
31+
This method MUST be implemented. It should return an instance of the model class.
32+
You may place additional model-specific code here (for example, to initialize
33+
the model in a certain non-standard way). get_model() is executed once for each test,
34+
that is, each test from the test suite uses a new instance of the model class.
35+
:return: an instance of the model class that derives from ModelBase.
36+
"""
37+
model = ChIMPServe()
38+
# ignore uq colums for non-deterministic uncertainty quantification results
39+
self._test_stability = self._test_stability_without_uq_columns
40+
return model
41+
42+
def _test_stability_without_uq_columns(self, codes: List[str]):
43+
# modified variant of TestBase._test_stability() that ignores the
44+
# uncertainty columns
45+
test_smiles = self.STABILITY_TEST_SMILES
46+
result_table = None
47+
print(f'Running n_inner_loop_repeats={self.STABILITY_NUM_INNER_LOOP_REPEATS}')
48+
for i in range(self.STABILITY_NUM_INNER_LOOP_REPEATS):
49+
smiles_chunks = list(tu.gen_random_chunks(test_smiles))
50+
with cu.ContextStopwatch(f'num. chunks={len(smiles_chunks)}'):
51+
# normally, the number of concurrently executing threads
52+
# will be equal to the number of chunks test_smiles list is split into
53+
options = self.STABILITY_COMPUTE_OPTIONS.copy()
54+
with ThreadPoolExecutor(max_workers=self.STABILITY_MAX_THREADS) as executor:
55+
ret = executor.map(functools.partial(self.run_model_compute,
56+
property_codes=codes, options=options), smiles_chunks)
57+
complete_table_original = cu.TableMerger.append_multiple_tables(ret)
58+
complete_table = self._only_keep_every_third_column(complete_table_original)
59+
if result_table is None:
60+
result_table = complete_table
61+
# Note: in all currently deployed models, there should be 1-to-1 mapping between
62+
# the input chemical structures and output table rows.
63+
# This may change in the future (the gtpp API itself has no such restriction).
64+
self.assertEqual(len(complete_table_original.rows), len(test_smiles), 'Num. output rows is not the same as '
65+
'the size of the input list')
66+
print('First iteration done, got sample result table for comparison.')
67+
else:
68+
try:
69+
tu.assert_tables_equal(self, result_table, complete_table)
70+
print(f'Passed with num. chunks={len(smiles_chunks)}')
71+
except Exception:
72+
print(f'FAILED with num. chunks={len(smiles_chunks)}')
73+
raise
74+
print(f'Iteration #{i + 1}: success.')
75+
76+
print(f'_test_stability passed')
77+
pass
78+
79+
def _only_keep_every_third_column(self, table):
80+
table_out = deepcopy(table)
81+
table_out.columns = [c for i,c in enumerate(table.columns) if not i%3]
82+
for r in table_out.rows:
83+
r.values = [x for i,x in enumerate(r.values) if not i%3]
84+
return table_out
85+
86+
def test_uncertainty_columns(self):
87+
test_smiles = ['C[C@H](N)C(=O)O', 'CC(O)C(=O)O', 'OCC(O)CO', 'Oc1ccccc1',
88+
'Nc1ccncc1', 'C#CC(C)(O)CC', '', 'ClCCCl', 'O=C1CCC(=O)N1', 'O=CCCCC=O',
89+
'N[C@@H]1CONC1=O', 'S1C=CSC1=C'] + self.get_10_valid_smiles()
90+
prop_desciptions = self.get_properties()
91+
all_prop_codes = [pd.code for pd in prop_desciptions]
92+
res = self.run_model_compute(smiles_list = test_smiles,
93+
property_codes=all_prop_codes)
94+
95+
# assert that lcl<=main<=ucl for all properties and all test compounds
96+
n_columns = len(res.columns)
97+
pred_main = [np.array([x.values[i] for x in res.rows if None not in x.values]) for i in range(0,n_columns, 3)]
98+
pred_lcl = [np.array([x.values[i] for x in res.rows if None not in x.values]) for i in range(1,n_columns, 3)]
99+
pred_ucl = [np.array([x.values[i] for x in res.rows if None not in x.values]) for i in range(2,n_columns, 3)]
100+
self.assertTrue(all(all(x<=y) for x,y in zip(pred_lcl, pred_main)))
101+
self.assertTrue(all(all(x<=y) for x,y in zip(pred_main, pred_ucl)))
102+
103+
if __name__ == '__main__':
104+
unittest.main()

chytorch/zoo/autodl/model.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
#
2323
from chytorch.nn import (MoleculeRotaryEncoder, MoleculeEncoder, ConditionedMaskedSlicer,
2424
CensoredLoss, MaskedNaNLoss, MultiTaskLoss, LossDispatcher)
25+
from copy import deepcopy
2526
from functools import reduce
2627
from lightning.pytorch import LightningModule
2728
from operator import add
28-
from threading import Lock
2929
from torch import bfloat16, zeros, zeros_like, stack, empty, minimum, maximum, where, ones, tensor
3030
from torch.nn import Linear, BCEWithLogitsLoss, SmoothL1Loss
3131
from torch.nn.functional import embedding
@@ -89,23 +89,29 @@ def __init__(self, d_model: int = 256, max_tokens: int = 10_000,
8989
self.lr_mode = lr_mode
9090
self.betas = betas
9191
self.weight_decay = weight_decay
92+
self.__copied_encoder = None
9293

94+
@property
95+
def _mc_encoder(self):
96+
# creates a copy of the encoder in train mode
97+
if not self.__copied_encoder:
98+
self.__copied_encoder = deepcopy(self.encoder)
99+
for x in self.__copied_encoder.layers:
100+
x.train()
101+
return self.__copied_encoder
102+
93103
def predict(self, batch):
94104
# each mol in minibatch must be with the same prompt size.
95105
# I'm not going to check it here
96106
prompt_size = (batch.atoms[0] > self.prompt_shift).sum()
97107
prompt_batch = (batch.atoms[:, :prompt_size], batch.neighbors[:, :prompt_size], batch.distances[:, :prompt_size])
98108

99109
cache = self.build_cache(batch)
100-
with Lock(): # make sure MC runs are not going in parallel
101-
mid = self.head(self.encoder(batch, cache=cache)[:, :prompt_size]).flatten()
102-
for x in self.encoder.layers:
103-
x.train()
104-
mc = stack([
105-
self.head(self.encoder(prompt_batch, cache=cache, cache_direction='left')).flatten()
106-
for _ in range(self.monte_carlo_runs)
107-
])
108-
self.eval()
110+
mid = self.head(self.encoder(batch, cache=cache)[:, :prompt_size]).flatten()
111+
mc = stack([
112+
self.head(self._mc_encoder(prompt_batch, cache=cache, cache_direction='left')).flatten()
113+
for _ in range(self.monte_carlo_runs)
114+
])
109115

110116
# get quantiles
111117
low = mc.quantile(self.monte_carlo_quantile, dim=0)

0 commit comments

Comments
 (0)