Skip to content

Commit

Permalink
redis: Use a single redis connection per process (MarkUsProject#409)
Browse files Browse the repository at this point in the history
  • Loading branch information
mishaschwartz committed Nov 2, 2022
1 parent 7c1876d commit 1d93906
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 65 deletions.
71 changes: 27 additions & 44 deletions client/autotest_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
SETTINGS_JOB_TIMEOUT = os.environ.get("SETTINGS_JOB_TIMEOUT", 600)
REDIS_URL = os.environ["REDIS_URL"]

REDIS_CONNECTION = redis.Redis.from_url(REDIS_URL)

app = Flask(__name__)


Expand All @@ -35,23 +37,6 @@ def _open_log(log, mode="a", fallback=sys.stdout):
yield fallback


def _redis_connection(decode_responses=True) -> redis.Redis:
return redis.Redis.from_url(REDIS_URL, decode_responses=decode_responses)


def _rq_connection() -> redis.Redis:
"""
Return the currently open redis connection object. If there is no
connection currently open, one is created using the url specified in
REDIS_URL.
"""
conn = rq.get_current_connection()
if conn:
return conn
rq.use_connection(redis=redis.Redis.from_url(REDIS_URL))
return rq.get_current_connection()


@app.errorhandler(Exception)
def _handle_error(e):
code = 500
Expand All @@ -71,31 +56,29 @@ def _handle_error(e):


def _check_rate_limit(api_key):
conn = _redis_connection()
key = f"autotest:ratelimit:{api_key}:{datetime.now().minute}"
n_requests = conn.get(key) or 0
user_limit = conn.get(f"autotest:ratelimit:{api_key}:limit") or 20 # TODO: make default limit configurable
n_requests = REDIS_CONNECTION.get(key) or 0
user_limit = REDIS_CONNECTION.get(f"autotest:ratelimit:{api_key}:limit") or 20 # TODO: make default configurable
if int(n_requests) > int(user_limit):
abort(make_response(jsonify(message="Too many requests"), 429))
else:
with conn.pipeline() as pipe:
with REDIS_CONNECTION.pipeline() as pipe:
pipe.incr(key)
pipe.expire(key, 59)
pipe.execute()


def _authorize_user():
api_key = request.headers.get("Api-Key")
user_name = (_redis_connection().hgetall("autotest:user_credentials") or {}).get(api_key)
if user_name is None:
if api_key is None or (REDIS_CONNECTION.hgetall("autotest:user_credentials") or {}).get(api_key.encode()) is None:
abort(make_response(jsonify(message="Unauthorized"), 401))
_check_rate_limit(api_key)
return api_key


def _authorize_settings(user, settings_id=None, **_kw):
if settings_id:
settings_ = _redis_connection().hget("autotest:settings", settings_id)
settings_ = REDIS_CONNECTION.hget("autotest:settings", settings_id)
if settings_ is None:
abort(make_response(jsonify(message="Settings not found"), 404))
if json.loads(settings_).get("_user") != user:
Expand All @@ -104,10 +87,10 @@ def _authorize_settings(user, settings_id=None, **_kw):

def _authorize_tests(tests_id=None, settings_id=None, **_kw):
if settings_id and tests_id:
test_setting = _redis_connection().hget("autotest:tests", tests_id)
test_setting = REDIS_CONNECTION.hget("autotest:tests", tests_id)
if test_setting is None:
abort(make_response(jsonify(message="Test not found"), 404))
if test_setting != settings_id:
if int(test_setting) != int(settings_id):
abort(make_response(jsonify(message="Unauthorized"), 401))


Expand All @@ -125,7 +108,7 @@ def _update_settings(settings_id, user):
if error:
abort(make_response(jsonify(message=error), 422))

queue = rq.Queue("settings", connection=_rq_connection())
queue = rq.Queue("settings", connection=REDIS_CONNECTION)
data = {"user": user, "settings_id": settings_id, "test_settings": test_settings, "file_url": file_url}
queue.enqueue_call(
"autotest_server.update_test_settings",
Expand All @@ -137,12 +120,12 @@ def _update_settings(settings_id, user):

def _get_jobs(test_ids, settings_id):
for id_ in test_ids:
test_setting = _redis_connection().hget("autotest:tests", id_)
if test_setting is None or test_setting != settings_id:
test_setting = REDIS_CONNECTION.hget("autotest:tests", id_)
if test_setting is None or int(test_setting) != int(settings_id):
yield None
else:
try:
yield rq.job.Job.fetch(str(id_), connection=_rq_connection())
yield rq.job.Job.fetch(str(id_), connection=REDIS_CONNECTION)
except rq.exceptions.NoSuchJobError:
yield None

Expand Down Expand Up @@ -181,7 +164,7 @@ def register():
credentials = request.json.get("credentials")
key = base64.b64encode(os.urandom(24)).decode("utf-8")
data = {"auth_type": auth_type, "credentials": credentials}
while not _redis_connection().hsetnx("autotest:user_credentials", key=key, value=json.dumps(data)):
while not REDIS_CONNECTION.hsetnx("autotest:user_credentials", key=key, value=json.dumps(data)):
key = base64.b64encode(os.urandom(24)).decode("utf-8")
return {"api_key": key}

Expand All @@ -192,20 +175,20 @@ def reset_credentials(user):
auth_type = request.json.get("auth_type")
credentials = request.json.get("credentials")
data = {"auth_type": auth_type, "credentials": credentials}
_redis_connection().hset("autotest:user_credentials", key=user, value=json.dumps(data))
REDIS_CONNECTION.hset("autotest:user_credentials", key=user, value=json.dumps(data))
return jsonify(success=True)


@app.route("/schema", methods=["GET"])
@authorize
def schema(**_kwargs):
return json.loads(_redis_connection().get("autotest:schema") or "{}")
return json.loads(REDIS_CONNECTION.get("autotest:schema") or "{}")


@app.route("/settings/<settings_id>", methods=["GET"])
@authorize
def settings(settings_id, **_kw):
settings_ = json.loads(_redis_connection().hget("autotest:settings", key=settings_id) or "{}")
settings_ = json.loads(REDIS_CONNECTION.hget("autotest:settings", key=settings_id) or "{}")
if settings_.get("_error"):
raise Exception(f"Settings Error: {settings_['_error']}")
return {k: v for k, v in settings_.items() if not k.startswith("_")}
Expand All @@ -214,8 +197,8 @@ def settings(settings_id, **_kw):
@app.route("/settings", methods=["POST"])
@authorize
def create_settings(user):
settings_id = _redis_connection().incr("autotest:settings_id")
_redis_connection().hset("autotest:settings", key=settings_id, value=json.dumps({"_user": user}))
settings_id = REDIS_CONNECTION.incr("autotest:settings_id")
REDIS_CONNECTION.hset("autotest:settings", key=settings_id, value=json.dumps({"_user": user}))
_update_settings(settings_id, user)
return {"settings_id": settings_id}

Expand All @@ -234,7 +217,7 @@ def run_tests(settings_id, user):
categories = request.json["categories"]
high_priority = request.json.get("request_high_priority")
queue_name = "batch" if len(test_data) > 1 else ("high" if high_priority else "low")
queue = rq.Queue(queue_name, connection=_rq_connection())
queue = rq.Queue(queue_name, connection=REDIS_CONNECTION)

timeout = 0

Expand All @@ -246,8 +229,8 @@ def run_tests(settings_id, user):
for data in test_data:
url = data["file_url"]
test_env_vars = data.get("env_vars", {})
id_ = _redis_connection().incr("autotest:tests_id")
_redis_connection().hset("autotest:tests", key=id_, value=settings_id)
id_ = REDIS_CONNECTION.incr("autotest:tests_id")
REDIS_CONNECTION.hset("autotest:tests", key=id_, value=settings_id)
ids.append(id_)
data = {
"settings_id": settings_id,
Expand All @@ -272,30 +255,30 @@ def run_tests(settings_id, user):
@app.route("/settings/<settings_id>/test/<tests_id>", methods=["GET"])
@authorize
def get_result(settings_id, tests_id, **_kw):
job = rq.job.Job.fetch(tests_id, connection=_rq_connection())
job = rq.job.Job.fetch(tests_id, connection=REDIS_CONNECTION)
job_status = job.get_status()
result = {"status": job_status}
if job_status == "finished":
test_result = _redis_connection().get(f"autotest:test_result:{tests_id}")
test_result = REDIS_CONNECTION.get(f"autotest:test_result:{tests_id}")
try:
result.update(json.loads(test_result))
except json.JSONDecodeError:
result.update({"error": f"invalid json: {test_result}"})
elif job_status == "failed":
result.update({"error": str(job.exc_info)})
job.delete()
_redis_connection().delete(f"autotest:test_result:{tests_id}")
REDIS_CONNECTION.delete(f"autotest:test_result:{tests_id}")
return result


@app.route("/settings/<settings_id>/test/<tests_id>/feedback/<feedback_id>", methods=["GET"])
@authorize
def get_feedback_file(settings_id, tests_id, feedback_id, **_kw):
key = f"autotest:feedback_file:{tests_id}:{feedback_id}"
data = _redis_connection(decode_responses=False).get(key)
data = REDIS_CONNECTION.get(key)
if data is None:
abort(make_response(jsonify(message="File doesn't exist"), 404))
_redis_connection().delete(key)
REDIS_CONNECTION.delete(key)
return send_file(io.BytesIO(data), mimetype="application/gzip", as_attachment=True, download_name=str(feedback_id))


Expand Down
10 changes: 2 additions & 8 deletions client/autotest_client/tests/test_flask_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,12 @@ def client():

@pytest.fixture
def fake_redis_conn():
yield fakeredis.FakeStrictRedis(decode_responses=True)


@pytest.fixture
def fake_rq_conn():
conn = fakeredis.FakeStrictRedis(decode_responses=False)
autotest_client.rq.use_connection(conn)
yield fakeredis.FakeStrictRedis()


@pytest.fixture(autouse=True)
def fake_redis_db(monkeypatch, fake_redis_conn):
monkeypatch.setattr(autotest_client.redis.Redis, "from_url", lambda *a, **kw: fake_redis_conn)
monkeypatch.setattr(autotest_client, "REDIS_CONNECTION", fake_redis_conn)


class TestRegister:
Expand Down
4 changes: 2 additions & 2 deletions server/autotest_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@
import importlib
import psycopg2
import mimetypes
import rq
from typing import Optional, Dict, Union, List, Tuple, Callable, Type
from types import TracebackType

from .config import config
from .utils import loads_partial_json, set_rlimits_before_test, extract_zip_stream, recursive_iglob, copy_tree

DEFAULT_ENV_DIR = "defaultvenv"
REDIS_URL = config["redis_url"]
TEST_SCRIPT_DIR = os.path.join(config["workspace"], "scripts")

ResultData = Dict[str, Union[str, int, type(None), Dict]]


def redis_connection() -> redis.Redis:
return redis.Redis.from_url(REDIS_URL, decode_responses=True)
return rq.get_current_job().connection


def run_test_command(test_username: Optional[str] = None) -> str:
Expand Down
17 changes: 14 additions & 3 deletions server/autotest_server/tests/test_autotest_server.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
import pytest
import fakeredis
import rq
import autotest_server


@pytest.fixture
def fake_redis_conn():
yield fakeredis.FakeStrictRedis(decode_responses=True)
yield fakeredis.FakeStrictRedis()


@pytest.fixture
def fake_queue(fake_redis_conn):
yield rq.Queue(is_async=False, connection=fake_redis_conn)


@pytest.fixture
def fake_job(fake_queue):
yield fake_queue.enqueue(lambda: None)


@pytest.fixture(autouse=True)
def fake_redis_db(monkeypatch, fake_redis_conn):
monkeypatch.setattr(autotest_server.redis.Redis, "from_url", lambda *a, **kw: fake_redis_conn)
def fake_redis_db(monkeypatch, fake_job):
monkeypatch.setattr(autotest_server.rq, "get_current_job", lambda *a, **kw: fake_job)


def test_redis_connection(fake_redis_conn):
Expand Down
9 changes: 6 additions & 3 deletions server/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
import json
import subprocess
import getpass
import redis
from autotest_server.config import config
from autotest_server import redis_connection, run_test_command
from autotest_server import run_test_command
from autotest_server.testers import install as install_testers

REDIS_CONNECTION = redis.Redis.from_url(config["redis_url"])


def _print(*args, **kwargs):
print("[AUTOTESTER]", *args, **kwargs)
Expand All @@ -19,7 +22,7 @@ def _print(*args, **kwargs):
def check_dependencies():
_print("checking if redis url is valid:")
try:
redis_connection().keys()
REDIS_CONNECTION.ping()
except Exception as e:
raise Exception(f'Cannot connect to redis database with url: {config["redis_url"]}') from e
for w in config["workers"]:
Expand Down Expand Up @@ -66,7 +69,7 @@ def install_all_testers():
skeleton = json.load(f)
skeleton["definitions"]["installed_testers"]["enum"] = list(settings.keys())
skeleton["definitions"]["tester_schemas"]["oneOf"] = list(settings.values())
redis_connection().set("autotest:schema", json.dumps(skeleton))
REDIS_CONNECTION.set("autotest:schema", json.dumps(skeleton))


def install():
Expand Down
8 changes: 3 additions & 5 deletions server/start_stop.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@
"""


def redis_connection() -> redis.Redis:
return redis.Redis.from_url(config["redis_url"], decode_responses=True)
REDIS_CONNECTION = redis.Redis.from_url(config["redis_url"], decode_responses=True)


def create_enqueuer_wrapper(rq):
Expand Down Expand Up @@ -82,7 +80,7 @@ def stat(rq, extra_args):


def clean(age, dry_run):
for settings_id, settings in dict(redis_connection().hgetall("autotest:settings") or {}).items():
for settings_id, settings in dict(REDIS_CONNECTION.hgetall("autotest:settings") or {}).items():
settings = json.loads(settings)
last_access_timestamp = settings.get("_last_access")
access = int(time.time() - (last_access_timestamp or 0))
Expand All @@ -93,7 +91,7 @@ def clean(age, dry_run):
print(f"{dir_path} -> last accessed {last_access or '< 1'} days ago")
else:
settings["_error"] = "the settings for this test have expired, please re-upload the settings."
redis_connection().hset("autotest:settings", key=settings_id, value=json.dumps(settings))
REDIS_CONNECTION.hset("autotest:settings", key=settings_id, value=json.dumps(settings))
if os.path.isdir(dir_path):
shutil.rmtree(dir_path)

Expand Down

0 comments on commit 1d93906

Please sign in to comment.