Skip to content

Commit

Permalink
Rename UserRoundRobin -> RoundRobin
Browse files Browse the repository at this point in the history
  • Loading branch information
stsievert committed Mar 14, 2022
1 parent b358e71 commit 070891e
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion salmon/triplets/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._adaptive_runners import ARR, CKL, GNMDS, SOE, SRR, STE, TSTE, Adaptive
from ._random_sampling import Random
from ._round_robin import RoundRobin, UserRoundRobin
from ._round_robin import RoundRobin
from ._test_runner import Test
from ._validation import Validation
12 changes: 6 additions & 6 deletions salmon/triplets/samplers/_round_robin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _score_query(q: Tuple[int, int, int]) -> float:
return float(score)


class RoundRobin(Sampler):
class _RoundRobin(Sampler):
"""
Let the head of the triplet query rotate through the available items while choosing the bottom two items randomly.
"""
Expand Down Expand Up @@ -89,10 +89,11 @@ def run(self, *args, **kwargs):
return None


class UserRoundRobin(RoundRobin):
class RoundRobin(_RoundRobin):
"""
Rotate through "heads" in each query (just like
:class:`~salmon.triplets.samplers.RoundRobin`) for each user.
Let the head of the triplet query rotate through the available items while choosing
the bottom two items randomly. This class is user specific if the
``/query?puid=foo`` endpoint is hit.
"""
def __init__(self, *args, **kwargs):
self.rr_args = args
Expand All @@ -102,9 +103,8 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def get_query(self, puid: str = "") -> Tuple[Query, float]:
logger.warning(f"puid = {puid}")
if puid not in self.samplers:
self.samplers[puid] = RoundRobin(*self.rr_args, **self.rr_kwargs)
self.samplers[puid] = _RoundRobin(*self.rr_args, **self.rr_kwargs)
return self.samplers[puid].get_query()

def process_answers(self, ans: List[Answer]):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_passive.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_round_robin(server, logs):
def test_round_robin_per_user(server):
N = 5
R = 2
config = {"targets": N, "samplers": {"UserRoundRobin": {}}}
config = {"targets": N, "samplers": {"RoundRobin": {}}}
server.authorize()
server.post("/init_exp", data={"exp": config})

Expand Down

0 comments on commit 070891e

Please sign in to comment.