Skip to content

Commit

Permalink
Implement scheduled checks #7093 (#7271)
Browse files Browse the repository at this point in the history
* Implement scheduled checks #7093

- Rename `run_backfill` to `run_evaluation` in admin malware view
- Modify `run` and `scan` method signatures to accept `**kwargs`
- Extend `run_check` to accomodate scheduled check functionality

* Reduce unit test flakiness

* Code review changes.

Also replace `check.hooked_object` with `check.hooked_object.value` in
check detail template.

* tests, warehouse: enum fixes

* Fix lint error

Co-authored-by: William Woodruff <william@yossarian.net>
  • Loading branch information
2 people authored and ewdurbin committed Feb 11, 2020
1 parent f2b93df commit d75af1d
Show file tree
Hide file tree
Showing 12 changed files with 197 additions and 67 deletions.
4 changes: 2 additions & 2 deletions tests/common/checks/scheduled.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ class ExampleScheduledCheck(MalwareCheckBase):
def __init__(self, db):
super().__init__(db)

def scan(self):
def scan(self, **kwargs):
project = self.db.query(Project).first()
self.add_verdict(
project_id=project.id,
classification=VerdictClassification.benign,
classification=VerdictClassification.Benign,
confidence=VerdictConfidence.High,
message="Nothing to see here!",
)
4 changes: 2 additions & 2 deletions tests/unit/admin/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def test_includeme():
domain=warehouse,
),
pretend.call(
"admin.checks.run_backfill",
"/admin/checks/{check_name}/run_backfill",
"admin.checks.run_evaluation",
"/admin/checks/{check_name}/run_evaluation",
domain=warehouse,
),
pretend.call("admin.verdicts.list", "/admin/verdicts/", domain=warehouse),
Expand Down
59 changes: 40 additions & 19 deletions tests/unit/admin/views/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from pyramid.httpexceptions import HTTPNotFound

from warehouse.admin.views import checks as views
from warehouse.malware.models import MalwareCheckState
from warehouse.malware.models import MalwareCheckState, MalwareCheckType
from warehouse.malware.tasks import backfill, run_check

from ....common.db.malware import MalwareCheckFactory

Expand Down Expand Up @@ -46,6 +47,7 @@ def test_get_check(self, db_request):
"check": check,
"checks": [check],
"states": MalwareCheckState,
"evaluation_run_size": 10000,
}

def test_get_check_many_versions(self, db_request):
Expand All @@ -56,6 +58,7 @@ def test_get_check_many_versions(self, db_request):
"check": check2,
"checks": [check2, check1],
"states": MalwareCheckState,
"evaluation_run_size": 10000,
}

def test_get_check_not_found(self, db_request):
Expand Down Expand Up @@ -129,17 +132,17 @@ def test_change_to_invalid_state(self, db_request):
assert check.state == initial_state


class TestRunBackfill:
class TestRunEvaluation:
@pytest.mark.parametrize(
("check_state", "message"),
[
(
MalwareCheckState.Disabled,
"Check must be in 'enabled' or 'evaluation' state to run a backfill.",
"Check must be in 'enabled' or 'evaluation' state to manually execute.",
),
(
MalwareCheckState.WipedOut,
"Check must be in 'enabled' or 'evaluation' state to run a backfill.",
"Check must be in 'enabled' or 'evaluation' state to manually execute.",
),
],
)
Expand All @@ -152,23 +155,29 @@ def test_invalid_backfill_parameters(self, db_request, check_state, message):
)

db_request.route_path = pretend.call_recorder(
lambda *a, **kw: "/admin/checks/%s/run_backfill" % check.name
lambda *a, **kw: "/admin/checks/%s/run_evaluation" % check.name
)

views.run_backfill(db_request)
views.run_evaluation(db_request)

assert db_request.session.flash.calls == [pretend.call(message, queue="error")]

def test_sucess(self, db_request):
check = MalwareCheckFactory.create(state=MalwareCheckState.Enabled)
@pytest.mark.parametrize(
("check_type"), [MalwareCheckType.EventHook, MalwareCheckType.Scheduled]
)
def test_success(self, db_request, check_type):

check = MalwareCheckFactory.create(
check_type=check_type, state=MalwareCheckState.Enabled
)
db_request.matchdict["check_name"] = check.name

db_request.session = pretend.stub(
flash=pretend.call_recorder(lambda *a, **kw: None)
)

db_request.route_path = pretend.call_recorder(
lambda *a, **kw: "/admin/checks/%s/run_backfill" % check.name
lambda *a, **kw: "/admin/checks/%s/run_evaluation" % check.name
)

