Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store job information #33

Merged
merged 1 commit into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions backend/src/predicTCR_server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
from flask_jwt_extended import JWTManager
from flask_cors import cross_origin
from predicTCR_server.logger import get_logger
from predicTCR_server.utils import timestamp_now
from predicTCR_server.model import (
db,
Sample,
User,
Job,
Status,
Settings,
add_new_user,
add_new_runner_user,
Expand Down Expand Up @@ -281,6 +284,16 @@ def admin_users():
)
return jsonify(users=[user.as_dict() for user in users])

@app.route("/api/admin/jobs", methods=["GET"])
@jwt_required()
def admin_jobs():
if not current_user.is_admin:
return jsonify(message="Admin account required"), 400
jobs = (
db.session.execute(db.select(Job).order_by(db.desc(Job.id))).scalars().all()
)
return jsonify(jobs=[job.as_dict() for job in jobs])

@app.route("/api/admin/runner_token", methods=["GET"])
@jwt_required()
def admin_runner_token():
Expand All @@ -305,7 +318,17 @@ def runner_request_job():
sample_id = request_job()
if sample_id is None:
return jsonify(message="No job available"), 204
return {"sample_id": sample_id}
new_job = Job(
id=None,
sample_id=sample_id,
timestamp_start=timestamp_now(),
timestamp_end=0,
status=Status.RUNNING,
error_message="",
)
db.session.add(new_job)
db.session.commit()
return {"job_id": new_job.id, "sample_id": sample_id}

@app.route("/api/runner/result", methods=["POST"])
@cross_origin()
Expand All @@ -317,6 +340,9 @@ def runner_result():
sample_id = form_as_dict.get("sample_id", None)
if sample_id is None:
return jsonify(message="Missing key: sample_id"), 400
job_id = form_as_dict.get("job_id", None)
if job_id is None:
return jsonify(message="Missing key: job_id"), 400
success = form_as_dict.get("success", None)
if success is None or success.lower() not in ["true", "false"]:
logger.info(" -> missing success key")
Expand All @@ -328,19 +354,22 @@ def runner_result():
return jsonify(message="Result has success=True but no file"), 400
runner_hostname = form_as_dict.get("runner_hostname", "")
logger.info(
f"Result upload for '{sample_id}' from runner {current_user.email} / {runner_hostname}"
f"Job '{job_id}' uploaded result for '{sample_id}' from runner {current_user.email} / {runner_hostname}"
)
error_message = form_as_dict.get("error_message", None)
if error_message is not None:
error_message = form_as_dict.get("error_message", "")
if error_message != "":
logger.info(f" -> error message: {error_message}")
message, code = process_result(sample_id, success, zipfile)
message, code = process_result(
int(job_id), int(sample_id), success, error_message, zipfile
)
return jsonify(message=message), code

