diff --git a/tests/test_active.py b/tests/test_active.py index 96c45da79..4640de951 100644 --- a/tests/test_active.py +++ b/tests/test_active.py @@ -82,17 +82,21 @@ def test_active_queries_generated(server, sampler, logs): # test is designed to make sure no unexpected errors are thrown in # active portion (not that it generates a good embedding) - n = 6 + n = 16 + config = { "targets": [_ for _ in range(n)], "samplers": {sampler: {}}, "sampling": {"common": {"d": 1, "R": 1}}, } + for _ in range(2): + server.reset() + sleep(2) with logs: server.authorize() server.post("/init_exp", data={"exp": config}) n_active_queries = 0 - for k in range(6 * n + 1): + for k in range(20 * n + 1): q = server.get("/query").json() ans = random.choice([q["left"], q["right"]]) @@ -108,18 +112,21 @@ def test_active_queries_generated(server, sampler, logs): sleep(1) break - c = 1 # for github actions + # if github_actions: + # sleep(1) + # else: ... if k % n == 0: - sleep(1 * c) - if k == n + 1: - sleep(2 * c) + sleep(1) - d = server.get("/responses").json() + r = server.get("/responses") + d = r.json() + _ = "foobar" df = pd.DataFrame(d) random_queries = df["score"] == -9999 active_queries = ~random_queries - assert active_queries.sum() and random_queries.sum() + assert active_queries.sum() + assert random_queries.sum() samplers = set(df.alg_ident.unique()) assert samplers == {sampler} diff --git a/tests/utils.py b/tests/utils.py index 71689cd87..8bf708db6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -36,7 +36,8 @@ def get(self, endpoint, status_code=200, **kwargs): logger.info(f"Getting {endpoint}") r = requests.get(self.url + endpoint, **kwargs) logger.info("done") - assert r.status_code == status_code, (r.status_code, status_code, r.text) + if status_code: + assert r.status_code == status_code, (r.status_code, status_code, r.text) return r def reset(self):