Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

some fixes #19

Merged
merged 1 commit into from
Jul 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Empty file added dr_claude/comparisons/claude.py
Empty file.
21 changes: 17 additions & 4 deletions dr_claude/datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,17 @@ def __hash__(self: HasUMLSClass) -> int:
return hash(self.umls_code)

def __eq__(self: HasUMLSClass, other: object) -> bool:
return isinstance(other, HasUMLS) and self.umls_code == other.umls_code
return (
isinstance(other, HasUMLS)
and self.umls_code == other.umls_code
and self.name
)

cls.__hash__ = __hash__
cls.__eq__ = __eq__
return cls


@set_umls_methods
class Symptom(pydantic.BaseModel):
"""
Symptom
Expand All @@ -53,6 +56,9 @@ class Symptom(pydantic.BaseModel):
class Config:
frozen = True

def __eq__(self, other):
return self.name == other.name and self.umls_code == other.umls_code


@set_umls_methods
class Condition(pydantic.BaseModel):
Expand All @@ -70,7 +76,6 @@ class Config:
frozen = True


@set_umls_methods
class WeightedSymptom(Symptom):
"""
Weight
Expand All @@ -81,6 +86,9 @@ class WeightedSymptom(Symptom):
class Config:
frozen = True

def __eq__(self, other):
return self.name == other.name and self.umls_code == other.umls_code


class DiseaseSymptomKnowledgeBase(pydantic.BaseModel):
condition_symptoms: Dict[Condition, List[WeightedSymptom]]
Expand Down Expand Up @@ -151,6 +159,8 @@ def to_numpy(kb: DiseaseSymptomKnowledgeBase) -> ProbabilityMatrix:

## the antagonist
probas = np.zeros((len(symptom_idx), len(disease_idx)))
symptom_idx = dict(symptom_idx)
disease_idx = dict(disease_idx)

## fill noise vals
for symptom, index in symptom_idx.items():
Expand All @@ -159,7 +169,10 @@ def to_numpy(kb: DiseaseSymptomKnowledgeBase) -> ProbabilityMatrix:
## fill known probas
for condition, symptoms in kb.condition_symptoms.items():
for symptom in symptoms:
probas[symptom_idx[symptom], disease_idx[condition]] = symptom.weight
probas[
symptom_idx[SymptomTransformer.to_symptom(symptom)],
disease_idx[condition],
] = symptom.weight

return ProbabilityMatrix(
matrix=probas,
Expand Down
18 changes: 12 additions & 6 deletions tests/test_mcts.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import copy
import random
from typing import Callable, Collection, Self
from typing import Callable, Collection

import loguru
from loguru import logger

from dr_claude import datamodels
from dr_claude.mcts import action_states
Expand Down Expand Up @@ -58,10 +59,10 @@ def log_wrapper(state: ConvergenceTestState) -> float:

def test_convergence():
## load the knowledge base
reader = kb_reading.NYPHKnowldegeBaseReader("data/NYPHKnowldegeBase.html")
# reader = kb_reading.NYPHKnowldegeBaseReader("data/NYPHKnowldegeBase.html")
reader = kb_reading.CSVKnowledgeBaseReader("data/ClaudeKnowledgeBase.csv")
kb = reader.load_knowledge_base()
matrix = datamodels.DiseaseSymptomKnowledgeBaseTransformer.to_numpy(kb)

## create the initial state
conditions = list(matrix.columns.keys())
the_condition = conditions[0] # TODO chose condiiton here
Expand All @@ -71,23 +72,28 @@ def test_convergence():
if matrix[symptom, the_condition] > symptom.noise_rate
]
state = ConvergenceTestState(matrix, discount_rate=1e-9)
# state = action_states.SimulationNextActionState(matrix, discount_rate=1e-9)

state.set_condition(the_condition, the_symptoms)
state.pertinent_pos.update(random.choices(the_symptoms, k=2))
state.pertinent_pos.update(random.choices(the_symptoms, k=1))

## Rollout policy
rollout_policy = action_states.RandomRollOutPolicy()
rollout_policy = action_states.ArgMaxDiagnosisRolloutPolicy()
rollout_policy = logtrueconditionhook(rollout_policy)

## create the initial state
searcher = mcts.mcts(timeLimit=3000, rolloutPolicy=rollout_policy)

action = None
while not isinstance(action, datamodels.Condition):
if action is not None:
assert isinstance(action, datamodels.Symptom)
action = searcher.search(initialState=state)
assert isinstance(action, datamodels.Symptom)
if action in the_symptoms:
logger.info("got a pertinent positive: {}", action)
state.pertinent_pos.add(action)
else:
logger.info("got a pertinent negative: {}", action)
state.pertinent_neg.add(action)
loguru.logger.info(f"action={action}")
diagnosis = action
Expand Down