Skip to content

Commit

Permalink
Addressed Eric's comments. Wrote test for static function in class Ep…
Browse files Browse the repository at this point in the history
…silon
  • Loading branch information
norases committed Aug 1, 2014
1 parent 4811dd4 commit 77aff65
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 15 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
* Features

* Added multi-armed bandit endpoint. (#255)
* Implemented epsilon-greedy.
* Implemented epsilon-greedy. (#255)
* Implemented epsilon-first. (#335)
* Added support for the L-BFGS-B optimizer. (#296)

* Changes
Expand Down
7 changes: 5 additions & 2 deletions moe/bandit/epsilon.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
"""Construct an Epsilon object.
:param historical_info: a dictionary of arms sampled
:type historical_info: dictionary of (String(), SingleArm()) pairs
:type historical_info: dictionary of (String(), SampleArm()) pairs (see :class:`moe.bandit.data_containers.SampleArm` for more details)
:param subtype: subtype of the epsilon bandit algorithm (default: None)
:type subtype: String()
:param epsilon: epsilon hyperparameter for the epsilon bandit algorithm (default: :const:`~moe.bandit.constant.DEFAULT_EPSILON`)
Expand All @@ -43,12 +43,15 @@ def __init__(
self._subtype = subtype
self._epsilon = epsilon

def _get_winning_arm_names(self, arms_sampled):
@staticmethod
def get_winning_arm_names(arms_sampled):
r"""Compute the set of winning arm names based on the given ``arms_sampled``..
Throws an exception when arms_sampled is empty.
Implementers of this interface will never override this method.
:param arms_sampled: a dictionary of arm name to :class:`moe.bandit.data_containers.SampleArm`
:type arms_sampled: dictionary of (String(), SampleArm()) pairs
:return: of set of names of the winning arms
:rtype: frozenset(String())
Expand Down
4 changes: 1 addition & 3 deletions moe/bandit/epsilon_first.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
See :class:`moe.bandit.epsilon.Epsilon` for further details on bandit.
"""
import numpy

from moe.bandit.constant import DEFAULT_EPSILON, DEFAULT_TOTAL_SAMPLES, EPSILON_SUBTYPE_FIRST
from moe.bandit.epsilon import Epsilon

Expand Down Expand Up @@ -106,7 +104,7 @@ def allocate_arms(self):
return arms_to_allocations

# Exploitation phase, trials epsilon * T+1, ..., T
winning_arm_names = self._get_winning_arm_names(arms_sampled)
winning_arm_names = self.get_winning_arm_names(arms_sampled)

num_winning_arms = len(winning_arm_names)
arms_to_allocations = {}
Expand Down
4 changes: 1 addition & 3 deletions moe/bandit/epsilon_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
See :class:`moe.bandit.epsilon.Epsilon` for further details on this bandit.
"""
import numpy

from moe.bandit.constant import DEFAULT_EPSILON, EPSILON_SUBTYPE_GREEDY
from moe.bandit.epsilon import Epsilon

Expand Down Expand Up @@ -69,7 +67,7 @@ def allocate_arms(self):
if not arms_sampled:
raise ValueError('sample_arms are empty!')

winning_arm_names = self._get_winning_arm_names(arms_sampled)
winning_arm_names = self.get_winning_arm_names(arms_sampled)

num_winning_arms = len(winning_arm_names)
epsilon_allocation = self._epsilon / num_arms
Expand Down
1 change: 1 addition & 0 deletions moe/tests/bandit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
* :mod:`moe.tests.bandit.bandit_test_case`: base test case for bandit tests with a simple integration test case
* :mod:`moe.tests.bandit.epsilon_first_test`: tests for :mod:`moe.bandit.epsilon_greedy.EpsilonFirst`
* :mod:`moe.tests.bandit.epsilon_greedy_test`: tests for :mod:`moe.bandit.epsilon_greedy.EpsilonGreedy`
* :mod:`moe.tests.bandit.epsilon_test`: tests for :mod:`moe.bandit.epsilon_greedy.Epsilon`
* :mod:`moe.tests.bandit.epsilon_test_case`: test cases for classes under :mod:`moe.bandit.epsilon.Epsilon`
* :mod:`moe.tests.bandit.linkers_test`: tests for :mod:`moe.bandit.linkers`
Expand Down
27 changes: 27 additions & 0 deletions moe/tests/bandit/epsilon_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
"""Test epsilon bandit implementation (functions common to epsilon bandit).
Test functions in :class:`moe.bandit.epsilon.Epsilon`
"""
import testify as T

from moe.bandit.epsilon import Epsilon
from moe.tests.bandit.epsilon_test_case import EpsilonTestCase


class EpsilonTest(EpsilonTestCase):

"""Verify that different sample_arms return correct results."""

def test_two_new_arms(self):
"""Check that the two-new-arms case always returns both arms as winning arms. This tests num_winning_arms == num_arms > 1."""
T.assert_sets_equal(Epsilon.get_winning_arm_names(self.two_new_arms_test_case.arms_sampled), frozenset(["arm1", "arm2"]))

def test_three_arms_two_winners(self):
"""Check that the three-arms cases with two winners return the expected winning arms. This tests num_arms > num_winning_arms > 1."""
T.assert_sets_equal(Epsilon.get_winning_arm_names(self.three_arms_two_winners_test_case.arms_sampled), frozenset(["arm1", "arm2"]))


if __name__ == "__main__":
T.run()
6 changes: 3 additions & 3 deletions moe/views/rest/bandit_epsilon.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def get_params_from_request(self):
params = super(BanditEpsilonView, self).get_params_from_request()

# colander deserialized results are READ-ONLY. We will potentially be overwriting
# fields of ``params['optimizer_info']``, so we need to copy it first.
# fields of ``params['hyperparameter_info']``, so we need to copy it first.
params['hyperparameter_info'] = copy.deepcopy(params['hyperparameter_info'])

# Find the schma class that corresponds to the ``optimizer_type`` of the request
# optimizer_parameters has *not been validated yet*, so we need to validate manually.
# Find the schema class that corresponds to the ``subtype`` of the request
# hyperparameter_info has *not been validated yet*, so we need to validate manually.
schema_class = BANDIT_EPSILON_SUBTYPES_TO_HYPERPARAMETER_INFO_SCHEMA_CLASSES[params['subtype']]()

# Deserialize and validate the parameters
Expand Down
6 changes: 3 additions & 3 deletions moe/views/schemas/bandit_pretty_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ class BanditEpsilonFirstHyperparameterInfo(base_schemas.StrictMappingSchema):
:ivar epsilon: (*0.0 <= float64 <= 1.0*) epsilon value for epsilon-first bandit. This strategy pulls the optimal arm
(best expected return) with if it is in exploitation phase (number sampled > epsilon * total_samples). Otherwise a random arm is pulled (exploration).
:ivar total_samples: total number of samples for epsilon-first bandit. total_samples is T from :doc:`bandit`.
:ivar total_samples: (*int >= 0*) total number of samples for epsilon-first bandit. total_samples is T from :doc:`bandit`.
"""

epsilon = colander.SchemaNode(
colander.Float(),
validator=colander.Range(min=0),
validator=colander.Range(min=0.0, max=1.0),
missing=DEFAULT_EPSILON,
)

Expand All @@ -103,7 +103,7 @@ class BanditEpsilonGreedyHyperparameterInfo(base_schemas.StrictMappingSchema):

epsilon = colander.SchemaNode(
colander.Float(),
validator=colander.Range(min=0),
validator=colander.Range(min=0.0, max=1.0),
missing=DEFAULT_EPSILON,
)

Expand Down

0 comments on commit 77aff65

Please sign in to comment.