Skip to content

Commit

Permalink
Added ValueError test for Epsilon class and comments about raising Va…
Browse files Browse the repository at this point in the history
…lueError.
  • Loading branch information
norases committed Aug 1, 2014
1 parent 77aff65 commit ec4c694
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 1 deletion.
3 changes: 2 additions & 1 deletion moe/bandit/epsilon.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ def get_winning_arm_names(arms_sampled):
:type arms_sampled: dictionary of (String(), SampleArm()) pairs
:return: of set of names of the winning arms
:rtype: frozenset(String())
:raise: ValueError when ``arms_sampled`` are empty.
"""
if not arms_sampled:
raise ValueError('sample_arms is empty!')
raise ValueError('arms_sampled is empty!')

avg_payoff_arm_name_list = []
for arm_name, sampled_arm in arms_sampled.iteritems():
Expand Down
1 change: 1 addition & 0 deletions moe/bandit/epsilon_first.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def allocate_arms(self):
:return: the dictionary of (arm, allocation) key-value pairs
:rtype: a dictionary of (String(), float64) pairs
:raise: ValueError when ``sample_arms`` are empty.
"""
arms_sampled = self._historical_info.arms_sampled
Expand Down
1 change: 1 addition & 0 deletions moe/bandit/epsilon_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def allocate_arms(self):
:return: the dictionary of (arm, allocation) key-value pairs
:rtype: a dictionary of (String(), float64) pairs
:raise: ValueError when ``sample_arms`` are empty.
"""
arms_sampled = self._historical_info.arms_sampled
Expand Down
16 changes: 16 additions & 0 deletions moe/tests/bandit/epsilon_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
Test functions in :class:`moe.bandit.epsilon.Epsilon`
"""
import logging

import testify as T

from moe.bandit.epsilon import Epsilon
Expand All @@ -14,6 +16,20 @@ class EpsilonTest(EpsilonTestCase):

"""Verify that different sample_arms return correct results."""

@T.class_setup
def disable_logging(self):
"""Disable logging (for the duration of this test case)."""
logging.disable(logging.CRITICAL)

@T.class_teardown
def enable_logging(self):
"""Re-enable logging (so other test cases are unaffected)."""
logging.disable(logging.NOTSET)

def test_empty_arm_invalid(self):
"""Test empty ``sample_arms`` causes an ValueError."""
T.assert_raises(ValueError, Epsilon.get_winning_arm_names, {})

def test_two_new_arms(self):
"""Check that the two-new-arms case always returns both arms as winning arms. This tests num_winning_arms == num_arms > 1."""
T.assert_sets_equal(Epsilon.get_winning_arm_names(self.two_new_arms_test_case.arms_sampled), frozenset(["arm1", "arm2"]))
Expand Down

0 comments on commit ec4c694

Please sign in to comment.