Skip to content

Commit

Permalink
GH-64: Add CSRF protection to review scheme
Browse files Browse the repository at this point in the history
  • Loading branch information
markhobson committed Feb 29, 2024
1 parent 854c320 commit 458909e
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 8 deletions.
2 changes: 0 additions & 2 deletions schemes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ def bindings(binder: Binder) -> None:
csrf.exempt(authorities.add_schemes)
csrf.exempt(authorities.clear)
app.register_blueprint(schemes.bp, url_prefix="/schemes")
# TODO: add CSRF to scheme review form
csrf.exempt(schemes.schemes.review)
csrf.exempt(schemes.schemes.clear)
app.register_blueprint(users.bp, url_prefix="/users")
csrf.exempt(users.clear)
Expand Down
14 changes: 13 additions & 1 deletion schemes/views/schemes/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
session,
url_for,
)
from flask_wtf import FlaskForm
from werkzeug import Response as BaseResponse

from schemes.dicts import as_shallow_dict, inverse_dict
Expand Down Expand Up @@ -166,6 +167,7 @@ class SchemeContext:
funding: SchemeFundingContext
milestones: SchemeMilestonesContext
outputs: SchemeOutputsContext
review: SchemeReviewContext

@classmethod
def from_domain(
Expand All @@ -180,6 +182,7 @@ def from_domain(
funding=SchemeFundingContext.from_domain(scheme.funding),
milestones=SchemeMilestonesContext.from_domain(scheme.milestones),
outputs=SchemeOutputsContext.from_domain(scheme.outputs.current_output_revisions),
review=SchemeReviewContext(),
)


Expand Down Expand Up @@ -234,6 +237,15 @@ def from_domain(cls, funding_programme: FundingProgramme | None) -> FundingProgr
return cls(name=cls._NAMES[funding_programme] if funding_programme else None)


class SchemeReviewForm(FlaskForm): # type: ignore
pass


@dataclass(frozen=True)
class SchemeReviewContext:
form: SchemeReviewForm = field(default_factory=SchemeReviewForm)


@bp.get("<int:scheme_id>/spend-to-date")
@bearer_auth
@inject.autoparams("users", "schemes")
Expand Down Expand Up @@ -318,7 +330,7 @@ def milestones(users: UserRepository, clock: Clock, schemes: SchemeRepository, s
return redirect(url_for("schemes.get", scheme_id=scheme_id))


@bp.post("<int:scheme_id>/review")
@bp.post("<int:scheme_id>")
@bearer_auth
@inject.autoparams("clock", "schemes")
def review(clock: Clock, schemes: SchemeRepository, scheme_id: int) -> BaseResponse:
Expand Down
2 changes: 2 additions & 0 deletions schemes/views/templates/scheme/_review.html
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
{% from "govuk_frontend_jinja/components/checkboxes/macro.html" import govukCheckboxes %}

<form method="post" action="{{ url_for('schemes.review', scheme_id=id) }}" aria-label="Review scheme">
{{ review.form.csrf_token }}

{{ govukCheckboxes({
"name": "review",
"fieldset": {
Expand Down
9 changes: 9 additions & 0 deletions schemes/views/templates/scheme/index.html
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{% extends "service_base.html" %}
{% from "govuk_frontend_jinja/components/back-link/macro.html" import govukBackLink %}
{% from "govuk_frontend_jinja/components/notification-banner/macro.html" import govukNotificationBanner %}
{% from "govuk_frontend_jinja/components/tag/macro.html" import govukTag %}

{% block beforeContent %}
Expand All @@ -10,6 +11,14 @@
{% endblock %}

{% block content %}
{% with messages = get_flashed_messages() %}
{% if messages %}
{{ govukNotificationBanner(params={
"text": messages | first
}) }}
{% endif %}
{% endwith %}

<h1 class="govuk-heading-xl">
<span class="govuk-caption-xl">{{ authority_name }}</span>
<span>{{ name }}</span>
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ class SchemePage(PageObject):
def __init__(self, response: TestResponse):
super().__init__(response)
self.back_url = one(self._soup.select("a.govuk-back-link"))["href"]
notification_banner_tag = self._soup.select_one(".govuk-notification-banner")
self.notification_banner = (
NotificationBannerComponent(notification_banner_tag) if notification_banner_tag else None
)
self.authority = one(self._soup.select("main h1 .govuk-caption-xl")).string
self.name = one(self._soup.select("main h1 span:nth-child(2)")).string
tag = self._soup.select_one("main h1 .govuk-tag")
Expand Down
34 changes: 29 additions & 5 deletions tests/integration/test_scheme_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,19 @@ def test_scheme_shows_confirm(self, schemes: SchemeRepository, client: FlaskClie

scheme_page = SchemePage.open(client, id_=1)

assert scheme_page.review.confirm_url == "/schemes/1/review"
assert scheme_page.review.confirm_url == "/schemes/1"

def test_review_updates_last_reviewed(self, clock: Clock, schemes: SchemeRepository, client: FlaskClient) -> None:
def test_review_updates_last_reviewed(
self, clock: Clock, schemes: SchemeRepository, client: FlaskClient, csrf_token: str
) -> None:
clock.now = datetime(2023, 4, 24, 12)
scheme = Scheme(id_=1, name="Wirral Package", authority_id=1)
scheme.update_authority_review(
AuthorityReview(id_=1, review_date=datetime(2020, 1, 2), source=DataSource.ATF4_BID)
)
schemes.add(scheme)

client.post("/schemes/1/review", data={})
client.post("/schemes/1", data={"csrf_token": csrf_token})

actual_scheme = schemes.get(1)
assert actual_scheme
Expand All @@ -45,9 +47,31 @@ def test_review_updates_last_reviewed(self, clock: Clock, schemes: SchemeReposit
and authority_review2.source == DataSource.AUTHORITY_UPDATE
)

def test_review_shows_schemes(self, schemes: SchemeRepository, client: FlaskClient) -> None:
def test_review_shows_schemes(self, schemes: SchemeRepository, client: FlaskClient, csrf_token: str) -> None:
schemes.add(Scheme(id_=1, name="Wirral Package", authority_id=1))

response = client.post("/schemes/1/review", data={})
response = client.post("/schemes/1", data={"csrf_token": csrf_token})

assert response.status_code == 302 and response.location == "/schemes"

def test_cannot_review_when_no_csrf_token(self, schemes: SchemeRepository, client: FlaskClient) -> None:
schemes.add(Scheme(id_=1, name="Wirral Package", authority_id=1))

scheme_page = SchemePage(client.post("/schemes/1", data={}, follow_redirects=True))

assert scheme_page.name == "Wirral Package"
assert (
scheme_page.notification_banner
and scheme_page.notification_banner.heading == "The form you were submitting has expired. Please try again."
)

def test_cannot_review_when_incorrect_csrf_token(self, schemes: SchemeRepository, client: FlaskClient) -> None:
schemes.add(Scheme(id_=1, name="Wirral Package", authority_id=1))

scheme_page = SchemePage(client.post("/schemes/1", data={"csrf_token": "x"}, follow_redirects=True))

assert scheme_page.name == "Wirral Package"
assert (
scheme_page.notification_banner
and scheme_page.notification_banner.heading == "The form you were submitting has expired. Please try again."
)
21 changes: 21 additions & 0 deletions tests/views/schemes/test_schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from decimal import Decimal

import pytest
from flask_wtf import FlaskForm

from schemes.domain.authorities import Authority
from schemes.domain.dates import DateRange
Expand Down Expand Up @@ -43,6 +44,8 @@
FundingProgrammeRepr,
SchemeContext,
SchemeOverviewContext,
SchemeReviewContext,
SchemeReviewForm,
SchemeRowContext,
SchemesContext,
SchemeTypeContext,
Expand Down Expand Up @@ -131,6 +134,7 @@ def test_from_domain_sets_last_reviewed(self) -> None:
assert context.last_reviewed == datetime(2020, 1, 3, 12)


@pytest.mark.usefixtures("app")
class TestSchemeContext:
def test_from_domain(self) -> None:
authority = Authority(id_=2, name="Liverpool City Region Combined Authority")
Expand All @@ -144,6 +148,7 @@ def test_from_domain(self) -> None:
and context.name == "Wirral Package"
and not context.needs_review
and context.overview.reference == "ATE00001"
and isinstance(context.review.form, SchemeReviewForm)
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -345,6 +350,22 @@ def test_from_domain(self, funding_programme: FundingProgramme | None, expected_
assert context == FundingProgrammeContext(name=expected_name)


@pytest.mark.usefixtures("app")
class TestSchemeReviewContext:
def test_create(self) -> None:
context = SchemeReviewContext()

assert isinstance(context.form, SchemeReviewForm)


@pytest.mark.usefixtures("app")
class TestSchemeReviewForm:
def test_create(self) -> None:
form = SchemeReviewForm()

assert isinstance(form, FlaskForm)


class TestSchemeRepr:
def test_from_domain(self) -> None:
scheme = Scheme(id_=1, name="Wirral Package", authority_id=2)
Expand Down

0 comments on commit 458909e

Please sign in to comment.