Skip to content

Commit

Permalink
Merge pull request #12 from fadynakhla/arthur/mcts
Browse files Browse the repository at this point in the history
Arthur/mcts
  • Loading branch information
ArthurBook authored Jul 30, 2023
2 parents ca54953 + 51662d7 commit 5b33d79
Show file tree
Hide file tree
Showing 6 changed files with 6,243 additions and 877 deletions.
117 changes: 72 additions & 45 deletions dr_claude/mcts/action_states.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,35 @@
import abc
import loguru
import copy
import math
import random
from typing import Collection, Dict, List, Optional, Set, Union
from typing import Collection, Dict, Generic, List, Optional, Set, TypeVar, Union
from typing_extensions import Self

import numpy as np

from dr_claude import datamodels
from dr_claude.mcts import probability_calcs


class ActionState(abc.ABC):
T = TypeVar("T")


class ActionState(abc.ABC, Generic[T]):
"""
Base class for the state of the game
"""

@abc.abstractmethod
def getCurrentPlayer(self):
def getCurrentPlayer(self) -> int:
...

@abc.abstractmethod
def getPossibleActions(self):
def getPossibleActions(self) -> List[T]:
...

@abc.abstractmethod
def takeAction(self, action):
def takeAction(self, action: T) -> Self:
...

@abc.abstractmethod
Expand All @@ -36,18 +41,14 @@ def getReward(self) -> float:
...


class NextBestActionState(ActionState, abc.ABC):
class NextBestActionState(
ActionState[Union[datamodels.Symptom, datamodels.Condition]], abc.ABC
):
"""
base class for simulating next best action for diagnosis
"""

diagnosis: Optional[datamodels.Condition] = None
pertinent_pos: Set[
datamodels.Symptom
] # symptoms that are confirmed pertinent positives
pertinent_neg: Set[
datamodels.Symptom
] # symptoms that are confirmed pertinent negatives

def takeAction(
self, action: Union[datamodels.Symptom, datamodels.Condition]
Expand Down Expand Up @@ -77,7 +78,35 @@ def handleDiagnostic(
DEFAULT_DISCOUNT_RATE = 0.1


class SimulationNextActionState(NextBestActionState):
class SimulationMixin:
dynamics: datamodels.ProbabilityMatrix
pertinent_pos: Set[datamodels.Symptom]
pertinent_neg: Set[datamodels.Symptom]

def getSymptomProbabilityDict(self) -> Dict[datamodels.Symptom, float]:
condition_posterior = self.getConditionProbabilityVector()
symptom_posterior = probability_calcs.compute_symptom_posterior_flat_prior_dict(
matrix=self.dynamics,
condition_probas=condition_posterior,
)
return symptom_posterior

def getConditionProbabilityVector(self) -> np.ndarray:
return probability_calcs.compute_condition_posterior_flat_prior(
self.dynamics,
pertinent_positives=self.pertinent_pos,
pertinent_negatives=self.pertinent_neg,
)

def getConditionProbabilityDict(self) -> Dict[datamodels.Condition, float]:
return probability_calcs.compute_condition_posterior_flat_prior_dict(
self.dynamics,
pertinent_positives=self.pertinent_pos,
pertinent_negatives=self.pertinent_neg,
)


class SimulationNextActionState(SimulationMixin, NextBestActionState):
def __init__(
self,
matrix: datamodels.ProbabilityMatrix,
Expand Down Expand Up @@ -109,16 +138,6 @@ def getPossibleActions(
) -> List[Union[datamodels.Symptom, datamodels.Condition]]:
return list(self.remaining_symptoms.union(self.conditions))

def takeAction(
self, action: Union[datamodels.Symptom, datamodels.Condition]
) -> "NextBestActionState":
if isinstance(action, datamodels.Symptom):
return self.handleSymptom(action)
elif isinstance(action, datamodels.Condition):
return self.handleDiagnostic(action)
else:
raise ValueError(f"Unknown action type {action}")

def handleSymptom(self, symptom: datamodels.Symptom) -> "NextBestActionState":
next_self = copy.deepcopy(self)
next_self.remaining_symptoms.remove(symptom)
Expand All @@ -130,28 +149,6 @@ def handleSymptom(self, symptom: datamodels.Symptom) -> "NextBestActionState":
self.increment_discount_factor()
return next_self

def getSymptomProbabilityDict(self) -> Dict[datamodels.Symptom, float]:
condition_posterior = self.getConditionProbabilityVector()
symptom_posterior = probability_calcs.compute_symptom_posterior_flat_prior_dict(
matrix=self.dynamics,
condition_probas=condition_posterior,
)
return symptom_posterior

def getConditionProbabilityVector(self) -> np.ndarray:
return probability_calcs.compute_condition_posterior_flat_prior(
self.dynamics,
pertinent_positives=self.pertinent_pos,
pertinent_negatives=self.pertinent_neg,
)

def getConditionProbabilityDict(self) -> Dict[datamodels.Condition, float]:
return probability_calcs.compute_condition_posterior_flat_prior_dict(
self.dynamics,
pertinent_positives=self.pertinent_pos,
pertinent_negatives=self.pertinent_neg,
)

def handleDiagnostic(
self, condition: datamodels.Condition
) -> "NextBestActionState":
Expand All @@ -171,3 +168,33 @@ def getReward(self) -> float:
assert self.diagnosis is not None
return conditions[self.dynamics.columns[self.diagnosis]]


"""
Rollout policy
"""


class RollOutPolicy(abc.ABC):
@abc.abstractmethod
def __call__(self, state: ActionState) -> float:
...


class RandomRollOutPolicy(RollOutPolicy):
def __call__(self, state: NextBestActionState) -> float:
while not state.isTerminal():
try:
actions_space = state.getPossibleActions()
action = random.choice(actions_space)
except IndexError:
raise Exception(
"Non-terminal state has no possible actions: " + str(state)
)
state = state.takeAction(action)
return state.getReward()


class ArgMaxDiagnosisRolloutPolicy(RollOutPolicy):
def __call__(self, state: SimulationNextActionState) -> float:
actions_space = state.getConditionProbabilityDict()
return max(actions_space.values())
Loading

0 comments on commit 5b33d79

Please sign in to comment.