Skip to content

Commit

Permalink
Update cold arm logic
Browse files Browse the repository at this point in the history
Signed-off-by: Kilitcioglu, Doruk <doruk.kilitcioglu@fmr.com>
  • Loading branch information
dorukkilitcioglu committed Jan 26, 2023
1 parent 091d07b commit 93971bc
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 15 deletions.
11 changes: 6 additions & 5 deletions mabwiser/base_mab.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ def trained_arms(self) -> List[Arm]:
Arms for which at least one decision has been observed are deemed trained."""
return [arm for arm in self.arms if self.arm_to_status[arm][STATUS_TRAINED]]

@property
def cold_arms(self) -> List[Arm]:
return [arm for arm in self.arms if ((not self.arm_to_status[arm][STATUS_TRAINED]) and
(not self.arm_to_status[arm][STATUS_WARM]))]

def add_arm(self, arm: Arm, binarizer: Callable = None) -> NoReturn:
"""Introduces a new arm to the bandit.
Expand Down Expand Up @@ -375,14 +380,10 @@ def _get_cold_arm_to_warm_arm(self, arm_to_features, distance_quantile):
distance_from_to = self._get_pairwise_distances(arm_to_features)
distance_threshold = self._get_distance_threshold(distance_from_to, quantile=distance_quantile)

# Cold arms
cold_arms = [arm for arm in self.arms if ((arm not in self.trained_arms) and
(not self.arm_to_status[arm][STATUS_WARM]))]

# New cold arm to warm arm dictionary
new_cold_arm_to_warm_arm = dict()

for cold_arm in cold_arms:
for cold_arm in self.cold_arms:

# Collect distance from cold arm to warm arms
arm_to_distance = {}
Expand Down
20 changes: 12 additions & 8 deletions mabwiser/mab.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,6 @@ def __init__(self,
# Create the random number generator
self._rng = create_rng(self.seed)
self._is_initial_fit = False
self.cold_arms = set()

# Create the learning policy implementor
lp = None
Expand Down Expand Up @@ -970,6 +969,18 @@ def neighborhood_policy(self):
else:
return None

@property
def cold_arms(self) -> List[Arm]:
if not self.neighborhood_policy:
# No neighborhood policy, cold arms are calculated at the learning policy level
return self._imp.cold_arms

else:
# With neighborhood policies, we end up training and doing inference within the neighborhood.
# Each neighborhood can have a different set of trained arms, and if warm start is used,
# a different set of cold arms. Therefore, cold arms aren't defined for neighborhood policies.
return list()

def add_arm(self, arm: Arm, binarizer: Callable = None) -> NoReturn:
""" Adds an _arm_ to the list of arms.
Expand Down Expand Up @@ -1010,7 +1021,6 @@ def add_arm(self, arm: Arm, binarizer: Callable = None) -> NoReturn:
self._validate_arm(arm)
self.arms.append(arm)
self._imp.add_arm(arm, binarizer)
self._refresh_cold_arms()

def remove_arm(self, arm: Arm) -> NoReturn:
"""Removes an _arm_ from the list of arms.
Expand All @@ -1037,7 +1047,6 @@ def remove_arm(self, arm: Arm) -> NoReturn:
self._validate_arm(arm)
self.arms.remove(arm)
self._imp.remove_arm(arm)
self._refresh_cold_arms()

def fit(self,
decisions: Union[List[Arm], np.ndarray, pd.Series], # Decisions that are made
Expand Down Expand Up @@ -1096,7 +1105,6 @@ def fit(self,

# Turn initial to true
self._is_initial_fit = True
self._refresh_cold_arms()

def partial_fit(self,
decisions: Union[List[Arm], np.ndarray, pd.Series],
Expand Down Expand Up @@ -1154,9 +1162,6 @@ def partial_fit(self,
else:
self.fit(decisions, rewards, contexts)

# Refresh the list of cold arms
self._refresh_cold_arms()

def predict(self,
contexts: Union[None, List[Num], List[List[Num]],
np.ndarray, pd.Series, pd.DataFrame] = None # Contexts, optional
Expand Down Expand Up @@ -1271,7 +1276,6 @@ def warm_start(self, arm_to_features: Dict[Arm, List[Num]], distance_quantile: f
check_true(set(self.arms) == set(arm_to_features.keys()),
ValueError("The arms in arm features do not match arms."))
self._imp.warm_start(arm_to_features, distance_quantile)
self._refresh_cold_arms()

@staticmethod
def _validate_mab_args(arms, learning_policy, neighborhood_policy, seed, n_jobs, backend):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_mab.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,8 +1784,8 @@ def test_cold_arms(self):

# Before warm start
self.assertEqual(mab._imp.trained_arms, [1, 2])
self.assertSetEqual(mab.cold_arms, {3})
self.assertListEqual(mab.cold_arms, [3])

# Warm start
mab.warm_start(arm_to_features={1: [0, 1], 2: [0, 0], 3: [0.5, 0.5]}, distance_quantile=0.5)
self.assertSetEqual(mab.cold_arms, set())
self.assertListEqual(mab.cold_arms, list())

0 comments on commit 93971bc

Please sign in to comment.