Skip to content

Commit

Permalink
add proper abort condition to mcts
Browse files Browse the repository at this point in the history
  • Loading branch information
johanneslenfers committed Oct 23, 2024
1 parent 034b4ad commit 3152851
Showing 1 changed file with 71 additions and 20 deletions.
91 changes: 71 additions & 20 deletions src/main/scala/elevate/heuristic_search/heuristics/MCTS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,20 @@ case class Node[P](
}
}

case class RolloutStep[P](
solution: Solution[P],
performance: Option[Double],
)

class MCTS[P] extends Heuristic[P] {

var counter: Int = 0

def start(panel: HeuristicPanel[P], initialSolution: Solution[P], depth: Int, samples: Int): ExplorationResult[P] = {

// execute initial solution as default
val default_performance: Option[Double] = panel.f(initialSolution)

var counter: Int = 1
counter += 1

// create a tree with initial solution as root
val rootNode: Node[P] = Node(
Expand All @@ -68,7 +74,7 @@ class MCTS[P] extends Heuristic[P] {
)

// main exploration loop
for (_ <- 1 to samples) {
while (counter < samples) {

// 1. Selection based on UCB1 value
var node: Node[P] = rootNode
Expand All @@ -90,42 +96,48 @@ class MCTS[P] extends Heuristic[P] {

// 3. Rollout
// we start the rollout at the current node
var rollout: (Solution[P], Option[Double]) = (node.solution, None)
var rollout_step: RolloutStep[P] = RolloutStep[P](
solution = node.solution,
performance = None,
)

var isTerminal: Boolean = false
while (!isTerminal && rollout._1.solutionSteps.count(step => step.strategy != elevate.core.strategies.basic.id[P]) < depth) {
while (!isTerminal && rollout_step.solution.solutionSteps.count(step => step.strategy != elevate.core.strategies.basic.id[P]) < depth) {

val actions = panel.N(rollout._1)
val actions = panel.N(rollout_step.solution)
if (actions.nonEmpty) {

// try to consider only valid ones
// rollout performance should be the minimum that was seen during rollout
rollout = choose_valid_solution_randomly(panel = panel, actions = actions, rollout._2)
rollout_step = choose_valid_solution_randomly(panel = panel, actions = actions, rollout_step.performance)

// check if rollout is empty
if (rollout._1 == null) {
if (rollout_step.solution == null) {
isTerminal = true
} else {
rollout._1.solutionSteps.foreach(step => println(s"""[${step.strategy}, ${step.location}]"""))
rollout_step.solution.solutionSteps.foreach(step => println(s"""[${step.strategy}, ${step.location}]"""))

// check if we have a dead end for this rollout
if (rollout._1 == null) {
if (rollout_step.solution == null) {
isTerminal = true
}
}

} else {
isTerminal = true
}

if (counter >= samples) {
isTerminal = true
}
}

// 4. Backpropagation
counter += 1
while (node != null) {
node.visits += 1

// this can be biased by the ranges
val win: Double = rollout._2 match {
val win: Double = rollout_step.performance match {
case Some(value) => 1 / value
case None => 0
}
Expand All @@ -134,30 +146,69 @@ class MCTS[P] extends Heuristic[P] {
}
}

def choose_valid_solution_randomly(panel: HeuristicPanel[P], actions: Seq[Solution[P]], minimum: Option[Double]): (Solution[P], Option[Double]) = {
def choose_valid_solution_randomly(panel: HeuristicPanel[P], actions: Seq[Solution[P]], minimum: Option[Double]): RolloutStep[P] = {

def findSolution(minimum: Option[Double], attempts: Set[Solution[P]]): (Solution[P], Option[Double]) = {
def findSolution(minimum: Option[Double], attempts: Set[Solution[P]]): RolloutStep[P] = {
val remainingActions = actions.filterNot(attempts.contains)

if (remainingActions.isEmpty) {
(null.asInstanceOf[Solution[P]], None) // No valid solution found
// No valid solution found
RolloutStep[P](
solution = null.asInstanceOf[Solution[P]],
performance = None,
)
} else {
val candidate: Solution[P] = remainingActions(Random.nextInt(remainingActions.size))

// get performance of
counter += 1
panel.f(candidate) match {
case Some(value) =>
minimum match {
case None =>
(candidate, Some(value)) // Valid solution found

RolloutStep[P](
solution = candidate,
performance = Some(value),
)

case Some(minimum_value) =>
value <= minimum_value match {
case true => (candidate, Some(value))
case false => (candidate, Some(minimum_value))
value < minimum_value match {
case true =>

RolloutStep[P](
solution = candidate,
performance = Some(value),
)

case false =>

RolloutStep[P](
solution = candidate,
performance = Some(minimum_value),
)
}
}
case None =>
counter < samples match {
case true =>
// continue with other candidate
findSolution(minimum, attempts + candidate)
case false =>
// stop and return candidate with minimum found so far
minimum match {
case Some(minimum_value) =>
RolloutStep[P](
solution = candidate,
performance = Some(minimum_value),
)
case None =>
RolloutStep[P](
solution = candidate,
performance = None,
)
}
}
case None => findSolution(minimum, attempts + candidate) // Add to attempts and recurse
}
}
}
Expand Down

0 comments on commit 3152851

Please sign in to comment.