backfill_recorder = pretend.stub(
Expand All @@ -177,13 +186,25 @@ def test_sucess(self, db_request):

db_request.task = pretend.call_recorder(lambda *a, **kw: backfill_recorder)

views.run_backfill(db_request)

assert db_request.session.flash.calls == [
pretend.call(
"Running %s on 10000 %ss!" % (check.name, check.hooked_object.value),
queue="success",
)
]

assert backfill_recorder.delay.calls == [pretend.call(check.name, 10000)]
views.run_evaluation(db_request)

if check_type == MalwareCheckType.EventHook:
assert db_request.session.flash.calls == [
pretend.call(
"Running %s on 10000 %ss!"
% (check.name, check.hooked_object.value),
queue="success",
)
]
assert db_request.task.calls == [pretend.call(backfill)]
assert backfill_recorder.delay.calls == [pretend.call(check.name, 10000)]
elif check_type == MalwareCheckType.Scheduled:
assert db_request.session.flash.calls == [
pretend.call("Running %s now!" % check.name, queue="success",)
]
assert db_request.task.calls == [pretend.call(run_check)]
assert backfill_recorder.delay.calls == [
pretend.call(check.name, manually_triggered=True)
]
else:
raise Exception("Invalid check type: %s" % check_type)
10 changes: 10 additions & 0 deletions tests/unit/malware/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@

import pretend

from celery.schedules import crontab

from warehouse import malware
from warehouse.malware import utils
from warehouse.malware.interfaces import IMalwareCheckService
from warehouse.malware.tasks import run_check

from ...common import checks as test_checks
from ...common.db.accounts import UserFactory
Expand Down Expand Up @@ -165,10 +168,17 @@ def test_includeme(monkeypatch):
registry=pretend.stub(
settings={"malware_check.backend": "TestMalwareCheckService"}
),
add_periodic_task=pretend.call_recorder(lambda *a, **kw: None),
)

malware.includeme(config)

assert config.register_service_factory.calls == [
pretend.call(malware_check_class.create_service, IMalwareCheckService)
]

assert config.add_periodic_task.calls == [
pretend.call(
crontab(minute="0", hour="*/8"), run_check, args=("ExampleScheduledCheck",)
)
]
71 changes: 54 additions & 17 deletions tests/unit/malware/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
import pretend
import pytest

from sqlalchemy.orm.exc import NoResultFound

from warehouse.malware import tasks
from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict

Expand All @@ -34,45 +32,86 @@ def test_success(self, db_request, monkeypatch):
name="ExampleHookedCheck", state=MalwareCheckState.Enabled
)
task = pretend.stub()
tasks.run_check(task, db_request, "ExampleHookedCheck", file0.id)
tasks.run_check(task, db_request, "ExampleHookedCheck", obj_id=file0.id)

assert db_request.route_url.calls == [
pretend.call("packaging.file", path=file0.path)
]
assert db_request.db.query(MalwareVerdict).one()

def test_disabled_check(self, db_request, monkeypatch):
@pytest.mark.parametrize(("manually_triggered"), [True, False])
def test_evaluation_run(self, db_session, monkeypatch, manually_triggered):
monkeypatch.setattr(tasks, "checks", test_checks)
MalwareCheckFactory.create(
name="ExampleScheduledCheck", state=MalwareCheckState.Evaluation
)
ProjectFactory.create()
task = pretend.stub()

request = pretend.stub(
db=db_session,
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)),
)

tasks.run_check(
task,
request,
"ExampleScheduledCheck",
manually_triggered=manually_triggered,
)

if manually_triggered:
assert db_session.query(MalwareVerdict).one()
else:
assert request.log.info.calls == [
pretend.call(
"ExampleScheduledCheck is in the `evaluation` state and must be \
manually triggered to run."
)
]
assert db_session.query(MalwareVerdict).all() == []

def test_disabled_check(self, db_session, monkeypatch):
monkeypatch.setattr(tasks, "checks", test_checks)
MalwareCheckFactory.create(
name="ExampleHookedCheck", state=MalwareCheckState.Disabled
)
task = pretend.stub()
request = pretend.stub(
db=db_session,
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)),
)

file = FileFactory.create()

with pytest.raises(NoResultFound):
tasks.run_check(task, db_request, "ExampleHookedCheck", file.id)
tasks.run_check(
task, request, "ExampleHookedCheck", obj_id=file.id,
)

assert request.log.info.calls == [
pretend.call("Check ExampleHookedCheck isn't active. Aborting.")
]

def test_missing_check(self, db_request, monkeypatch):
monkeypatch.setattr(tasks, "checks", test_checks)
task = pretend.stub()

file = FileFactory.create()

with pytest.raises(AttributeError):
tasks.run_check(task, db_request, "DoesNotExistCheck", file.id)
tasks.run_check(
task, db_request, "DoesNotExistCheck",
)

def test_retry(self, db_session, monkeypatch):
monkeypatch.setattr(tasks, "checks", test_checks)
exc = Exception("Scan failed")

def scan(self, **kwargs):
raise exc

monkeypatch.setattr(tasks, "checks", test_checks)
monkeypatch.setattr(tasks.checks.ExampleHookedCheck, "scan", scan)

