Skip to content

Commit

Permalink
Merge pull request #14 from fadynakhla/arthur/sentient
Browse files Browse the repository at this point in the history
Arthur/sentient
  • Loading branch information
WianStipp authored Jul 30, 2023
2 parents 1e927b4 + b09fff6 commit 50e408d
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 6 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import random
from typing import Collection, Dict, Generic, List, Optional, Set, TypeVar, Union
from typing_extensions import Self

import mcts
import numpy as np

from dr_claude import datamodels
from dr_claude.mcts import probability_calcs
from dr_claude.mcts_module import probability_calcs


T = TypeVar("T")
Expand Down Expand Up @@ -200,3 +200,26 @@ class ArgMaxDiagnosisRolloutPolicy(RollOutPolicy):
def __call__(self, state: SimulationNextActionState) -> float:
actions_space = state.getConditionProbabilityDict()
return max(actions_space.values())


"""
The MCTS algorithm
"""


class MCTS(mcts.mcts):
def getBestChild(self, node: mcts.treeNode, explorationValue: float):
bestValue = float("-inf")
bestNodes = []
for child in node.children.values():
nodeValue = (
node.state.getCurrentPlayer() * child.totalReward / child.numVisits
+ explorationValue
* math.sqrt(2 * math.log(node.numVisits) / child.numVisits)
)
if nodeValue > bestValue:
bestValue = nodeValue
bestNodes = [child]
elif nodeValue == bestValue:
bestNodes.append(child)
return random.choice(bestNodes)
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/test_conditional_proba.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dr_claude import datamodels
from dr_claude.mcts import probability_calcs
from dr_claude.mcts_module import probability_calcs


def test_conditional_condition_proba():
Expand Down
9 changes: 6 additions & 3 deletions tests/test_mcts.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import copy
import random
from typing import Callable, Collection, Self

import loguru

from dr_claude import datamodels
from dr_claude.mcts import action_states
from dr_claude import mcts_module
from dr_claude.mcts_module import action_states

import mcts

Expand Down Expand Up @@ -71,13 +73,14 @@ def test_convergence():
]
state = ConvergenceTestState(matrix, discount_rate=1e-9)
state.set_condition(the_condition, the_symptoms)
state.pertinent_pos.add(random.choice(the_symptoms))

## Rollout policy
rollout_policy = action_states.RandomRollOutPolicy()
rollout_policy = action_states.ArgMaxDiagnosisRolloutPolicy()
rollout_policy = logtrueconditionhook(rollout_policy)

## create the initial state
searcher = mcts.mcts(timeLimit=3000, rolloutPolicy=rollout_policy)
searcher = action_states.MCTS(timeLimit=3000, rolloutPolicy=rollout_policy)

action = None
while not isinstance(action, datamodels.Condition):
Expand Down

0 comments on commit 50e408d

Please sign in to comment.