From 2bc08f775437667b087b76a5870943720c3692c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?arthur=20b=C3=B6=C3=B6k?= Date: Sat, 29 Jul 2023 14:03:37 -0700 Subject: [PATCH] added transformer for knowledge data -> numpy matrix --- dr_claude/datamodels.py | 116 +++++++++++++++++++++++++++++++++++--- dr_claude/mcts/base.py | 44 ++++++++++++++- poetry.lock | 120 +++++++++++++++++++++++++++++++++++++++- pyproject.toml | 1 + 4 files changed, 268 insertions(+), 13 deletions(-) diff --git a/dr_claude/datamodels.py b/dr_claude/datamodels.py index 360dee3..58feba3 100644 --- a/dr_claude/datamodels.py +++ b/dr_claude/datamodels.py @@ -1,33 +1,63 @@ import collections -from typing import Dict, List, Tuple - +import itertools +from typing import ( + Dict, + List, + NamedTuple, + Protocol, + Set, + Tuple, + Type, + TypeVar, + cast, + runtime_checkable, +) + +import numpy as np +import pandas as pd import pydantic -class UMLSMixin: +@runtime_checkable +class HasUMLS(Protocol): """ UMLS Mixin """ umls_code: str - def __hash__(self) -> int: + +HasUMLSClass = TypeVar("HasUMLSClass", bound=HasUMLS) + + +def set_umls_methods(cls: Type[HasUMLSClass]) -> Type[HasUMLSClass]: + def __hash__(self: HasUMLSClass) -> int: return hash(self.umls_code) - def __eq__(self, other: object) -> bool: - return isinstance(other, self.__class__) and self.umls_code == other.umls_code + def __eq__(self: HasUMLSClass, other: object) -> bool: + return isinstance(other, HasUMLS) and self.umls_code == other.umls_code + cls.__hash__ = __hash__ + cls.__eq__ = __eq__ + return cls -class Symptom(UMLSMixin, pydantic.BaseModel): + +@set_umls_methods +class Symptom(pydantic.BaseModel): """ Symptom """ name: str umls_code: str + noise_rate: float = 0.03 + + class Config: + frozen = True -class Condition(UMLSMixin, pydantic.BaseModel): +@set_umls_methods +class Condition(pydantic.BaseModel): """ Condition """ @@ -35,7 +65,11 @@ class Condition(UMLSMixin, pydantic.BaseModel): name: str umls_code: str + class Config: + frozen = True + +@set_umls_methods class WeightedSymptom(Symptom): """ Weight @@ -43,6 +77,70 @@ class WeightedSymptom(Symptom): weight: float # between 0,1 + class Config: + frozen = True + class DiseaseSymptomKnowledgeBase(pydantic.BaseModel): - pairs: Dict[Condition, List[WeightedSymptom]] + condition_symptoms: Dict[Condition, List[WeightedSymptom]] + + +""" +Helper methods +""" + + +class MatrixIndex(NamedTuple): + rows: Dict[Symptom, int] + columns: Dict[Condition, int] + + +class MonotonicCounter: + """ + A counter that increments and returns a new value each time it is called + """ + + def __init__(self, start: int = 0): + self._count = start + + def __call__(self) -> int: + c = self._count + self._count += 1 + return c + + +class SymptomTransformer: + @staticmethod + def to_symptom(symptom: WeightedSymptom) -> Symptom: + return Symptom(**symptom.dict()) + + +class DiseaseSymptomKnowledgeBaseTransformer: + @staticmethod + def to_numpy(kb: DiseaseSymptomKnowledgeBase) -> Tuple[np.ndarray, MatrixIndex]: + """ + Returns a numpy array of the weights of each symptom for each condition + """ + + ## init symptoms + all_symptoms = itertools.chain.from_iterable(kb.condition_symptoms.values()) + symptom_idx: Dict[Symptom, int] = collections.defaultdict(MonotonicCounter()) + [symptom_idx[s] for s in map(SymptomTransformer.to_symptom, all_symptoms)] + + ## init conditions + disease_idx: Dict[Condition, int] = collections.defaultdict(MonotonicCounter()) + [disease_idx[condition] for condition in kb.condition_symptoms.keys()] + + ## the antagonist + probas = np.zeros((len(symptom_idx), len(disease_idx))) + + ## fill noise vals + for symptom, index in symptom_idx.items(): + probas[index, :] = symptom.noise_rate + + ## fill known probas + for condition, symptoms in kb.condition_symptoms.items(): + for symptom in symptoms: + probas[symptom_idx[symptom], disease_idx[condition]] = symptom.weight + + return (probas, MatrixIndex(symptom_idx, disease_idx)) diff --git a/dr_claude/mcts/base.py b/dr_claude/mcts/base.py index 6f9fcf4..e31e0d2 100644 --- a/dr_claude/mcts/base.py +++ b/dr_claude/mcts/base.py @@ -1,3 +1,41 @@ -class MCTS: - def __init__(self) -> None: - pass +from typing import List, Optional +from dr_claude import datamodels +import math +import random + + +if __name__ == "__main__": + db = datamodels.DiseaseSymptomKnowledgeBase( + condition_symptoms={ + datamodels.Condition(name="COVID-19", umls_code="C0000001"): [ + datamodels.WeightedSymptom( + name="Fever", + umls_code="C0000002", + weight=0.5, + noise_rate=0.2, + ), + datamodels.WeightedSymptom( + name="Cough", + umls_code="C0000003", + weight=0.5, + noise_rate=0.1, + ), + ], + datamodels.Condition(name="Common Cold", umls_code="C0000004"): [ + datamodels.WeightedSymptom( + name="Fever", + umls_code="C0000002", + weight=0.5, + noise_rate=0.05, + ), + datamodels.WeightedSymptom( + name="Runny nose", + umls_code="C0000004", + weight=0.5, + noise_rate=0.01, + ), + ], + } + ) + + dataframe, index = datamodels.DiseaseSymptomKnowledgeBaseTransformer.to_numpy(db) diff --git a/poetry.lock b/poetry.lock index 118eead..4b50667 100644 --- a/poetry.lock +++ b/poetry.lock @@ -934,6 +934,73 @@ files = [ {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, ] +[[package]] +name = "pandas" +version = "2.0.3" +description = "Powerful data structures for data analysis, time series, and statistics" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pandas-2.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e4c7c9f27a4185304c7caf96dc7d91bc60bc162221152de697c98eb0b2648dd8"}, + {file = "pandas-2.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f167beed68918d62bffb6ec64f2e1d8a7d297a038f86d4aed056b9493fca407f"}, + {file = "pandas-2.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce0c6f76a0f1ba361551f3e6dceaff06bde7514a374aa43e33b588ec10420183"}, + {file = "pandas-2.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba619e410a21d8c387a1ea6e8a0e49bb42216474436245718d7f2e88a2f8d7c0"}, + {file = "pandas-2.0.3-cp310-cp310-win32.whl", hash = "sha256:3ef285093b4fe5058eefd756100a367f27029913760773c8bf1d2d8bebe5d210"}, + {file = "pandas-2.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:9ee1a69328d5c36c98d8e74db06f4ad518a1840e8ccb94a4ba86920986bb617e"}, + {file = "pandas-2.0.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b084b91d8d66ab19f5bb3256cbd5ea661848338301940e17f4492b2ce0801fe8"}, + {file = "pandas-2.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:37673e3bdf1551b95bf5d4ce372b37770f9529743d2498032439371fc7b7eb26"}, + {file = "pandas-2.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9cb1e14fdb546396b7e1b923ffaeeac24e4cedd14266c3497216dd4448e4f2d"}, + {file = "pandas-2.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d9cd88488cceb7635aebb84809d087468eb33551097d600c6dad13602029c2df"}, + {file = "pandas-2.0.3-cp311-cp311-win32.whl", hash = "sha256:694888a81198786f0e164ee3a581df7d505024fbb1f15202fc7db88a71d84ebd"}, + {file = "pandas-2.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:6a21ab5c89dcbd57f78d0ae16630b090eec626360085a4148693def5452d8a6b"}, + {file = "pandas-2.0.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9e4da0d45e7f34c069fe4d522359df7d23badf83abc1d1cef398895822d11061"}, + {file = "pandas-2.0.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:32fca2ee1b0d93dd71d979726b12b61faa06aeb93cf77468776287f41ff8fdc5"}, + {file = "pandas-2.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:258d3624b3ae734490e4d63c430256e716f488c4fcb7c8e9bde2d3aa46c29089"}, + {file = "pandas-2.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eae3dc34fa1aa7772dd3fc60270d13ced7346fcbcfee017d3132ec625e23bb0"}, + {file = "pandas-2.0.3-cp38-cp38-win32.whl", hash = "sha256:f3421a7afb1a43f7e38e82e844e2bca9a6d793d66c1a7f9f0ff39a795bbc5e02"}, + {file = "pandas-2.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:69d7f3884c95da3a31ef82b7618af5710dba95bb885ffab339aad925c3e8ce78"}, + {file = "pandas-2.0.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5247fb1ba347c1261cbbf0fcfba4a3121fbb4029d95d9ef4dc45406620b25c8b"}, + {file = "pandas-2.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:81af086f4543c9d8bb128328b5d32e9986e0c84d3ee673a2ac6fb57fd14f755e"}, + {file = "pandas-2.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1994c789bf12a7c5098277fb43836ce090f1073858c10f9220998ac74f37c69b"}, + {file = "pandas-2.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ec591c48e29226bcbb316e0c1e9423622bc7a4eaf1ef7c3c9fa1a3981f89641"}, + {file = "pandas-2.0.3-cp39-cp39-win32.whl", hash = "sha256:04dbdbaf2e4d46ca8da896e1805bc04eb85caa9a82e259e8eed00254d5e0c682"}, + {file = "pandas-2.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:1168574b036cd8b93abc746171c9b4f1b83467438a5e45909fed645cf8692dbc"}, + {file = "pandas-2.0.3.tar.gz", hash = "sha256:c02f372a88e0d17f36d3093a644c73cfc1788e876a7c4bcb4020a77512e2043c"}, +] + +[package.dependencies] +numpy = [ + {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, + {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, +] +python-dateutil = ">=2.8.2" +pytz = ">=2020.1" +tzdata = ">=2022.1" + +[package.extras] +all = ["PyQt5 (>=5.15.1)", "SQLAlchemy (>=1.4.16)", "beautifulsoup4 (>=4.9.3)", "bottleneck (>=1.3.2)", "brotlipy (>=0.7.0)", "fastparquet (>=0.6.3)", "fsspec (>=2021.07.0)", "gcsfs (>=2021.07.0)", "html5lib (>=1.1)", "hypothesis (>=6.34.2)", "jinja2 (>=3.0.0)", "lxml (>=4.6.3)", "matplotlib (>=3.6.1)", "numba (>=0.53.1)", "numexpr (>=2.7.3)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pandas-gbq (>=0.15.0)", "psycopg2 (>=2.8.6)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "pyreadstat (>=1.1.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)", "python-snappy (>=0.6.0)", "pyxlsb (>=1.0.8)", "qtpy (>=2.2.0)", "s3fs (>=2021.08.0)", "scipy (>=1.7.1)", "tables (>=3.6.1)", "tabulate (>=0.8.9)", "xarray (>=0.21.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)", "zstandard (>=0.15.2)"] +aws = ["s3fs (>=2021.08.0)"] +clipboard = ["PyQt5 (>=5.15.1)", "qtpy (>=2.2.0)"] +compression = ["brotlipy (>=0.7.0)", "python-snappy (>=0.6.0)", "zstandard (>=0.15.2)"] +computation = ["scipy (>=1.7.1)", "xarray (>=0.21.0)"] +excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pyxlsb (>=1.0.8)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)"] +feather = ["pyarrow (>=7.0.0)"] +fss = ["fsspec (>=2021.07.0)"] +gcp = ["gcsfs (>=2021.07.0)", "pandas-gbq (>=0.15.0)"] +hdf5 = ["tables (>=3.6.1)"] +html = ["beautifulsoup4 (>=4.9.3)", "html5lib (>=1.1)", "lxml (>=4.6.3)"] +mysql = ["SQLAlchemy (>=1.4.16)", "pymysql (>=1.0.2)"] +output-formatting = ["jinja2 (>=3.0.0)", "tabulate (>=0.8.9)"] +parquet = ["pyarrow (>=7.0.0)"] +performance = ["bottleneck (>=1.3.2)", "numba (>=0.53.1)", "numexpr (>=2.7.1)"] +plot = ["matplotlib (>=3.6.1)"] +postgresql = ["SQLAlchemy (>=1.4.16)", "psycopg2 (>=2.8.6)"] +spss = ["pyreadstat (>=1.1.2)"] +sql-other = ["SQLAlchemy (>=1.4.16)"] +test = ["hypothesis (>=6.34.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"] +xml = ["lxml (>=4.6.3)"] + [[package]] name = "pydantic" version = "1.10.12" @@ -987,6 +1054,33 @@ typing-extensions = ">=4.2.0" dotenv = ["python-dotenv (>=0.10.4)"] email = ["email-validator (>=1.0.3)"] +[[package]] +name = "python-dateutil" +version = "2.8.2" +description = "Extensions to the standard Python datetime module" +category = "main" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, + {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, +] + +[package.dependencies] +six = ">=1.5" + +[[package]] +name = "pytz" +version = "2023.3" +description = "World timezone definitions, modern and historical" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "pytz-2023.3-py2.py3-none-any.whl", hash = "sha256:a151b3abb88eda1d4e34a9814df37de2a80e301e68ba0fd856fb9b46bfbbbffb"}, + {file = "pytz-2023.3.tar.gz", hash = "sha256:1d8ce29db189191fb55338ee6d0387d82ab59f3d00eac103412d64e0ebd0c588"}, +] + [[package]] name = "pyyaml" version = "6.0.1" @@ -1218,6 +1312,18 @@ tensorflow = ["tensorflow (>=2.11.0)"] testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "numpy (>=1.21.6)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)"] torch = ["torch (>=1.10)"] +[[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] + [[package]] name = "sqlalchemy" version = "2.0.19" @@ -1541,6 +1647,18 @@ files = [ mypy-extensions = ">=0.3.0" typing-extensions = ">=3.7.4" +[[package]] +name = "tzdata" +version = "2023.3" +description = "Provider of IANA time zone data" +category = "main" +optional = false +python-versions = ">=2" +files = [ + {file = "tzdata-2023.3-py2.py3-none-any.whl", hash = "sha256:7e65763eef3120314099b6939b5546db7adce1e7d6f2e179e3df563c70511eda"}, + {file = "tzdata-2023.3.tar.gz", hash = "sha256:11ef1e08e54acb0d4f95bdb1be05da659673de4acbd21bf9c69e94cc5e907a3a"}, +] + [[package]] name = "urllib3" version = "2.0.4" @@ -1650,4 +1768,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "5b75dba208fc76f87292ee8d5431361632c215a2592d44c3f6594f7731fb59c4" +content-hash = "8e11f4d6654776ef8f9bb92232a3cc6570a1014957e46f4a14a860bf0d50daa5" diff --git a/pyproject.toml b/pyproject.toml index 6ebbb9d..c210a63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ transformers = "^4.31.0" langchain = "^0.0.247" pydantic = ">=1,<2" torch = "^2.0.1" +pandas = "^2.0.3" [build-system] requires = ["poetry-core"]