Skip to content

Commit

Permalink
Merge pull request #19 from fadynakhla/wian/fixing
Browse files Browse the repository at this point in the history
some fixes
  • Loading branch information
fadynakhla authored Jul 30, 2023
2 parents 2a8c42f + b8e012a commit 26ec9eb
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 10 deletions.
Empty file.
Empty file added dr_claude/comparisons/claude.py
Empty file.
21 changes: 17 additions & 4 deletions dr_claude/datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,17 @@ def __hash__(self: HasUMLSClass) -> int:
return hash(self.umls_code)

def __eq__(self: HasUMLSClass, other: object) -> bool:
return isinstance(other, HasUMLS) and self.umls_code == other.umls_code
return (
isinstance(other, HasUMLS)
and self.umls_code == other.umls_code
and self.name
)

cls.__hash__ = __hash__
cls.__eq__ = __eq__
return cls


@set_umls_methods
class Symptom(pydantic.BaseModel):
"""
Symptom
Expand All @@ -53,6 +56,9 @@ class Symptom(pydantic.BaseModel):
class Config:
frozen = True

def __eq__(self, other):
return self.name == other.name and self.umls_code == other.umls_code


@set_umls_methods
class Condition(pydantic.BaseModel):
Expand All @@ -70,7 +76,6 @@ class Config:
frozen = True


@set_umls_methods
class WeightedSymptom(Symptom):
"""
Weight
Expand All @@ -81,6 +86,9 @@ class WeightedSymptom(Symptom):
class Config:
frozen = True

def __eq__(self, other):
return self.name == other.name and self.umls_code == other.umls_code


class DiseaseSymptomKnowledgeBase(pydantic.BaseModel):
condition_symptoms: Dict[Condition, List[WeightedSymptom]]
Expand Down Expand Up @@ -151,6 +159,8 @@ def to_numpy(kb: DiseaseSymptomKnowledgeBase) -> ProbabilityMatrix:

## the antagonist
probas = np.zeros((len(symptom_idx), len(disease_idx)))
symptom_idx = dict(symptom_idx)
disease_idx = dict(disease_idx)

## fill noise vals
for symptom, index in symptom_idx.items():
Expand All @@ -159,7 +169,10 @@ def to_numpy(kb: DiseaseSymptomKnowledgeBase) -> ProbabilityMatrix:
## fill known probas
for condition, symptoms in kb.condition_symptoms.items():
for symptom in symptoms:
probas[symptom_idx[symptom], disease_idx[condition]] = symptom.weight
probas[
symptom_idx[SymptomTransformer.to_symptom(symptom)],
disease_idx[condition],
] = symptom.weight

return ProbabilityMatrix(
matrix=probas,
Expand Down
18 changes: 12 additions & 6 deletions tests/test_mcts.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import copy
import random
from typing import Callable, Collection, Self
from typing import Callable, Collection

import loguru
from loguru import logger

from dr_claude import datamodels
from dr_claude.mcts import action_states
Expand Down Expand Up @@ -58,10 +59,10 @@ def log_wrapper(state: ConvergenceTestState) -> float:

def test_convergence():
## load the knowledge base
reader = kb_reading.NYPHKnowldegeBaseReader("data/NYPHKnowldegeBase.html")
# reader = kb_reading.NYPHKnowldegeBaseReader("data/NYPHKnowldegeBase.html")
reader = kb_reading.CSVKnowledgeBaseReader("data/ClaudeKnowledgeBase.csv")
kb = reader.load_knowledge_base()
matrix = datamodels.DiseaseSymptomKnowledgeBaseTransformer.to_numpy(kb)

## create the initial state
conditions = list(matrix.columns.keys())
the_condition = conditions[0] # TODO chose condiiton here
Expand All @@ -71,23 +72,28 @@ def test_convergence():
if matrix[symptom, the_condition] > symptom.noise_rate
]
state = ConvergenceTestState(matrix, discount_rate=1e-9)
# state = action_states.SimulationNextActionState(matrix, discount_rate=1e-9)

state.set_condition(the_condition, the_symptoms)
state.pertinent_pos.update(random.choices(the_symptoms, k=2))
state.pertinent_pos.update(random.choices(the_symptoms, k=1))

## Rollout policy
rollout_policy = action_states.RandomRollOutPolicy()
rollout_policy = action_states.ArgMaxDiagnosisRolloutPolicy()
rollout_policy = logtrueconditionhook(rollout_policy)

## create the initial state
searcher = mcts.mcts(timeLimit=3000, rolloutPolicy=rollout_policy)

action = None
while not isinstance(action, datamodels.Condition):
if action is not None:
assert isinstance(action, datamodels.Symptom)
action = searcher.search(initialState=state)
assert isinstance(action, datamodels.Symptom)
if action in the_symptoms:
logger.info("got a pertinent positive: {}", action)
state.pertinent_pos.add(action)
else:
logger.info("got a pertinent negative: {}", action)
state.pertinent_neg.add(action)
loguru.logger.info(f"action={action}")
diagnosis = action
Expand Down

0 comments on commit 26ec9eb

Please sign in to comment.