with app.app_context():
db.create_all()
if db.session.get(Settings, 1) is None:
db.session.add(
Settings(
id=None,
default_personal_submission_quota=10,
default_personal_submission_interval_mins=30,
global_quota=1000,
Expand Down
109 changes: 70 additions & 39 deletions backend/src/predicTCR_server/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import re
import flask
from enum import Enum
import enum
import argon2
import pathlib
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass, Mapped, mapped_column
from werkzeug.datastructures import FileStorage
from sqlalchemy.inspection import inspect
from sqlalchemy import Integer, String, Boolean, Enum
from dataclasses import dataclass
from predicTCR_server.email import send_email
from predicTCR_server.settings import predicTCR_url
Expand All @@ -20,12 +22,17 @@
decode_password_reset_token,
)

db = SQLAlchemy()

class Base(DeclarativeBase, MappedAsDataclass):
pass


db = SQLAlchemy(model_class=Base)
ph = argon2.PasswordHasher()
logger = get_logger()


class Status(str, Enum):
class Status(str, enum.Enum):
QUEUED = "queued"
RUNNING = "running"
COMPLETED = "completed"
Expand All @@ -34,34 +41,45 @@ class Status(str, Enum):

@dataclass
class Settings(db.Model):
id: int = db.Column(db.Integer, primary_key=True)
default_personal_submission_quota: int = db.Column(db.Integer, nullable=False)
default_personal_submission_interval_mins: int = db.Column(
db.Integer, nullable=False
id: Mapped[int] = mapped_column(Integer, primary_key=True)
default_personal_submission_quota: Mapped[int] = mapped_column(
Integer, nullable=False
)
default_personal_submission_interval_mins: Mapped[int] = mapped_column(
Integer, nullable=False
)
global_quota: int = db.Column(db.Integer, nullable=False)
tumor_types: str = db.Column(db.String, nullable=False)
sources: str = db.Column(db.String, nullable=False)
csv_required_columns: str = db.Column(db.String, nullable=False)
global_quota: Mapped[int] = mapped_column(Integer, nullable=False)
tumor_types: Mapped[str] = mapped_column(String, nullable=False)
sources: Mapped[str] = mapped_column(String, nullable=False)
csv_required_columns: Mapped[str] = mapped_column(String, nullable=False)

def as_dict(self):
return {
c: getattr(self, c)
for c in inspect(self).attrs.keys()
if c != "password_hash"
}
return {c: getattr(self, c) for c in inspect(self).attrs.keys()}


@dataclass
class Job(db.Model):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
sample_id: Mapped[int] = mapped_column(Integer, nullable=False)
timestamp_start: Mapped[int] = mapped_column(Integer, nullable=False)
timestamp_end: Mapped[int] = mapped_column(Integer, nullable=False)
status: Mapped[Status] = mapped_column(Enum(Status), nullable=False)
error_message: Mapped[str] = mapped_column(String, nullable=False)

def as_dict(self):
return {c: getattr(self, c) for c in inspect(self).attrs.keys()}


@dataclass
class Sample(db.Model):
id: int = db.Column(db.Integer, primary_key=True)
email: str = db.Column(db.String(256), nullable=False)
name: str = db.Column(db.String(128), nullable=False)
tumor_type: str = db.Column(db.String(128), nullable=False)
source: str = db.Column(db.String(128), nullable=False)
timestamp: int = db.Column(db.Integer, nullable=False)
status: Status = db.Column(db.Enum(Status), nullable=False)
has_results_zip: bool = db.Column(db.Boolean, nullable=False)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
email: Mapped[str] = mapped_column(String(256), nullable=False)
name: Mapped[str] = mapped_column(String(128), nullable=False)
tumor_type: Mapped[str] = mapped_column(String(128), nullable=False)
source: Mapped[str] = mapped_column(String(128), nullable=False)
timestamp: Mapped[int] = mapped_column(Integer, nullable=False)
status: Mapped[Status] = mapped_column(Enum(Status), nullable=False)
has_results_zip: Mapped[bool] = mapped_column(Boolean, nullable=False)

def _base_path(self) -> pathlib.Path:
data_path = flask.current_app.config["PREDICTCR_DATA_PATH"]
Expand All @@ -79,17 +97,17 @@ def result_file_path(self) -> pathlib.Path:

@dataclass
class User(db.Model):
id: int = db.Column(db.Integer, primary_key=True)
email: str = db.Column(db.Text, nullable=False, unique=True)
password_hash: str = db.Column(db.Text, nullable=False)
activated: bool = db.Column(db.Boolean, nullable=False)
enabled: bool = db.Column(db.Boolean, nullable=False)
quota: int = db.Column(db.Integer, nullable=False)
submission_interval_minutes: int = db.Column(db.Integer, nullable=False)
last_submission_timestamp: int = db.Column(db.Integer, nullable=False)
is_admin: bool = db.Column(db.Boolean, nullable=False)
is_runner: bool = db.Column(db.Boolean, nullable=False)
full_results: bool = db.Column(db.Boolean, nullable=False)
id: int = mapped_column(Integer, primary_key=True)
email: str = mapped_column(String, nullable=False, unique=True)
password_hash: str = mapped_column(String, nullable=False)
activated: bool = mapped_column(Boolean, nullable=False)
enabled: bool = mapped_column(Boolean, nullable=False)
quota: int = mapped_column(Integer, nullable=False)
submission_interval_minutes: int = mapped_column(Integer, nullable=False)
last_submission_timestamp: int = mapped_column(Integer, nullable=False)
is_admin: bool = mapped_column(Boolean, nullable=False)
is_runner: bool = mapped_column(Boolean, nullable=False)
full_results: bool = mapped_column(Boolean, nullable=False)

def set_password_nocheck(self, new_password: str):
self.password_hash = ph.hash(new_password)
Expand Down Expand Up @@ -145,17 +163,26 @@ def request_job() -> int | None:


def process_result(
sample_id: str, success: bool, result_zip_file: FileStorage | None
job_id: int,
sample_id: int,
success: bool,
error_message: str,
result_zip_file: FileStorage | None,
) -> tuple[str, int]:
sample = db.session.execute(
db.select(Sample).filter_by(id=sample_id)
).scalar_one_or_none()
sample = db.session.get(Sample, sample_id)
if sample is None:
logger.warning(f" --> Unknown sample id {sample_id}")
return f"Unknown sample id {sample_id}", 400
job = db.session.get(Job, job_id)
if job is None:
logger.warning(f" --> Unknown job id {job_id}")
return f"Unknown job id {job_id}", 400
job.timestamp_end = timestamp_now()
if success is False:
sample.has_results_zip = False
sample.status = Status.FAILED
job.status = Status.FAILED
job.error_message = error_message
db.session.commit()
return "Result processed", 200
if result_zip_file is None:
Expand All @@ -165,6 +192,7 @@ def process_result(
result_zip_file.save(sample.result_file_path())
sample.has_results_zip = True
sample.status = Status.COMPLETED
job.status = Status.COMPLETED
db.session.commit()
return "Result processed", 200

Expand Down Expand Up @@ -244,6 +272,7 @@ def add_new_user(email: str, password: str, is_admin: bool) -> tuple[str, int]:
try:
db.session.add(
User(
id=None,
email=email,
password_hash=ph.hash(password),
activated=False,
Expand Down Expand Up @@ -282,6 +311,7 @@ def add_new_runner_user() -> User | None:
runner_name = f"runner{runner_number}"
db.session.add(
User(
id=None,
email=runner_name,
password_hash="",
activated=False,
Expand Down Expand Up @@ -419,6 +449,7 @@ def add_new_sample(
settings = db.session.get(Settings, 1)
settings.global_quota -= 1
new_sample = Sample(
id=None,
email=email,
name=name,
tumor_type=tumor_type,
Expand Down
2 changes: 2 additions & 0 deletions backend/tests/helpers/flask_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def add_test_users(app):
email = f"{name}@abc.xy"
db.session.add(
User(
id=None,
email=email,
password_hash=ph.hash(name),
activated=True,
Expand Down Expand Up @@ -46,6 +47,7 @@ def add_test_samples(app, data_path: pathlib.Path):
with open(f"{ref_dir}/input.{input_file_type}", "w") as f:
f.write(input_file_type)
new_sample = Sample(
id=None,
email="user@abc.xy",
name=name,
tumor_type=f"tumor_type{sample_id}",
Expand Down
59 changes: 46 additions & 13 deletions backend/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,13 @@ def test_result_invalid(client):
assert "No results available" in response.json["message"]


def _upload_result(client, result_zipfile: pathlib.Path, sample_id: int):
def _upload_result(client, result_zipfile: pathlib.Path, job_id: int, sample_id: int):
headers = _get_auth_headers(client, "runner@abc.xy", "runner")
with open(result_zipfile, "rb") as f:
response = client.post(
"/api/runner/result",
data={
"job_id": job_id,
"sample_id": sample_id,
"success": True,
"file": (io.BytesIO(f.read()), result_zipfile.name),
Expand All @@ -225,19 +226,57 @@ def _upload_result(client, result_zipfile: pathlib.Path, sample_id: int):
return response


def test_result_valid(client, result_zipfile):
headers = _get_auth_headers(client, "user@abc.xy", "user")
sample_id = 1
assert _upload_result(client, result_zipfile, sample_id).status_code == 200
def test_runner_valid_success(client, result_zipfile):
headers = _get_auth_headers(client, "runner@abc.xy", "runner")
# request job
request_job_response = client.post(
"/api/runner/request_job",
json={"runner_hostname": "me"},
headers=headers,
)
assert request_job_response.status_code == 200
assert request_job_response.json == {"sample_id": 1, "job_id": 1}
# upload successful result
assert _upload_result(client, result_zipfile, 1, 1).status_code == 200
response = client.post(
"/api/result",
json={"sample_id": sample_id},
headers=headers,
json={"sample_id": 1},
headers=_get_auth_headers(client, "user@abc.xy", "user"),
)
assert response.status_code == 200
assert len(response.data) > 1


def test_runner_valid_failure(client, result_zipfile):
headers = _get_auth_headers(client, "runner@abc.xy", "runner")
# request job
request_job_response = client.post(
"/api/runner/request_job",
json={"runner_hostname": "me"},
headers=headers,
)
assert request_job_response.status_code == 200
assert request_job_response.json == {"sample_id": 1, "job_id": 1}
# upload failure result
result_response = client.post(
"/api/runner/result",
data={
"job_id": 1,
"sample_id": 1,
"success": False,
"error_message": "Something went wrong",
},
headers=headers,
)
assert result_response.status_code == 200
response = client.post(
"/api/result",
json={"sample_id": 1},
headers=_get_auth_headers(client, "user@abc.xy", "user"),
)
assert response.status_code == 400


def test_admin_samples_valid(client):
headers = _get_auth_headers(client, "admin@abc.xy", "admin")
response = client.get("/api/admin/samples", headers=headers)
Expand Down Expand Up @@ -288,12 +327,6 @@ def test_admin_users_valid(client):
assert "users" in response.json


def test_runner_result_valid(client, result_zipfile):
response = _upload_result(client, result_zipfile, 1)
assert response.status_code == 200
assert "result processed" in response.json["message"].lower()


def test_admin_update_user_valid(client):
headers = _get_auth_headers(client, "admin@abc.xy", "admin")
user = client.get("/api/admin/users", headers=headers).json["users"][0]
Expand Down
Loading
Loading