Skip to content

Commit

Permalink
feat: Add First play urgency enhancement to selection phase
Browse files Browse the repository at this point in the history
The First Play Urgency parameter allows the algorithm to exploit more than explore from the
beginning when there is a good reason to do so. With FPU, the algorithm no longer waits to expand
every child node before exploiting some of them.
  • Loading branch information
Philippe Vaillancourt committed May 5, 2018
1 parent 57072b3 commit 3388c05
Show file tree
Hide file tree
Showing 10 changed files with 310 additions and 258 deletions.
17 changes: 14 additions & 3 deletions src/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ export class Controller<State extends Playerwise, Action> {
private mcts_!: MCTSFacade<State, Action>
private duration_!: number
private explorationParam_!: number
private fpuParam_!: number
private simulate_!: string[]
private expand_!: string[]
private select_!: string[]

/**
* Creates an instance of Controller.
Expand All @@ -58,12 +61,18 @@ export class Controller<State extends Playerwise, Action> {
config: {
duration: number
explorationParam?: number
fpuParam?: number
simulate?: string[]
expand?: string[]
select?: string[]
}
) {
this.duration_ = config.duration
this.explorationParam_ = config.explorationParam || 1.414
this.fpuParam_ = config.fpuParam || Infinity
this.simulate_ = config.simulate || []
this.expand_ = config.expand || []
this.select_ = config.select || []

this.init(funcs)
}
Expand Down Expand Up @@ -92,10 +101,12 @@ export class Controller<State extends Playerwise, Action> {
// This is where we bootstrap the library according to initialization options.
const data: Map<string, MCTSState<State, Action>> = new Map()
const dataStore = new DataStore(data)
const ucb1: UCB1<State, Action> = new DefaultUCB1(this.explorationParam_)
const bestChild = new DefaultBestChild(ucb1)

const expand = new DefaultExpand(funcs.applyAction, funcs.generateActions, dataStore)
const UCB1: UCB1<State, Action> = new DefaultUCB1()
const bestChild = new DefaultBestChild(UCB1)
const select = new DefaultSelect(funcs.stateIsTerminal, expand, bestChild)

const select = new DefaultSelect(funcs.stateIsTerminal, expand, bestChild, ucb1, this.fpuParam_)

let simulate: Simulate<State, Action>
if (this.simulate_.includes('decisive')) {
Expand Down
20 changes: 19 additions & 1 deletion src/macao.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,16 @@ export class Macao<State extends Playerwise, Action> {
* @param {object} config Configuration options
* @param {number} config.duration Run time of the algorithm, in milliseconds.
* @param {number | undefined} config.explorationParam The exploration parameter constant.
* Used in [UCT](https://en.wikipedia.org/wiki/Monte_Carlo_tree_search). Defaults to 1.414.
* @param {number | undefined} config.fpuParam The First play urgency parameter. Used to encourage
* early exploitation. Defaults to `Infinity`.
* See [Exploration exploitation in Go:
* UCT for Monte-Carlo Go](https://hal.archives-ouvertes.fr/hal-00115330/document)
* @param {string[]} config.simulate An array of the simulation algorithm enhancements
* you wish to use.
* used in [UCT](https://en.wikipedia.org/wiki/Monte_Carlo_tree_search). Defaults to 1.414.
* @param {string[]} config.expand An array of the expand algorithm enhancements
* you wish to use.
*
*/
constructor(
funcs: {
Expand All @@ -80,11 +87,22 @@ export class Macao<State extends Playerwise, Action> {
config: {
duration: number
explorationParam?: number
fpuParam?: number
/**
* An array of the `simulate` algorithm enhancements you wish to use.
* Valid options: "decisive", "anti-decisive".
*/
simulate?: string[]
/**
* An array of the `expand` algorithm enhancements you wish to use.
* Valid options: none at the moment.
*/
expand?: string[]
/**
* An array of the `select` algorithm enhancements you wish to use.
* Valid options: none at the moment.
*/
select?: string[]
}
) {
this.controller_ = new Controller(funcs, config)
Expand Down
15 changes: 1 addition & 14 deletions src/mcts/expand/expand.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { MCTSNode, Playerwise, ApplyAction, GenerateActions, MCTSState } from '../../entities'
import { DataGateway } from '../mcts'
import { spliceRandom } from '../../utils'
import { BestChild } from '../select/best-child/best-child'

/**
*
Expand All @@ -26,26 +27,12 @@ export interface Expand<State, Action> {
* @template Action
*/
export class DefaultExpand<State extends Playerwise, Action> implements Expand<State, Action> {
/**
* Creates an instance of DefaultExpand.
* @param {ApplyAction<State, Action>} applyAction_
* @param {GenerateActions<State, Action>} generateActions_
* @param {DataGateway<string, MCTSState<State, Action>>} dataStore_
* @memberof DefaultExpand
*/
constructor(
private applyAction_: ApplyAction<State, Action>,
private generateActions_: GenerateActions<State, Action>,
private dataStore_: DataGateway<string, MCTSState<State, Action>>
) {}

/**
*
*
* @param {MCTSNode<State, Action>} node
* @returns {MCTSNode<State, Action>}
* @memberof DefaultExpand
*/
run(node: MCTSNode<State, Action>): MCTSNode<State, Action> {
const action = spliceRandom(node.possibleActionsLeftToExpand)
const state = this.applyAction_(node.mctsState.state, action)
Expand Down
6 changes: 3 additions & 3 deletions src/mcts/mcts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ export class DefaultMCTSFacade<State extends Playerwise, Action>
getAction(state: State, duration?: number): Action {
const rootNode = this.createRootNode_(state)
loopFor(duration || this.duration_).milliseconds(() => {
const node = this.select_.run(rootNode, this.explorationParam_)
const node = this.select_.run(rootNode)
const score = this.simulate_.run(node.mctsState.state)
this.backPropagate_.run(node, score)
})
const bestChild = this.bestChild_.run(rootNode, 0)
return bestChild.action as Action
const bestChild = this.bestChild_.run(rootNode, true)
return bestChild!.action as Action
}

/**
Expand Down
26 changes: 10 additions & 16 deletions src/mcts/select/best-child/best-child.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { MCTSNode, Playerwise, MCTSState } from '../../../entities'
* @template Action
*/
export interface BestChild<State, Action> {
run: (node: MCTSNode<State, Action>, explorationParam: number) => MCTSNode<State, Action>
run: (node: MCTSNode<State, Action>, exploit?: boolean) => MCTSNode<State, Action> | undefined
}

/**
Expand Down Expand Up @@ -40,14 +40,14 @@ export class DefaultBestChild<State extends Playerwise, Action>
* @returns {MCTSNode<State, Action>}
* @memberof DefaultBestChild
*/
run(node: MCTSNode<State, Action>, explorationParam: number): MCTSNode<State, Action> {
run(node: MCTSNode<State, Action>, exploit = false): MCTSNode<State, Action> | undefined {
if (!node.children.length) {
throw new Error('Cannot find the best children as the current node does not have children')
return undefined
}

const selectedNode = node.children.reduce((p, c) => {
return this.UCB1_.run(node.mctsState, p.mctsState, explorationParam) >
this.UCB1_.run(node.mctsState, c.mctsState, explorationParam)
return this.UCB1_.run(node.mctsState, p.mctsState, exploit) >
this.UCB1_.run(node.mctsState, c.mctsState, exploit)
? p
: c
})
Expand All @@ -66,11 +66,7 @@ export class DefaultBestChild<State extends Playerwise, Action>
* @template Action
*/
export interface UCB1<State, Action> {
run(
parent: MCTSState<State, Action>,
child: MCTSState<State, Action>,
explorationParam: number
): number
run(parent: MCTSState<State, Action>, child: MCTSState<State, Action>, exploit?: boolean): number
}

/**
Expand All @@ -84,6 +80,7 @@ export interface UCB1<State, Action> {
* @template Action
*/
export class DefaultUCB1<State, Action> implements UCB1<State, Action> {
constructor(private explorationParam_: number) {}
/**
*
*
Expand All @@ -93,13 +90,10 @@ export class DefaultUCB1<State, Action> implements UCB1<State, Action> {
* @returns {number}
* @memberof DefaultUCB1
*/
run(
parent: MCTSState<State, Action>,
child: MCTSState<State, Action>,
explorationParam: number
): number {
run(parent: MCTSState<State, Action>, child: MCTSState<State, Action>, exploit = false): number {
if (exploit) this.explorationParam_ = 0
const exploitationTerm = child.reward / child.visits
const explorationTerm = Math.sqrt(Math.log(parent.visits) / child.visits)
return exploitationTerm + explorationParam * explorationTerm
return exploitationTerm + this.explorationParam_ * explorationTerm
}
}
44 changes: 19 additions & 25 deletions src/mcts/select/select.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import { MCTSNode, Playerwise, StateIsTerminal } from '../../entities'
import { Expand } from '../expand/expand'
import { BestChild } from './best-child/best-child'
import { BestChild, UCB1 } from './best-child/best-child'

/**
*
* The Select interface represents the Selection part of the Monte Carlo Tree
* Search algorithm. This part of the algorithm deals with choosing which node
* in the tree to run a simulation on.
* @hidden
* @internal
* @export
Expand All @@ -12,47 +14,39 @@ import { BestChild } from './best-child/best-child'
* @template Action
*/
export interface Select<State, Action> {
run: (node: MCTSNode<State, Action>, explorationParam: number) => MCTSNode<State, Action>
run: (node: MCTSNode<State, Action>) => MCTSNode<State, Action>
}

/**
*
* The DefaultSelect class provides the standard Monte Carlo Tree Search algorithm
* with the selection phase. Through it's [[run]] method, when supplied with a tree
* node, it will provide another tree node from which to run a simulation.
* @hidden
* @internal
* @export
* @class DefaultSelect
* @implements {Select<State, Action>}
* @template State
* @template Action
*/
export class DefaultSelect<State extends Playerwise, Action> implements Select<State, Action> {
/**
* Creates an instance of DefaultSelect.
* @param {StateIsTerminal<State>} stateIsTerminal_
* @param {Expand<State, Action>} expand_
* @param {BestChild<State, Action>} bestChild_
* @memberof DefaultSelect
*/
constructor(
private stateIsTerminal_: StateIsTerminal<State>,
private expand_: Expand<State, Action>,
private bestChild_: BestChild<State, Action>
private bestChild_: BestChild<State, Action>,
private ucb1_: UCB1<State, Action>,
private fpuParam_: number
) {}

/**
*
*
* @param {MCTSNode<State, Action>} node
* @param {number} explorationParam
* @returns {MCTSNode<State, Action>}
* @memberof DefaultSelect
*/
run(node: MCTSNode<State, Action>, explorationParam: number): MCTSNode<State, Action> {
run(node: MCTSNode<State, Action>): MCTSNode<State, Action> {
while (!this.stateIsTerminal_(node.mctsState.state)) {
const child = this.bestChild_.run(node)
if (!child) return this.expand_.run(node)
if (node.isNotFullyExpanded()) {
return this.expand_.run(node)
const ucb1 = this.ucb1_.run(node.mctsState, child.mctsState)
if (ucb1 < this.fpuParam_) {
return this.expand_.run(node)
}
}
node = this.bestChild_.run(node, explorationParam)
node = child
}
return node
}
Expand Down
14 changes: 7 additions & 7 deletions test/mcts.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
ticTacToeBoard
} from './tic-tac-toe/tic-tac-toe'
import { MCTSState, MCTSNode } from '../src/entities'
import { Expand, DefaultExpand } from '../src/mcts/expand/expand'
import { Expand, DefaultExpand, FullExpand } from '../src/mcts/expand/expand'
import {
BestChild,
UCB1,
Expand Down Expand Up @@ -37,9 +37,9 @@ beforeEach(() => {
const map = new Map()
dataStore = new DataStore(map)
expand = new DefaultExpand(ticTacToeFuncs.applyAction, ticTacToeFuncs.generateActions, dataStore)
ucb1 = new DefaultUCB1()
ucb1 = new DefaultUCB1(1.414)
bestChild = new DefaultBestChild(ucb1)
select = new DefaultSelect(ticTacToeFuncs.stateIsTerminal, expand, bestChild)
select = new DefaultSelect(ticTacToeFuncs.stateIsTerminal, expand, bestChild, ucb1, Infinity)
simulate = new DefaultSimulate(
ticTacToeFuncs.stateIsTerminal,
ticTacToeFuncs.generateActions,
Expand Down Expand Up @@ -70,7 +70,7 @@ describe('The DefaultSelect instance', () => {
const mtcsState = new MCTSState(state)
const node = new MCTSNode(mtcsState, ticTacToeFuncs.generateActions(state))
it('should return the current node', () => {
expect(select.run(node, 1.414)).toBe(node)
expect(select.run(node)).toBe(node)
})
})
describe('when the current node is not terminal', () => {
Expand All @@ -83,7 +83,7 @@ describe('The DefaultSelect instance', () => {
it('should return a node that is not the current node.', () => {
const mtcsState = new MCTSState(state)
const node = new MCTSNode(mtcsState, ticTacToeFuncs.generateActions(state))
const result = select.run(node, 1.414)
const result = select.run(node)
expect(result).toBeInstanceOf(MCTSNode)
expect(result).not.toBe(node)
})
Expand All @@ -103,7 +103,7 @@ describe('The DefaultUCB1 function', () => {
parent.visits = 300
child.visits = 100
child.reward = 50
expect(ucb1.run(parent, child, 1.414)).toBeCloseTo(0.8377)
expect(ucb1.run(parent, child)).toBeCloseTo(0.8377)
})
})
})
Expand Down Expand Up @@ -141,7 +141,7 @@ describe('The DefaultBestChild instance', () => {
child3State.visits = 50
child3State.reward = 25

expect(bestChild.run(parentNode, 1.414)).toBe(parentNode.children[2])
expect(bestChild.run(parentNode)).toBe(parentNode.children[2])
})
})
})
Expand Down
Loading

0 comments on commit 3388c05

Please sign in to comment.