MalwareCheckFactory.create(
name="ExampleHookedCheck", state=MalwareCheckState.Evaluation
name="ExampleHookedCheck", state=MalwareCheckState.Enabled
)

task = pretend.stub(
Expand All @@ -87,7 +126,7 @@ def scan(self, **kwargs):
file = FileFactory.create()

with pytest.raises(celery.exceptions.Retry):
tasks.run_check(task, request, "ExampleHookedCheck", file.id)
tasks.run_check(task, request, "ExampleHookedCheck", obj_id=file.id)

assert request.log.error.calls == [
pretend.call("Error executing check ExampleHookedCheck: Scan failed")
Expand All @@ -108,9 +147,8 @@ def test_invalid_check_name(self, db_request, monkeypatch):
)
def test_run(self, db_session, num_objects, num_runs, monkeypatch):
monkeypatch.setattr(tasks, "checks", test_checks)
files = []
for i in range(num_objects):
files.append(FileFactory.create())
FileFactory.create()

MalwareCheckFactory.create(
name="ExampleHookedCheck", state=MalwareCheckState.Enabled
Expand All @@ -133,15 +171,14 @@ def test_run(self, db_session, num_objects, num_runs, monkeypatch):
pretend.call("Running backfill on %d Files." % num_runs)
]

assert enqueue_recorder.delay.calls == [
pretend.call("ExampleHookedCheck", files[i].id) for i in range(num_runs)
]
assert len(enqueue_recorder.delay.calls) == num_runs


class TestSyncChecks:
def test_no_updates(self, db_session, monkeypatch):
monkeypatch.setattr(tasks, "checks", test_checks)
monkeypatch.setattr(tasks.checks.ExampleScheduledCheck, "version", 2)

MalwareCheckFactory.create(
name="ExampleHookedCheck", state=MalwareCheckState.Disabled
)
Expand Down
4 changes: 2 additions & 2 deletions warehouse/admin/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def includeme(config):
domain=warehouse,
)
config.add_route(
"admin.checks.run_backfill",
"/admin/checks/{check_name}/run_backfill",
"admin.checks.run_evaluation",
"/admin/checks/{check_name}/run_evaluation",
domain=warehouse,
)
config.add_route("admin.verdicts.list", "/admin/verdicts/", domain=warehouse)
Expand Down
18 changes: 16 additions & 2 deletions warehouse/admin/templates/admin/malware/checks/detail.html
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,22 @@ <h4>Revision History</h4>
<tr>
<th>Version</th>
<th>State</th>
{% if check.check_type.value == "event_hook" %}
<th>Hooked Object</th>
{% else %}
<th>Schedule</th>
{% endif %}
<th>Created</th>
</tr>
{% for c in checks %}
<tr>
<td>{{ c.version }}</td>
<td>{{ c.state.value }}</td>
{% if check.check_type.value == "event_hook" %}
<td>{{ c.hooked_object.value }}</td>
{% else %}
<td><pre>{{ c.schedule }}</pre></td>
{% endif %}
<td>{{ c.created }}</td>
</tr>
{% endfor %}
Expand Down Expand Up @@ -69,10 +79,14 @@ <h3 class="box-title">Change State</h3>
<div class="box-header with-border">
<h3 class="box-title">Run Evaluation</h3>
</div>
<form method="POST" action="{{ request.route_path('admin.checks.run_backfill', check_name=check.name) }}">
<form method="POST" action="{{ request.route_path('admin.checks.run_evaluation', check_name=check.name) }}">
<input name="csrf_token" type="hidden" value="{{ request.session.get_csrf_token() }}">
<div class="box-body">
<p>Run this check against 10,000 {{ check.hooked_object.value }}s, selected at random. This is used to evaluate the efficacy of a check.</p>
{% if check.check_type.value == "event_hook" %}
<p>Run this check against {{ evaluation_run_size }} {{ check.hooked_object.value }}s, selected at random. This is used to evaluate the efficacy of a check.</p>
{% else %}
<p>Execute this check now.</p>
{% endif %}
<div class="pull-right col-sm-4">
<button type="submit" class="btn btn-primary pull-right">Run</button>
</div>
Expand Down
4 changes: 2 additions & 2 deletions warehouse/admin/templates/admin/malware/checks/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
<tr>
<th>Check Name</th>
<th>State</th>
<th>Revisions</th>
<th>Type</th>
<th>Last Modified</th>
<th>Description</th>
</tr>
Expand All @@ -38,7 +38,7 @@
</a>
</td>
<td>{{ check.state.value }}</td>
<td>{{ check.version }}</td>
<td>{{ check.check_type.value }}</td>
<td>{{ check.created }}</td>
<td>{{ check.short_description }}</td>
</tr>
Expand Down
Loading

0 comments on commit d75af1d

Please sign in to comment.