Skip to content

Commit

Permalink
ref: set Request.auth and HttpRequest.auth to our real auth type
Browse files Browse the repository at this point in the history
  • Loading branch information
asottile-sentry committed Jan 10, 2025
1 parent 732d636 commit 00cc029
Show file tree
Hide file tree
Showing 13 changed files with 139 additions and 79 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ module = [
"sentry.api.endpoints.event_attachments",
"sentry.api.endpoints.group_integration_details",
"sentry.api.endpoints.group_integrations",
"sentry.api.endpoints.index",
"sentry.api.endpoints.internal.mail",
"sentry.api.endpoints.organization_details",
"sentry.api.endpoints.organization_events_facets_performance",
Expand Down
10 changes: 4 additions & 6 deletions src/sentry/api/bases/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def is_not_2fa_compliant(
if not organization.flags.require_2fa:
return False

if request.user.has_2fa(): # type: ignore[union-attr]
if request.user.is_authenticated and request.user.has_2fa():
return False

if is_active_superuser(request):
Expand Down Expand Up @@ -661,12 +661,10 @@ def has_release_permission(
actor_id = None
has_perms = None
key = None
if getattr(request, "user", None) and request.user.id:
if request.user is not None and request.user.id:
actor_id = "user:%s" % request.user.id
if getattr(request, "auth", None) and getattr(request.auth, "id", None):
actor_id = "apikey:%s" % request.auth.id # type: ignore[union-attr]
elif getattr(request, "auth", None) and getattr(request.auth, "entity_id", None):
actor_id = "apikey:%s" % request.auth.entity_id # type: ignore[union-attr]
if request.auth is not None and getattr(request.auth, "entity_id", None):
actor_id = "apikey:%s" % request.auth.entity_id
if actor_id is not None:
requested_project_ids = project_ids
if requested_project_ids is None:
Expand Down
6 changes: 3 additions & 3 deletions src/sentry/api/endpoints/project_user_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get(self, request: Request, project) -> Response:
:auth: required
"""
# we don't allow read permission with DSNs
if request.auth is not None and request.auth.kind == "project_key": # type: ignore[union-attr]
if request.auth is not None and request.auth.kind == "project_key":
return self.respond(status=401)

paginate_kwargs: _PaginateKwargs = {}
Expand Down Expand Up @@ -118,7 +118,7 @@ def post(self, request: Request, project) -> Response:
:param string email: user's email address
:param string comments: comments supplied by user
"""
if request.auth is not None and project.id != request.auth.project_id: # type: ignore[union-attr] # TODO: real .auth typing
if request.auth is not None and project.id != request.auth.project_id:
return self.respond(status=401)

serializer = UserReportSerializer(data=request.data)
Expand All @@ -133,7 +133,7 @@ def post(self, request: Request, project) -> Response:
except Conflict as e:
return self.respond({"detail": str(e)}, status=409)

if request.auth is not None and request.auth.kind == "project_key": # type: ignore[union-attr]
if request.auth is not None and request.auth.kind == "project_key":
return self.respond(status=200)

return self.respond(
Expand Down
8 changes: 3 additions & 5 deletions src/sentry/api/endpoints/release_deploys.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,10 @@ def post(self, request: Request, organization, version) -> Response:
# Logic here copied from `has_release_permission` (lightly edited for results to be more
# human-readable)
auth = None
if getattr(request, "user", None) and request.user.id:
if request.user is not None and request.user.id:
auth = f"user.id: {request.user.id}"
elif getattr(request, "auth", None) and getattr(request.auth, "id", None):
auth = f"auth.id: {request.auth.id}" # type: ignore[union-attr]
elif getattr(request, "auth", None) and getattr(request.auth, "entity_id", None):
auth = f"auth.entity_id: {request.auth.entity_id}" # type: ignore[union-attr]
elif request.auth is not None and getattr(request.auth, "entity_id", None):
auth = f"auth.entity_id: {request.auth.entity_id}"
if auth is not None:
logging_info.update({"auth": auth})
logger.info(
Expand Down
2 changes: 1 addition & 1 deletion src/sentry/api/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class ScopedPermission(BasePermission):

def has_permission(self, request: Request, view: object) -> bool:
# session-based auth has all scopes for a logged in user
if not getattr(request, "auth", None):
if request.auth is None:
return request.user.is_authenticated

if is_org_auth_token_auth(request.auth):
Expand Down
2 changes: 1 addition & 1 deletion src/sentry/auth/staff.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@


def is_active_staff(request: HttpRequest | Request) -> bool:
if is_system_auth(getattr(request, "auth", None)):
if is_system_auth(request.auth):
return True
staff = getattr(request, "staff", None) or Staff(request)
return staff.is_active
Expand Down
2 changes: 1 addition & 1 deletion src/sentry/auth/superuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def superuser_has_permission(


def is_active_superuser(request: HttpRequest | Request) -> bool:
if is_system_auth(getattr(request, "auth", None)):
if is_system_auth(request.auth):
return True
su = getattr(request, "superuser", None) or Superuser(request)
return su.is_active
Expand Down
4 changes: 0 additions & 4 deletions src/sentry/monitors/endpoints/base_monitor_checkin_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ def get_monitor_checkins(self, request: Request, project, monitor) -> Response:
"""
Retrieve a list of check-ins for a monitor
"""
# we don't allow read permission with DSNs
if request.auth is not None and request.auth.kind == "project_key": # type: ignore[union-attr]
return self.respond(status=401)

start, end = get_date_range_from_params(request.GET)
if start is None or end is None:
raise ParseError(detail="Invalid date range")
Expand Down
2 changes: 2 additions & 0 deletions tests/sentry/api/bases/test_organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def build_request(self, user=None, active_superuser=False, **params):
if user is None:
user = self.user
request.user = user
request.auth = None
request.access = from_request(request, self.org)
return request

Expand Down Expand Up @@ -383,6 +384,7 @@ def test_none_user(self):
request = RequestFactory().get("/")
request.session = SessionBase()
request.access = NoAccess()
request.auth = None
result = self.endpoint.get_projects(request, self.org)
assert [] == result

Expand Down
1 change: 1 addition & 0 deletions tests/sentry/auth/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
def _set_up_request():
request = RequestFactory().post("/auth/sso/")
request.user = AnonymousUser()
request.auth = None
request.session = Client().session
return request

Expand Down
37 changes: 21 additions & 16 deletions tests/sentry/ratelimits/utils/test_get_ratelimit_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _populate_public_integration_request(self, request) -> None:

with assume_test_silo_mode_of(User):
request.user = User.objects.get(id=install.sentry_app.proxy_user_id)
request.auth = token
request.auth = AuthenticatedToken.from_token(token)

def _populate_internal_integration_request(self, request) -> None:
internal_integration = self.create_internal_integration(
Expand All @@ -69,7 +69,7 @@ def _populate_internal_integration_request(self, request) -> None:

with assume_test_silo_mode_of(User):
request.user = User.objects.get(id=internal_integration.proxy_user_id)
request.auth = token
request.auth = AuthenticatedToken.from_token(token)

def test_ips(self):
# Test for default IP
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_user(self):
)

def test_system_token(self):
self.request.auth = SystemToken()
self.request.auth = AuthenticatedToken.from_token(SystemToken())
assert (
get_rate_limit_key(
self.view, self.request, self.rate_limit_group, self.rate_limit_config
Expand All @@ -119,7 +119,7 @@ def test_system_token(self):
def test_api_token(self):
with assume_test_silo_mode_of(ApiToken):
token = ApiToken.objects.create(user=self.user, scope_list=["event:read", "org:read"])
self.request.auth = token
self.request.auth = AuthenticatedToken.from_token(token)
self.request.user = self.user
assert (
get_rate_limit_key(
Expand All @@ -135,7 +135,7 @@ def test_api_token_replica(self):
)
with assume_test_silo_mode_of(ApiTokenReplica):
token = ApiTokenReplica.objects.get(apitoken_id=apitoken.id)
self.request.auth = token
self.request.auth = AuthenticatedToken.from_token(token)
self.request.user = self.user

assert (
Expand All @@ -146,9 +146,6 @@ def test_api_token_replica(self):
)

def test_authenticated_token(self):
# Ensure AuthenticatedToken kinds are registered
import sentry.auth.services.auth.service # noqa: F401

with assume_test_silo_mode_of(ApiToken):
token = ApiToken.objects.create(user=self.user, scope_list=["event:read", "org:read"])
self.request.auth = AuthenticatedToken.from_token(token)
Expand All @@ -162,8 +159,8 @@ def test_authenticated_token(self):

def test_api_key(self):
self.request.user = AnonymousUser()
self.request.auth = self.create_api_key(
organization=self.organization, scope_list=["project:write"]
self.request.auth = AuthenticatedToken.from_token(
self.create_api_key(organization=self.organization, scope_list=["project:write"])
)

assert (
Expand All @@ -175,8 +172,8 @@ def test_api_key(self):

def test_org_auth_token(self):
self.request.user = AnonymousUser()
self.request.auth = self.create_org_auth_token(
organization_id=self.organization.id, scope_list=["org:ci"]
self.request.auth = AuthenticatedToken.from_token(
self.create_org_auth_token(organization_id=self.organization.id, scope_list=["org:ci"])
)

assert (
Expand All @@ -187,8 +184,11 @@ def test_org_auth_token(self):
)

def test_user_auth_token(self):
token = self.create_user_auth_token(user=self.user, scope_list=["event:read", "org:read"])
self.request.auth = token
with assume_test_silo_mode_of(User):
token = self.create_user_auth_token(
user=self.user, scope_list=["event:read", "org:read"]
)
self.request.auth = AuthenticatedToken.from_token(token)
self.request.user = self.user

assert (
Expand All @@ -210,11 +210,16 @@ def test_integration_tokens(self):

# Test for INTERNAL Integration api tokens
self._populate_internal_integration_request(self.request)
assert self.request.auth is not None
with assume_test_silo_mode_of(SentryAppInstallation, SentryAppInstallationToken):
# Ensure that the internal integration token lives in
# SentryAppInstallationToken instead of SentryAppInstallation
assert not SentryAppInstallation.objects.filter(api_token_id=self.request.auth.id)
assert SentryAppInstallationToken.objects.filter(api_token_id=self.request.auth.id)
assert not SentryAppInstallation.objects.filter(
api_token_id=self.request.auth.entity_id
)
assert SentryAppInstallationToken.objects.filter(
api_token_id=self.request.auth.entity_id
)
assert (
get_rate_limit_key(
self.view, self.request, self.rate_limit_group, self.rate_limit_config
Expand Down
91 changes: 54 additions & 37 deletions tests/tools/mypy_helpers/test_plugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import os.path
import pathlib
import shutil
import subprocess
import sys
Expand All @@ -10,6 +9,14 @@
import pytest


def _fill_init_pyi(tmpdir: str, path: str) -> str:
os.makedirs(os.path.join(tmpdir, path))
for part in path.split(os.sep):
tmpdir = os.path.join(tmpdir, part)
open(os.path.join(tmpdir, "__init__.pyi"), "a").close()
return tmpdir


def call_mypy(src: str, *, plugins: list[str] | None = None) -> tuple[int, str]:
if plugins is None:
plugins = ["tools.mypy_helpers.plugin"]
Expand All @@ -18,12 +25,32 @@ def call_mypy(src: str, *, plugins: list[str] | None = None) -> tuple[int, str]:
with open(cfg, "w") as f:
f.write(f"[tool.mypy]\nplugins = {plugins!r}\n")

# we stub several files in order to test our plugin
# the tests cannot depend on sentry being importable (it isn't!)
here = os.path.dirname(__file__)

# stubs for lazy_service_wrapper
utils_dir = _fill_init_pyi(tmpdir, "sentry/utils")
sentry_src = os.path.join(here, "../../../src/sentry/utils/lazy_service_wrapper.py")
shutil.copy(sentry_src, utils_dir)
with open(os.path.join(utils_dir, "__init__.pyi"), "w") as f:
f.write("from typing import Any\ndef __getattr__(k: str) -> Any: ...\n")

# stubs for auth types
auth_dir = _fill_init_pyi(tmpdir, "sentry/auth/services/auth")
with open(os.path.join(auth_dir, "model.pyi"), "w") as f:
f.write("class AuthenticatedToken: ...")

ret = subprocess.run(
(
*(sys.executable, "-m", "mypy"),
*("--config", cfg),
*("-c", src),
"--show-traceback",
# we only stub out limited parts of the sentry source tree
"--ignore-missing-imports",
),
env={**os.environ, "MYPYPATH": tmpdir},
capture_output=True,
encoding="UTF-8",
)
Expand Down Expand Up @@ -168,7 +195,30 @@ def test_added_http_request_attribute(attr: str) -> None:
assert ret == 0, (ret, out)


def test_lazy_service_wrapper(tmp_path: pathlib.Path) -> None:
def test_adjusted_drf_request_auth() -> None:
src = """\
from rest_framework.request import Request
x: Request
reveal_type(x.auth)
"""
expected_no_plugins = """\
<string>:3: note: Revealed type is "Union[rest_framework.authtoken.models.Token, Any]"
Success: no issues found in 1 source file
"""
expected_plugins = """\
<string>:3: note: Revealed type is "Union[sentry.auth.services.auth.model.AuthenticatedToken, None]"
Success: no issues found in 1 source file
"""
ret, out = call_mypy(src, plugins=[])
assert ret == 0
assert out == expected_no_plugins

ret, out = call_mypy(src)
assert ret == 0
assert out == expected_plugins


def test_lazy_service_wrapper() -> None:
src = """\
from typing import assert_type, Literal
from sentry.utils.lazy_service_wrapper import LazyServiceWrapper, Service, _EmptyType
Expand All @@ -195,42 +245,9 @@ def f(self) -> int:
Found 2 errors in 1 file (checked 1 source file)
"""

# tools tests aren't allowed to import from `sentry` so we fixture
# the particular source file we are testing
utils_dir = tmp_path.joinpath("sentry/utils")
utils_dir.mkdir(parents=True)

here = os.path.dirname(__file__)
sentry_src = os.path.join(here, "../../../src/sentry/utils/lazy_service_wrapper.py")
shutil.copy(sentry_src, utils_dir)

init_pyi = "from typing import Any\ndef __getattr__(self) -> Any: ...\n"
utils_dir.joinpath("__init__.pyi").write_text(init_pyi)

cfg = tmp_path.joinpath("mypy.toml")
cfg.write_text("[tool.mypy]\nplugins = []\n")

# can't use our helper above because we're fixturing sentry src, so mimic it here
def _mypy() -> tuple[int, str]:
ret = subprocess.run(
(
*(sys.executable, "-m", "mypy"),
*("--config", cfg),
# we only stub out limited parts of the sentry source tree
"--ignore-missing-imports",
*("-c", src),
),
env={**os.environ, "MYPYPATH": str(tmp_path)},
capture_output=True,
encoding="UTF-8",
)
assert not ret.stderr
return ret.returncode, ret.stdout

ret, out = _mypy()
ret, out = call_mypy(src, plugins=[])
assert ret
assert out == expected

cfg.write_text('[tool.mypy]\nplugins = ["tools.mypy_helpers.plugin"]\n')
ret, out = _mypy()
ret, out = call_mypy(src)
assert ret == 0
Loading

0 comments on commit 00cc029

Please sign in to comment.