Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Actually pass the forget and remember headers to the new response #457

Merged
merged 1 commit into from
Mar 22, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/accounts/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def test_post_validate_redirects(self, monkeypatch, pyramid_request,

assert isinstance(result, HTTPSeeOther)
assert result.headers["Location"] == "/"
assert result.headers["foo"] == "bar"
assert pyramid_request.find_service.calls == [
pretend.call(ILoginService),
]
Expand All @@ -143,7 +144,6 @@ def test_post_validate_redirects(self, monkeypatch, pyramid_request,
assert pyramid_request.session.invalidate.calls == [pretend.call()]
assert remember.calls == [pretend.call(pyramid_request, 1)]
assert pyramid_request.session.new_csrf_token.calls == [pretend.call()]
assert ("foo", "bar") in pyramid_request.response.headerlist


class TestLogout:
Expand All @@ -164,6 +164,6 @@ def test_post_forgets_user(self, monkeypatch, pyramid_request):

assert isinstance(result, HTTPSeeOther)
assert result.headers["Location"] == "/"
assert result.headers["foo"] == "bar"
assert forget.calls == [pretend.call(pyramid_request)]
assert pyramid_request.session.invalidate.calls == [pretend.call()]
assert ("foo", "bar") in pyramid_request.response.headerlist
6 changes: 2 additions & 4 deletions warehouse/accounts/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def login(request, _form_class=LoginForm):

# Remember the userid using the authentication policy.
headers = remember(request, userid)
request.response.headerlist.extend(headers)

# Cycle the CSRF token since we've crossed an authentication boundary
# and we don't want to continue using the old one.
Expand All @@ -97,7 +96,7 @@ def login(request, _form_class=LoginForm):
# where they were trying to go originally, or to the default view.
# TODO: Implement ?next= support.
# TODO: Figure out a better way to handle the "default view".
return HTTPSeeOther("/")
return HTTPSeeOther("/", headers=dict(headers))

return {"form": form}

Expand All @@ -117,7 +116,6 @@ def logout(request):
# CSRF attacks still because of the CSRF framework, so users will still
# need a post body that contains the CSRF token.
headers = forget(request)
request.response.headerlist.extend(headers)

# When crossing an authentication boundry we want to create a new
# session identifier. We don't want to keep any information in the
Expand All @@ -135,6 +133,6 @@ def logout(request):
# where they were originally, or to the default view.
# TODO: Implement ?next= support.
# TODO: Figure out a better way to handle the "default view".
return HTTPSeeOther("/")
return HTTPSeeOther("/", headers=dict(headers))

return {}