Skip to content

Commit 79ced9e

Browse files
author
Emily Strong
committed
fix for updating expectations in ucb1 partial fit
Signed-off-by: Emily Strong <emily.strong@fmr.com>
1 parent c6f1660 commit 79ced9e

8 files changed

+34
-17
lines changed

CHANGELOG.txt

+7
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
MABWiser CHANGELOG
33
=====================
44

5+
6+
-------------------------------------------------------------------------------
7+
December, 17, 2019 1.7.1
8+
-------------------------------------------------------------------------------
9+
minor:
10+
- Bug fix for partial fitting in UCB1
11+
512
-------------------------------------------------------------------------------
613
November, 27, 2019 1.7.0
714
-------------------------------------------------------------------------------

dist/mabwiser-1.7.0-py3-none-any.whl

-41.6 KB
Binary file not shown.

dist/mabwiser-1.7.1-py3-none-any.whl

41.7 KB
Binary file not shown.

mabwiser/mab.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66
:Author: FMR LLC
77
:Email: mabwiser@fmr.com
8-
:Version: 1.7.0 of November 27, 2019
8+
:Version: 1.7.1 of December 17, 2019
99
1010
This module defines the public interface of the **MABWiser Library** providing access to the following modules:
1111
@@ -32,7 +32,7 @@
3232

3333
__author__ = "FMR LLC"
3434
__email__ = "mabwiser@fmr.com"
35-
__version__ = "1.7.0"
35+
__version__ = "1.7.1"
3636
__copyright__ = "Copyright (C) 2019, FMR LLC"
3737

3838

mabwiser/simulator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55
:Author: FMR LLC
66
:Email: mabwiser@fmr.com
7-
:Version: 1.7.0 of November 27, 2019
7+
:Version: 1.7.1 of December 17, 2019
88
99
This module provides a simulation utility for comparing algorithms and hyper-parameter tuning.
1010
"""

mabwiser/ucb.py

+2
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def _fit_arm(self, arm: Arm, decisions: np.ndarray, rewards: np.ndarray, context
6363
self.arm_to_sum[arm] += arm_rewards.sum()
6464
self.arm_to_count[arm] += arm_rewards.size
6565
self.arm_to_mean[arm] = self.arm_to_sum[arm] / self.arm_to_count[arm]
66+
67+
if self.arm_to_count[arm]:
6668
self.arm_to_expectation[arm] = _UCB1._get_ucb(self.arm_to_mean[arm], self.alpha,
6769
self.total_count, self.arm_to_count[arm])
6870

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
name="mabwiser",
88
description="MABWiser: Parallelizable Contextual Multi-Armed Bandits Library",
99
long_description=long_description,
10-
version="1.7.0",
10+
version="1.7.1",
1111
author="FMR LLC",
1212
url="https://github.com/fmr-llc/mabwiser",
1313
packages=setuptools.find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),

tests/test_ucb.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -310,28 +310,36 @@ def test_partial_fit(self):
310310

311311
mean = mab._imp.arm_to_mean[1]
312312
ci = mab._imp.arm_to_expectation[1]
313+
mean1 = mab._imp.arm_to_mean[2]
314+
ci1 = mab._imp.arm_to_expectation[2]
313315
self.assertAlmostEqual(0.3333333333333333, mean)
314316
self.assertAlmostEqual(1.5723073962832794, ci)
315-
316-
mean1 = mab._imp.arm_to_mean[4]
317-
ci1 = mab._imp.arm_to_expectation[4]
318317
self.assertEqual(mean1, 0)
319-
self.assertEqual(ci1, 0)
318+
self.assertAlmostEqual(ci1, 1.5174271293851465)
319+
320+
mean2 = mab._imp.arm_to_mean[4]
321+
ci2 = mab._imp.arm_to_expectation[4]
322+
self.assertEqual(mean2, 0)
323+
self.assertEqual(ci2, 0)
320324

321325
# Fit again
322326
decisions2 = [1, 3, 4]
323327
rewards2 = [0, 1, 1]
324328
mab.partial_fit(decisions2, rewards2)
325329

326-
mean2 = mab._imp.arm_to_mean[1]
327-
ci2 = mab._imp.arm_to_expectation[1]
328-
mean3 = mab._imp.arm_to_mean[4]
329-
ci3 = mab._imp.arm_to_expectation[4]
330-
331-
self.assertEqual(mean2, 0.25)
332-
self.assertAlmostEqual(1.3824639856219572, ci2)
333-
self.assertEqual(mean3, 1)
334-
self.assertAlmostEqual(3.2649279712439143, ci3)
330+
mean3 = mab._imp.arm_to_mean[1]
331+
ci3 = mab._imp.arm_to_expectation[1]
332+
mean4 = mab._imp.arm_to_mean[4]
333+
ci4 = mab._imp.arm_to_expectation[4]
334+
mean5 = mab._imp.arm_to_mean[2]
335+
ci5 = mab._imp.arm_to_expectation[2]
336+
337+
self.assertEqual(mean3, 0.25)
338+
self.assertAlmostEqual(1.3824639856219572, ci3)
339+
self.assertEqual(mean4, 1)
340+
self.assertAlmostEqual(3.2649279712439143, ci4)
341+
self.assertEqual(mean5, 0)
342+
self.assertAlmostEqual(ci5, 1.6015459273656616)
335343

336344
def test_add_arm(self):
337345

0 commit comments

Comments
 (0)