Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: dedupe auth logic & patch potential timing attack #31

Merged
merged 11 commits into from
Dec 31, 2023
31 changes: 8 additions & 23 deletions modal/punctuator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from typing import Annotated

from fastapi import Depends, HTTPException, status
from fastapi import Depends
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.security import HTTPAuthorizationCredentials
from pydantic import BaseModel
from runner.shared.common import stub
from shared.config import Config
Expand Down Expand Up @@ -37,9 +38,7 @@ def download_models():
_gpu = gpu.A10G(count=1)
_image = (
Image.from_registry("nvcr.io/nvidia/pytorch:22.12-py3")
.pip_install(
"torch==2.0.1+cu118", index_url="https://download.pytorch.org/whl/cu118"
)
.pip_install("torch==2.0.1+cu118", index_url="https://download.pytorch.org/whl/cu118")
.pip_install("sentencepiece")
.pip_install("deepmultilingualpunctuation")
.pip_install("hf-transfer~=0.1")
Expand All @@ -53,8 +52,6 @@ def download_models():
api_key_id="RUNNER_API_KEY",
)

auth_scheme = HTTPBearer()


@stub.cls(
image=_image,
Expand All @@ -75,16 +72,12 @@ def transform(self, input_str: str):
import threading
import time

output = [
None
] # Use a list to hold the output to bypass Python's scoping limitations
output = [None] # Use a list to hold the output to bypass Python's scoping limitations
output_ready = threading.Event()

def punctuate_thread():
try:
output[0] = create_response_text(
self.model.restore_punctuation(input_str)
)
output[0] = create_response_text(self.model.restore_punctuation(input_str))
except Exception as err:
output[0] = create_error_text(err)
print(output[0])
Expand All @@ -109,17 +102,9 @@ def punctuate_thread():
)
@web_endpoint(method="POST")
def punct(
payload: Payload, token: HTTPAuthorizationCredentials = Depends(auth_scheme)
payload: Payload,
_token: Annotated[HTTPAuthorizationCredentials, Depends(config.auth)],
):
import os

if token.credentials != os.environ[config.api_key_id]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
headers={"WWW-Authenticate": "Bearer"},
)

p = Punctuator()

return StreamingResponse(
Expand Down
17 changes: 4 additions & 13 deletions modal/runner/endpoints/completion.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
from typing import Annotated

from fastapi import Depends, HTTPException, status
from fastapi import Depends, status
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.security import HTTPAuthorizationCredentials
from runner.containers import get_container
from runner.shared.common import BACKLOG_THRESHOLD, config
from runner.shared.sampling_params import SamplingParams
Expand All @@ -11,20 +11,11 @@
create_error_response,
)

auth_scheme = HTTPBearer()


def completion(
payload: Payload,
token: HTTPAuthorizationCredentials = Depends(auth_scheme),
_token: Annotated[HTTPAuthorizationCredentials, Depends(config.auth)],
):
if token.credentials != os.environ[config.api_key_id]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
headers={"WWW-Authenticate": "Bearer"},
)

try:
runner = get_container(payload.model)
stats = runner.generate.get_current_stats()
Expand Down
29 changes: 8 additions & 21 deletions modal/shap-e/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional
from typing import Annotated, List, Optional

from fastapi import Depends, HTTPException, status
from fastapi import Depends
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel
Expand All @@ -24,6 +24,7 @@ class Generation(BaseModel):
uri: Optional[str] = None
url: Optional[str] = None


class ResponseBody(BaseModel):
outputs: List[Generation]

Expand Down Expand Up @@ -86,9 +87,7 @@ def generate(self, payload: Payload):
import threading
import time

output = [
None
] # Use a list to hold the output to bypass Python's scoping limitations
output = [None] # Use a list to hold the output to bypass Python's scoping limitations
output_ready = threading.Event()

def make_object():
Expand Down Expand Up @@ -129,14 +128,10 @@ def make_object():
buffer.seek(0)

# Encode the buffer content to base64
base64_data = base64.b64encode(buffer.read()).decode(
"utf-8"
)
base64_data = base64.b64encode(buffer.read()).decode("utf-8")
outputs.append(Generation(uri=f"data:application/x-ply;base64,{base64_data}"))

