diff --git a/dr_claude/mcts_module/__init__.py b/dr_claude/mcts/__init__.py similarity index 100% rename from dr_claude/mcts_module/__init__.py rename to dr_claude/mcts/__init__.py diff --git a/dr_claude/mcts_module/action_states.py b/dr_claude/mcts/action_states.py similarity index 89% rename from dr_claude/mcts_module/action_states.py rename to dr_claude/mcts/action_states.py index 4473281..bf5705c 100644 --- a/dr_claude/mcts_module/action_states.py +++ b/dr_claude/mcts/action_states.py @@ -5,11 +5,11 @@ import random from typing import Collection, Dict, Generic, List, Optional, Set, TypeVar, Union from typing_extensions import Self -import mcts + import numpy as np from dr_claude import datamodels -from dr_claude.mcts_module import probability_calcs +from dr_claude.mcts import probability_calcs T = TypeVar("T") @@ -200,26 +200,3 @@ class ArgMaxDiagnosisRolloutPolicy(RollOutPolicy): def __call__(self, state: SimulationNextActionState) -> float: actions_space = state.getConditionProbabilityDict() return max(actions_space.values()) - - -""" -The MCTS algorithm -""" - - -class MCTS(mcts.mcts): - def getBestChild(self, node: mcts.treeNode, explorationValue: float): - bestValue = float("-inf") - bestNodes = [] - for child in node.children.values(): - nodeValue = ( - node.state.getCurrentPlayer() * child.totalReward / child.numVisits - + explorationValue - * math.sqrt(2 * math.log(node.numVisits) / child.numVisits) - ) - if nodeValue > bestValue: - bestValue = nodeValue - bestNodes = [child] - elif nodeValue == bestValue: - bestNodes.append(child) - return random.choice(bestNodes) diff --git a/dr_claude/mcts_module/probability_calcs.py b/dr_claude/mcts/probability_calcs.py similarity index 100% rename from dr_claude/mcts_module/probability_calcs.py rename to dr_claude/mcts/probability_calcs.py diff --git a/tests/test_conditional_proba.py b/tests/test_conditional_proba.py index 7211ecd..84ba8c5 100644 --- a/tests/test_conditional_proba.py +++ b/tests/test_conditional_proba.py @@ -1,5 +1,5 @@ from dr_claude import datamodels -from dr_claude.mcts_module import probability_calcs +from dr_claude.mcts import probability_calcs def test_conditional_condition_proba(): diff --git a/tests/test_mcts.py b/tests/test_mcts.py index 6468112..b989db9 100644 --- a/tests/test_mcts.py +++ b/tests/test_mcts.py @@ -1,12 +1,10 @@ import copy -import random from typing import Callable, Collection, Self import loguru from dr_claude import datamodels -from dr_claude import mcts_module -from dr_claude.mcts_module import action_states +from dr_claude.mcts import action_states import mcts @@ -73,14 +71,13 @@ def test_convergence(): ] state = ConvergenceTestState(matrix, discount_rate=1e-9) state.set_condition(the_condition, the_symptoms) - state.pertinent_pos.add(random.choice(the_symptoms)) ## Rollout policy - rollout_policy = action_states.ArgMaxDiagnosisRolloutPolicy() + rollout_policy = action_states.RandomRollOutPolicy() rollout_policy = logtrueconditionhook(rollout_policy) ## create the initial state - searcher = action_states.MCTS(timeLimit=3000, rolloutPolicy=rollout_policy) + searcher = mcts.mcts(timeLimit=3000, rolloutPolicy=rollout_policy) action = None while not isinstance(action, datamodels.Condition):