From b8e012a90ff96ddefb1399f743f5de19fed33552 Mon Sep 17 00:00:00 2001 From: Wian Stipp Date: Sun, 30 Jul 2023 00:09:34 -0500 Subject: [PATCH] some fixes --- dr_claude/comparisons/__init__.py | 0 dr_claude/comparisons/claude.py | 0 dr_claude/datamodels.py | 21 +++++++++++++++++---- tests/test_mcts.py | 18 ++++++++++++------ 4 files changed, 29 insertions(+), 10 deletions(-) create mode 100644 dr_claude/comparisons/__init__.py create mode 100644 dr_claude/comparisons/claude.py diff --git a/dr_claude/comparisons/__init__.py b/dr_claude/comparisons/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dr_claude/comparisons/claude.py b/dr_claude/comparisons/claude.py new file mode 100644 index 0000000..e69de29 diff --git a/dr_claude/datamodels.py b/dr_claude/datamodels.py index 54ccd23..25ad6b1 100644 --- a/dr_claude/datamodels.py +++ b/dr_claude/datamodels.py @@ -33,14 +33,17 @@ def __hash__(self: HasUMLSClass) -> int: return hash(self.umls_code) def __eq__(self: HasUMLSClass, other: object) -> bool: - return isinstance(other, HasUMLS) and self.umls_code == other.umls_code + return ( + isinstance(other, HasUMLS) + and self.umls_code == other.umls_code + and self.name + ) cls.__hash__ = __hash__ cls.__eq__ = __eq__ return cls -@set_umls_methods class Symptom(pydantic.BaseModel): """ Symptom @@ -53,6 +56,9 @@ class Symptom(pydantic.BaseModel): class Config: frozen = True + def __eq__(self, other): + return self.name == other.name and self.umls_code == other.umls_code + @set_umls_methods class Condition(pydantic.BaseModel): @@ -70,7 +76,6 @@ class Config: frozen = True -@set_umls_methods class WeightedSymptom(Symptom): """ Weight @@ -81,6 +86,9 @@ class WeightedSymptom(Symptom): class Config: frozen = True + def __eq__(self, other): + return self.name == other.name and self.umls_code == other.umls_code + class DiseaseSymptomKnowledgeBase(pydantic.BaseModel): condition_symptoms: Dict[Condition, List[WeightedSymptom]] @@ -151,6 +159,8 @@ def to_numpy(kb: DiseaseSymptomKnowledgeBase) -> ProbabilityMatrix: ## the antagonist probas = np.zeros((len(symptom_idx), len(disease_idx))) + symptom_idx = dict(symptom_idx) + disease_idx = dict(disease_idx) ## fill noise vals for symptom, index in symptom_idx.items(): @@ -159,7 +169,10 @@ def to_numpy(kb: DiseaseSymptomKnowledgeBase) -> ProbabilityMatrix: ## fill known probas for condition, symptoms in kb.condition_symptoms.items(): for symptom in symptoms: - probas[symptom_idx[symptom], disease_idx[condition]] = symptom.weight + probas[ + symptom_idx[SymptomTransformer.to_symptom(symptom)], + disease_idx[condition], + ] = symptom.weight return ProbabilityMatrix( matrix=probas, diff --git a/tests/test_mcts.py b/tests/test_mcts.py index 7b721ef..699fe25 100644 --- a/tests/test_mcts.py +++ b/tests/test_mcts.py @@ -1,8 +1,9 @@ import copy import random -from typing import Callable, Collection, Self +from typing import Callable, Collection import loguru +from loguru import logger from dr_claude import datamodels from dr_claude.mcts import action_states @@ -58,10 +59,10 @@ def log_wrapper(state: ConvergenceTestState) -> float: def test_convergence(): ## load the knowledge base - reader = kb_reading.NYPHKnowldegeBaseReader("data/NYPHKnowldegeBase.html") + # reader = kb_reading.NYPHKnowldegeBaseReader("data/NYPHKnowldegeBase.html") + reader = kb_reading.CSVKnowledgeBaseReader("data/ClaudeKnowledgeBase.csv") kb = reader.load_knowledge_base() matrix = datamodels.DiseaseSymptomKnowledgeBaseTransformer.to_numpy(kb) - ## create the initial state conditions = list(matrix.columns.keys()) the_condition = conditions[0] # TODO chose condiiton here @@ -71,11 +72,13 @@ def test_convergence(): if matrix[symptom, the_condition] > symptom.noise_rate ] state = ConvergenceTestState(matrix, discount_rate=1e-9) + # state = action_states.SimulationNextActionState(matrix, discount_rate=1e-9) + state.set_condition(the_condition, the_symptoms) - state.pertinent_pos.update(random.choices(the_symptoms, k=2)) + state.pertinent_pos.update(random.choices(the_symptoms, k=1)) ## Rollout policy - rollout_policy = action_states.RandomRollOutPolicy() + rollout_policy = action_states.ArgMaxDiagnosisRolloutPolicy() rollout_policy = logtrueconditionhook(rollout_policy) ## create the initial state @@ -83,11 +86,14 @@ def test_convergence(): action = None while not isinstance(action, datamodels.Condition): + if action is not None: + assert isinstance(action, datamodels.Symptom) action = searcher.search(initialState=state) - assert isinstance(action, datamodels.Symptom) if action in the_symptoms: + logger.info("got a pertinent positive: {}", action) state.pertinent_pos.add(action) else: + logger.info("got a pertinent negative: {}", action) state.pertinent_neg.add(action) loguru.logger.info(f"action={action}") diagnosis = action