Skip to content

Commit

Permalink
Improve gateway auth issues troubleshooting (#1569)
Browse files Browse the repository at this point in the history
If `dstack-gateway` experiences issues when
checking authorization:
- Log a more understandable error message
- Send 5xx error code to the user
  • Loading branch information
jvstme authored Aug 16, 2024
1 parent f561e4f commit c18c37e
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 33 deletions.
17 changes: 5 additions & 12 deletions gateway/src/dstack/gateway/auth/routes.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
from fastapi import APIRouter, Depends, HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi import APIRouter, Depends

from dstack.gateway.core.auth import AuthProvider, get_auth
from dstack.gateway.core.auth import access_to_project_required

router = APIRouter()


# TODO(egor-s): support Authorization header alternative for web browsers


@router.get("/{project}")
async def get_auth(
project: str,
token: HTTPAuthorizationCredentials = Security(HTTPBearer()),
auth: AuthProvider = Depends(get_auth),
):
if await auth.has_access(project, token.credentials):
return {"status": "ok"}
raise HTTPException(status_code=403)
@router.get("/{project}", dependencies=[Depends(access_to_project_required)])
async def get_auth():
return {"status": "ok"}
35 changes: 28 additions & 7 deletions gateway/src/dstack/gateway/core/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import httpx
from aiocache import cached
from fastapi import Depends, HTTPException, Security, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

from dstack.gateway.common import AsyncClientWrapper

Expand All @@ -14,20 +16,39 @@ class AuthProvider:
def __init__(self):
self.client = AsyncClientWrapper(base_url=f"http://localhost:{DSTACK_SERVER_TUNNEL_PORT}")

@cached(ttl=60, noself=True)
async def has_access(self, project: str, token: str) -> bool:
@cached(ttl=60, noself=True, skip_cache_func=lambda r: r is None)
async def has_access(self, project: str, token: str) -> bool | None:
"""True - yes, False - no, None - failed checking"""

try:
resp = await self.client.post(
f"/api/projects/{project}/get",
headers={"Authorization": f"Bearer {token}"},
)
if resp.status_code == 200:
return True
except httpx.RequestError as e:
logger.debug("Failed to check access: %r", e)
return False
if resp.status_code == httpx.codes.FORBIDDEN:
return False
resp.raise_for_status()
except httpx.HTTPError as e:
logger.error("Failed requesting dstack-server to check access: %r", e)
return None
return True


@lru_cache()
def get_auth() -> AuthProvider:
return AuthProvider()


async def access_to_project_required(
project: str,
auth: AuthProvider = Depends(get_auth),
token: HTTPAuthorizationCredentials = Security(HTTPBearer()),
) -> None:
has_access = await auth.has_access(project, token.credentials)
if has_access is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Internal error when checking authorization. Try again later",
)
if not has_access:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
17 changes: 3 additions & 14 deletions gateway/src/dstack/gateway/openai/routes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from fastapi import APIRouter, Depends, HTTPException, Security
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from typing_extensions import Annotated, AsyncIterator

from dstack.gateway.core.auth import AuthProvider, get_auth
from dstack.gateway.core.auth import access_to_project_required
from dstack.gateway.openai.schemas import (
ChatCompletionsChunk,
ChatCompletionsRequest,
Expand All @@ -12,17 +11,7 @@
)
from dstack.gateway.openai.store import OpenAIStore, get_store


async def auth_required(
project: str,
auth: AuthProvider = Depends(get_auth),
token: HTTPAuthorizationCredentials = Security(HTTPBearer()),
):
if not await auth.has_access(project, token.credentials):
raise HTTPException(status_code=403)


router = APIRouter(dependencies=[Depends(auth_required)])
router = APIRouter(dependencies=[Depends(access_to_project_required)])


@router.get("/{project}/models")
Expand Down

0 comments on commit c18c37e

Please sign in to comment.