output[0] = ResponseBody(outputs=outputs).json(
ensure_ascii=False
)
output[0] = ResponseBody(outputs=outputs).json(ensure_ascii=False)

except Exception as err:
output[0] = create_error_text(err)
Expand All @@ -159,17 +154,9 @@ def make_object():
)
@web_endpoint(method="POST")
def create(
payload: Payload, token: HTTPAuthorizationCredentials = Depends(auth_scheme)
payload: Payload,
_token: Annotated[HTTPAuthorizationCredentials, Depends(config.auth)],
):
import os

if token.credentials != os.environ[config.api_key_id]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
headers={"WWW-Authenticate": "Bearer"},
)

p = Model()

return StreamingResponse(
Expand Down
Empty file added modal/shared/__init__.py
Empty file.
30 changes: 30 additions & 0 deletions modal/shared/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,36 @@
import os
import secrets
from typing import Annotated

from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel

_auth = HTTPBearer()


class Config(BaseModel):
name: str
api_key_id: str

def auth(self, token: Annotated[HTTPAuthorizationCredentials, Depends(_auth)]) -> HTTPAuthorizationCredentials:
"""
API Authentication dependency for protected endpoints. Checks that the request's bearer token
matches the server's configured API key.

Raises:
* HTTPException(403) if no token is provided.
* HTTPException(401) if the token is invalid.
"""

# Timing attacks possible through direct comparison. Prevent it with a constant time comparison here.
got_credential = token.credentials.encode()
want_credential = os.environ[self.api_key_id].encode()
if not secrets.compare_digest(got_credential, want_credential):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
headers={"WWW-Authenticate": "Bearer"},
)

return token
29 changes: 29 additions & 0 deletions modal/shared/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
from typing import Annotated

from fastapi import Depends, FastAPI, testclient
from fastapi.security import HTTPAuthorizationCredentials

from shared.config import Config


def test_auth():
"""The API auth dependency should prevent unauthorized requests."""

app = FastAPI()
config = Config(name="test", api_key_id="RUNNER_API_KEY")
os.environ["RUNNER_API_KEY"] = "abc123"

@app.get("/test")
def test(_token: Annotated[HTTPAuthorizationCredentials, Depends(config.auth)]):
return "OK"

with testclient.TestClient(app) as client:
response = client.get("/test")
assert response.status_code == 403

response = client.get("/test", headers={"Authorization": "Bearer invalid"})
assert response.status_code == 401

response = client.get("/test", headers={"Authorization": "Bearer abc123"})
assert response.status_code == 200
18 changes: 4 additions & 14 deletions modal/tuner/endpoints/create_lora.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,13 @@
import os
from typing import Annotated

from fastapi import Depends, HTTPException, status
from fastapi import Depends
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.security import HTTPAuthorizationCredentials
from tuner.containers.mistral_7b_lora import Mistral7BLoraContainer
from tuner.shared.common import config

auth_scheme = HTTPBearer()


def create_lora(
token: HTTPAuthorizationCredentials = Depends(auth_scheme),
):
if token.credentials != os.environ[config.api_key_id]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
headers={"WWW-Authenticate": "Bearer"},
)
def create_lora(_token: Annotated[HTTPAuthorizationCredentials, Depends(config.auth)]):
tuner = Mistral7BLoraContainer()
return StreamingResponse(
tuner.generate.remote_gen(),
Expand Down
17 changes: 4 additions & 13 deletions modal/tuner/endpoints/list_lora.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
import os
from typing import Annotated

from fastapi import Depends, HTTPException, status
from fastapi import Depends, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.security import HTTPAuthorizationCredentials
from shared.volumes import loras_path
from tuner.shared.common import config

auth_scheme = HTTPBearer()


def list_lora(
token: HTTPAuthorizationCredentials = Depends(auth_scheme),
):
if token.credentials != os.environ[config.api_key_id]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
headers={"WWW-Authenticate": "Bearer"},
)
def list_lora(_token: Annotated[HTTPAuthorizationCredentials, Depends(config.auth)]):
# Get all files from the loras volume

files = os.listdir(loras_path)
Expand Down
Loading