Skip to content

Commit

Permalink
Merge pull request #31 from sambarnes/auth-refactor
Browse files Browse the repository at this point in the history
fix: dedupe auth logic & patch potential timing attack
  • Loading branch information
louisgv authored Dec 31, 2023
2 parents af3f494 + c3aef53 commit 9261a6e
Show file tree
Hide file tree
Showing 13 changed files with 250 additions and 74 deletions.
27 changes: 27 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Test

on:
push:
branches: [main]
pull_request:

jobs:
unit:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [ "3.10", "3.11" ]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- uses: abatilo/actions-poetry@v2

- name: Install dependencies
run: poetry install --no-root

- name: Run tests
run: poetry run pytest
18 changes: 4 additions & 14 deletions modal/punctuator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os

from fastapi import Depends, HTTPException, status
from fastapi import Depends
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.security import HTTPAuthorizationCredentials
from pydantic import BaseModel
from runner.shared.common import stub
from shared.config import Config
Expand Down Expand Up @@ -53,8 +53,6 @@ def download_models():
api_key_id="RUNNER_API_KEY",
)

auth_scheme = HTTPBearer()


@stub.cls(
image=_image,
Expand Down Expand Up @@ -109,17 +107,9 @@ def punctuate_thread():
)
@web_endpoint(method="POST")
def punct(
payload: Payload, token: HTTPAuthorizationCredentials = Depends(auth_scheme)
payload: Payload,
_token: HTTPAuthorizationCredentials = Depends(config.auth),
):
import os

if token.credentials != os.environ[config.api_key_id]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
headers={"WWW-Authenticate": "Bearer"},
)

p = Punctuator()

return StreamingResponse(
Expand Down
17 changes: 3 additions & 14 deletions modal/runner/endpoints/completion.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os

from fastapi import Depends, HTTPException, status
from fastapi import Depends, status
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.security import HTTPAuthorizationCredentials
from runner.containers import get_container
from runner.shared.common import BACKLOG_THRESHOLD, config
from runner.shared.sampling_params import SamplingParams
Expand All @@ -11,20 +9,11 @@
create_error_response,
)

auth_scheme = HTTPBearer()


def completion(
payload: Payload,
token: HTTPAuthorizationCredentials = Depends(auth_scheme),
_token: HTTPAuthorizationCredentials = Depends(config.auth),
):
if token.credentials != os.environ[config.api_key_id]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
headers={"WWW-Authenticate": "Bearer"},
)

try:
runner = get_container(payload.model)
stats = runner.generate.get_current_stats()
Expand Down
21 changes: 9 additions & 12 deletions modal/shap-e/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional

from fastapi import Depends, HTTPException, status
from fastapi import Depends
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel
Expand All @@ -24,6 +24,7 @@ class Generation(BaseModel):
uri: Optional[str] = None
url: Optional[str] = None


class ResponseBody(BaseModel):
outputs: List[Generation]

Expand Down Expand Up @@ -132,7 +133,11 @@ def make_object():
base64_data = base64.b64encode(buffer.read()).decode(
"utf-8"
)
outputs.append(Generation(uri=f"data:application/x-ply;base64,{base64_data}"))
outputs.append(
Generation(
uri=f"data:application/x-ply;base64,{base64_data}"
)
)

output[0] = ResponseBody(outputs=outputs).json(
ensure_ascii=False
Expand All @@ -159,17 +164,9 @@ def make_object():
)
@web_endpoint(method="POST")
def create(
payload: Payload, token: HTTPAuthorizationCredentials = Depends(auth_scheme)
payload: Payload,
_token: HTTPAuthorizationCredentials = Depends(config.auth),
):
import os

if token.credentials != os.environ[config.api_key_id]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
headers={"WWW-Authenticate": "Bearer"},
)

p = Model()

return StreamingResponse(
Expand Down
Empty file added modal/shared/__init__.py
Empty file.
31 changes: 31 additions & 0 deletions modal/shared/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,37 @@
import os
import secrets

from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel

_auth = HTTPBearer()


class Config(BaseModel):
name: str
api_key_id: str

def auth(
self, token: HTTPAuthorizationCredentials = Depends(_auth)
) -> HTTPAuthorizationCredentials:
"""
API Authentication dependency for protected endpoints. Checks that the request's bearer token
matches the server's configured API key.
Raises:
* HTTPException(403) if no token is provided.
* HTTPException(401) if the token is invalid.
"""

# Timing attacks possible through direct comparison. Prevent it with a constant time comparison here.
got_credential = token.credentials.encode()
want_credential = os.environ[self.api_key_id].encode()
if not secrets.compare_digest(got_credential, want_credential):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
headers={"WWW-Authenticate": "Bearer"},
)

return token
32 changes: 32 additions & 0 deletions modal/shared/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os

from fastapi import Depends, FastAPI, testclient
from fastapi.security import HTTPAuthorizationCredentials

from shared.config import Config


def test_auth():
"""The API auth dependency should prevent unauthorized requests."""

app = FastAPI()
config = Config(name="test", api_key_id="RUNNER_API_KEY")
os.environ["RUNNER_API_KEY"] = "abc123"

@app.get("/test")
def test(_token: HTTPAuthorizationCredentials = Depends(config.auth)):
return "OK"

with testclient.TestClient(app) as client:
response = client.get("/test")
assert response.status_code == 403

response = client.get(
"/test", headers={"Authorization": "Bearer invalid"}
)
assert response.status_code == 401

response = client.get(
"/test", headers={"Authorization": "Bearer abc123"}
)
assert response.status_code == 200
18 changes: 3 additions & 15 deletions modal/tuner/endpoints/create_lora.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,11 @@
import os

from fastapi import Depends, HTTPException, status
from fastapi import Depends
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.security import HTTPAuthorizationCredentials
from tuner.containers.mistral_7b_lora import Mistral7BLoraContainer
from tuner.shared.common import config

auth_scheme = HTTPBearer()


def create_lora(
token: HTTPAuthorizationCredentials = Depends(auth_scheme),
):
if token.credentials != os.environ[config.api_key_id]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
headers={"WWW-Authenticate": "Bearer"},
)
def create_lora(_token: HTTPAuthorizationCredentials = Depends(config.auth)):
tuner = Mistral7BLoraContainer()
return StreamingResponse(
tuner.generate.remote_gen(),
Expand Down
16 changes: 3 additions & 13 deletions modal/tuner/endpoints/list_lora.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,13 @@
import os

from fastapi import Depends, HTTPException, status
from fastapi import Depends, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.security import HTTPAuthorizationCredentials
from shared.volumes import loras_path
from tuner.shared.common import config

auth_scheme = HTTPBearer()


def list_lora(
token: HTTPAuthorizationCredentials = Depends(auth_scheme),
):
if token.credentials != os.environ[config.api_key_id]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
headers={"WWW-Authenticate": "Bearer"},
)
def list_lora(_token: HTTPAuthorizationCredentials = Depends(config.auth)):
# Get all files from the loras volume

files = os.listdir(loras_path)
Expand Down
Loading

0 comments on commit 9261a6e

Please sign in to comment.