|
| 1 | +from concurrent.futures import ThreadPoolExecutor |
| 2 | +from copy import deepcopy |
| 3 | +import functools |
| 4 | +import gtpp_client.util as cu |
| 5 | +from gtpp_client.models.compute_ex_job import ComputeExJob |
| 6 | +import numpy as np |
| 7 | +from os import getenv |
| 8 | +import time |
| 9 | +from typing import List |
| 10 | +import unittest |
| 11 | +from warnings import warn |
| 12 | +from .gtpp import ChIMPServe |
| 13 | +if getenv("GTPP_DEPLOYMENT_FLAVOR", "kserve").lower() == "sagemaker": |
| 14 | + from sagemaker_gtpp.model_base import ModelBase |
| 15 | + import sagemaker_gtpp.test_base as test_base |
| 16 | + from sagemaker_gtpp.test_base import MicroserviceException |
| 17 | + import sagemaker_gtpp.test_util as tu |
| 18 | +else : |
| 19 | + from kserve_gtpp.model_base import ModelBase |
| 20 | + import kserve_gtpp.test_base as test_base |
| 21 | + from kserve_gtpp.test_base import MicroserviceException |
| 22 | + import kserve_gtpp.test_util as tu |
| 23 | + |
| 24 | +class MyTestCase(test_base.TestBase): |
| 25 | + |
| 26 | + def __init__(self, *args, **kwargs): |
| 27 | + super().__init__(*args, **kwargs) |
| 28 | + |
| 29 | + def get_model(self) -> ModelBase: |
| 30 | + """ |
| 31 | + This method MUST be implemented. It should return an instance of the model class. |
| 32 | + You may place additional model-specific code here (for example, to initialize |
| 33 | + the model in a certain non-standard way). get_model() is executed once for each test, |
| 34 | + that is, each test from the test suite uses a new instance of the model class. |
| 35 | + :return: an instance of the model class that derives from ModelBase. |
| 36 | + """ |
| 37 | + model = ChIMPServe() |
| 38 | + # ignore uq colums for non-deterministic uncertainty quantification results |
| 39 | + self._test_stability = self._test_stability_without_uq_columns |
| 40 | + return model |
| 41 | + |
| 42 | + def _test_stability_without_uq_columns(self, codes: List[str]): |
| 43 | + # modified variant of TestBase._test_stability() that ignores the |
| 44 | + # uncertainty columns |
| 45 | + test_smiles = self.STABILITY_TEST_SMILES |
| 46 | + result_table = None |
| 47 | + print(f'Running n_inner_loop_repeats={self.STABILITY_NUM_INNER_LOOP_REPEATS}') |
| 48 | + for i in range(self.STABILITY_NUM_INNER_LOOP_REPEATS): |
| 49 | + smiles_chunks = list(tu.gen_random_chunks(test_smiles)) |
| 50 | + with cu.ContextStopwatch(f'num. chunks={len(smiles_chunks)}'): |
| 51 | + # normally, the number of concurrently executing threads |
| 52 | + # will be equal to the number of chunks test_smiles list is split into |
| 53 | + options = self.STABILITY_COMPUTE_OPTIONS.copy() |
| 54 | + with ThreadPoolExecutor(max_workers=self.STABILITY_MAX_THREADS) as executor: |
| 55 | + ret = executor.map(functools.partial(self.run_model_compute, |
| 56 | + property_codes=codes, options=options), smiles_chunks) |
| 57 | + complete_table_original = cu.TableMerger.append_multiple_tables(ret) |
| 58 | + complete_table = self._only_keep_every_third_column(complete_table_original) |
| 59 | + if result_table is None: |
| 60 | + result_table = complete_table |
| 61 | + # Note: in all currently deployed models, there should be 1-to-1 mapping between |
| 62 | + # the input chemical structures and output table rows. |
| 63 | + # This may change in the future (the gtpp API itself has no such restriction). |
| 64 | + self.assertEqual(len(complete_table_original.rows), len(test_smiles), 'Num. output rows is not the same as ' |
| 65 | + 'the size of the input list') |
| 66 | + print('First iteration done, got sample result table for comparison.') |
| 67 | + else: |
| 68 | + try: |
| 69 | + tu.assert_tables_equal(self, result_table, complete_table) |
| 70 | + print(f'Passed with num. chunks={len(smiles_chunks)}') |
| 71 | + except Exception: |
| 72 | + print(f'FAILED with num. chunks={len(smiles_chunks)}') |
| 73 | + raise |
| 74 | + print(f'Iteration #{i + 1}: success.') |
| 75 | + |
| 76 | + print(f'_test_stability passed') |
| 77 | + pass |
| 78 | + |
| 79 | + def _only_keep_every_third_column(self, table): |
| 80 | + table_out = deepcopy(table) |
| 81 | + table_out.columns = [c for i,c in enumerate(table.columns) if not i%3] |
| 82 | + for r in table_out.rows: |
| 83 | + r.values = [x for i,x in enumerate(r.values) if not i%3] |
| 84 | + return table_out |
| 85 | + |
| 86 | + def test_uncertainty_columns(self): |
| 87 | + test_smiles = ['C[C@H](N)C(=O)O', 'CC(O)C(=O)O', 'OCC(O)CO', 'Oc1ccccc1', |
| 88 | + 'Nc1ccncc1', 'C#CC(C)(O)CC', '', 'ClCCCl', 'O=C1CCC(=O)N1', 'O=CCCCC=O', |
| 89 | + 'N[C@@H]1CONC1=O', 'S1C=CSC1=C'] + self.get_10_valid_smiles() |
| 90 | + prop_desciptions = self.get_properties() |
| 91 | + all_prop_codes = [pd.code for pd in prop_desciptions] |
| 92 | + res = self.run_model_compute(smiles_list = test_smiles, |
| 93 | + property_codes=all_prop_codes) |
| 94 | + |
| 95 | + # assert that lcl<=main<=ucl for all properties and all test compounds |
| 96 | + n_columns = len(res.columns) |
| 97 | + pred_main = [np.array([x.values[i] for x in res.rows if None not in x.values]) for i in range(0,n_columns, 3)] |
| 98 | + pred_lcl = [np.array([x.values[i] for x in res.rows if None not in x.values]) for i in range(1,n_columns, 3)] |
| 99 | + pred_ucl = [np.array([x.values[i] for x in res.rows if None not in x.values]) for i in range(2,n_columns, 3)] |
| 100 | + self.assertTrue(all(all(x<=y) for x,y in zip(pred_lcl, pred_main))) |
| 101 | + self.assertTrue(all(all(x<=y) for x,y in zip(pred_main, pred_ucl))) |
| 102 | + |
| 103 | +if __name__ == '__main__': |
| 104 | + unittest.main() |
0 commit comments