Skip to content

Commit

Permalink
Merge pull request #3 from fadynakhla/arthur/numpy
Browse files Browse the repository at this point in the history
added transformer for knowledge data -> numpy matrix
  • Loading branch information
ArthurBook authored Jul 29, 2023
2 parents e7cde67 + 3ca2ba0 commit f658533
Show file tree
Hide file tree
Showing 3 changed files with 281 additions and 17 deletions.
112 changes: 104 additions & 8 deletions dr_claude/datamodels.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,60 @@
from typing import Dict, List

import collections
import itertools
from typing import (
Dict,
List,
NamedTuple,
Protocol,
Tuple,
Type,
TypeVar,
runtime_checkable,
)

import numpy as np
import pydantic


class UMLSMixin:
@runtime_checkable
class HasUMLS(Protocol):
"""
UMLS Mixin
"""

umls_code: str

def __hash__(self) -> int:

HasUMLSClass = TypeVar("HasUMLSClass", bound=HasUMLS)


def set_umls_methods(cls: Type[HasUMLSClass]) -> Type[HasUMLSClass]:
def __hash__(self: HasUMLSClass) -> int:
return hash(self.umls_code)

def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.umls_code == other.umls_code
def __eq__(self: HasUMLSClass, other: object) -> bool:
return isinstance(other, HasUMLS) and self.umls_code == other.umls_code

cls.__hash__ = __hash__
cls.__eq__ = __eq__
return cls


class Symptom(UMLSMixin, pydantic.BaseModel):
@set_umls_methods
class Symptom(pydantic.BaseModel):
"""
Symptom
"""

name: str
umls_code: str
noise_rate: float = 0.03

class Config:
frozen = True

class Condition(UMLSMixin, pydantic.BaseModel):

@set_umls_methods
class Condition(pydantic.BaseModel):
"""
Condition
"""
Expand All @@ -37,14 +65,82 @@ def __hash__(self) -> int:
name: str
umls_code: str

class Config:
frozen = True


@set_umls_methods
class WeightedSymptom(Symptom):
"""
Weight
"""

weight: float # between 0,1

class Config:
frozen = True


class DiseaseSymptomKnowledgeBase(pydantic.BaseModel):
condition_symptoms: Dict[Condition, List[WeightedSymptom]]


"""
Helper methods
"""


class MatrixIndex(NamedTuple):
rows: Dict[Symptom, int]
columns: Dict[Condition, int]


class MonotonicCounter:
"""
A counter that increments and returns a new value each time it is called
"""

def __init__(self, start: int = 0):
self._count = start

def __call__(self) -> int:
c = self._count
self._count += 1
return c


class SymptomTransformer:
@staticmethod
def to_symptom(symptom: WeightedSymptom) -> Symptom:
return Symptom(**symptom.dict())


class DiseaseSymptomKnowledgeBaseTransformer:
@staticmethod
def to_numpy(kb: DiseaseSymptomKnowledgeBase) -> Tuple[np.ndarray, MatrixIndex]:
"""
Returns a numpy array of the weights of each symptom for each condition
"""

## init symptoms
all_symptoms = itertools.chain.from_iterable(kb.condition_symptoms.values())
symptom_idx: Dict[Symptom, int] = collections.defaultdict(MonotonicCounter())
[symptom_idx[s] for s in map(SymptomTransformer.to_symptom, all_symptoms)]

## init conditions
disease_idx: Dict[Condition, int] = collections.defaultdict(MonotonicCounter())
[disease_idx[condition] for condition in kb.condition_symptoms.keys()]

## the antagonist
probas = np.zeros((len(symptom_idx), len(disease_idx)))

## fill noise vals
for symptom, index in symptom_idx.items():
probas[index, :] = symptom.noise_rate

## fill known probas
for condition, symptoms in kb.condition_symptoms.items():
for symptom in symptoms:
probas[symptom_idx[symptom], disease_idx[condition]] = symptom.weight

return (probas, MatrixIndex(symptom_idx, disease_idx))
44 changes: 41 additions & 3 deletions dr_claude/mcts/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,41 @@
class MCTS:
def __init__(self) -> None:
pass
from typing import List, Optional
from dr_claude import datamodels
import math
import random


if __name__ == "__main__":
db = datamodels.DiseaseSymptomKnowledgeBase(
condition_symptoms={
datamodels.Condition(name="COVID-19", umls_code="C0000001"): [
datamodels.WeightedSymptom(
name="Fever",
umls_code="C0000002",
weight=0.5,
noise_rate=0.2,
),
datamodels.WeightedSymptom(
name="Cough",
umls_code="C0000003",
weight=0.5,
noise_rate=0.1,
),
],
datamodels.Condition(name="Common Cold", umls_code="C0000004"): [
datamodels.WeightedSymptom(
name="Fever",
umls_code="C0000002",
weight=0.5,
noise_rate=0.05,
),
datamodels.WeightedSymptom(
name="Runny nose",
umls_code="C0000004",
weight=0.5,
noise_rate=0.01,
),
],
}
)

dataframe, index = datamodels.DiseaseSymptomKnowledgeBaseTransformer.to_numpy(db)
Loading

0 comments on commit f658533

Please sign in to comment.