diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..0203654 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,27 @@ +name: Test + +on: + push: + branches: [main] + pull_request: + +jobs: + unit: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [ "3.10", "3.11" ] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - uses: abatilo/actions-poetry@v2 + + - name: Install dependencies + run: poetry install --no-root + + - name: Run tests + run: poetry run pytest diff --git a/modal/punctuator/__init__.py b/modal/punctuator/__init__.py index 69136fa..0b26f67 100644 --- a/modal/punctuator/__init__.py +++ b/modal/punctuator/__init__.py @@ -1,8 +1,8 @@ import os -from fastapi import Depends, HTTPException, status +from fastapi import Depends from fastapi.responses import StreamingResponse -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from fastapi.security import HTTPAuthorizationCredentials from pydantic import BaseModel from runner.shared.common import stub from shared.config import Config @@ -53,8 +53,6 @@ def download_models(): api_key_id="RUNNER_API_KEY", ) -auth_scheme = HTTPBearer() - @stub.cls( image=_image, @@ -109,17 +107,9 @@ def punctuate_thread(): ) @web_endpoint(method="POST") def punct( - payload: Payload, token: HTTPAuthorizationCredentials = Depends(auth_scheme) + payload: Payload, + _token: HTTPAuthorizationCredentials = Depends(config.auth), ): - import os - - if token.credentials != os.environ[config.api_key_id]: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect bearer token", - headers={"WWW-Authenticate": "Bearer"}, - ) - p = Punctuator() return StreamingResponse( diff --git a/modal/runner/endpoints/completion.py b/modal/runner/endpoints/completion.py index e2b17ba..cacf300 100644 --- a/modal/runner/endpoints/completion.py +++ b/modal/runner/endpoints/completion.py @@ -1,8 +1,6 @@ -import os - -from fastapi import Depends, HTTPException, status +from fastapi import Depends, status from fastapi.responses import StreamingResponse -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from fastapi.security import HTTPAuthorizationCredentials from runner.containers import get_container from runner.shared.common import BACKLOG_THRESHOLD, config from runner.shared.sampling_params import SamplingParams @@ -11,20 +9,11 @@ create_error_response, ) -auth_scheme = HTTPBearer() - def completion( payload: Payload, - token: HTTPAuthorizationCredentials = Depends(auth_scheme), + _token: HTTPAuthorizationCredentials = Depends(config.auth), ): - if token.credentials != os.environ[config.api_key_id]: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect bearer token", - headers={"WWW-Authenticate": "Bearer"}, - ) - try: runner = get_container(payload.model) stats = runner.generate.get_current_stats() diff --git a/modal/shap-e/__init__.py b/modal/shap-e/__init__.py index 1e7917e..c0a034f 100644 --- a/modal/shap-e/__init__.py +++ b/modal/shap-e/__init__.py @@ -1,6 +1,6 @@ from typing import List, Optional -from fastapi import Depends, HTTPException, status +from fastapi import Depends from fastapi.responses import StreamingResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from pydantic import BaseModel @@ -24,6 +24,7 @@ class Generation(BaseModel): uri: Optional[str] = None url: Optional[str] = None + class ResponseBody(BaseModel): outputs: List[Generation] @@ -132,7 +133,11 @@ def make_object(): base64_data = base64.b64encode(buffer.read()).decode( "utf-8" ) - outputs.append(Generation(uri=f"data:application/x-ply;base64,{base64_data}")) + outputs.append( + Generation( + uri=f"data:application/x-ply;base64,{base64_data}" + ) + ) output[0] = ResponseBody(outputs=outputs).json( ensure_ascii=False @@ -159,17 +164,9 @@ def make_object(): ) @web_endpoint(method="POST") def create( - payload: Payload, token: HTTPAuthorizationCredentials = Depends(auth_scheme) + payload: Payload, + _token: HTTPAuthorizationCredentials = Depends(config.auth), ): - import os - - if token.credentials != os.environ[config.api_key_id]: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect bearer token", - headers={"WWW-Authenticate": "Bearer"}, - ) - p = Model() return StreamingResponse( diff --git a/modal/shared/__init__.py b/modal/shared/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modal/shared/config.py b/modal/shared/config.py index ca37d74..1ae1637 100644 --- a/modal/shared/config.py +++ b/modal/shared/config.py @@ -1,6 +1,37 @@ +import os +import secrets + +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from pydantic import BaseModel +_auth = HTTPBearer() + class Config(BaseModel): name: str api_key_id: str + + def auth( + self, token: HTTPAuthorizationCredentials = Depends(_auth) + ) -> HTTPAuthorizationCredentials: + """ + API Authentication dependency for protected endpoints. Checks that the request's bearer token + matches the server's configured API key. + + Raises: + * HTTPException(403) if no token is provided. + * HTTPException(401) if the token is invalid. + """ + + # Timing attacks possible through direct comparison. Prevent it with a constant time comparison here. + got_credential = token.credentials.encode() + want_credential = os.environ[self.api_key_id].encode() + if not secrets.compare_digest(got_credential, want_credential): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect bearer token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return token diff --git a/modal/shared/test_config.py b/modal/shared/test_config.py new file mode 100644 index 0000000..9a00fd7 --- /dev/null +++ b/modal/shared/test_config.py @@ -0,0 +1,32 @@ +import os + +from fastapi import Depends, FastAPI, testclient +from fastapi.security import HTTPAuthorizationCredentials + +from shared.config import Config + + +def test_auth(): + """The API auth dependency should prevent unauthorized requests.""" + + app = FastAPI() + config = Config(name="test", api_key_id="RUNNER_API_KEY") + os.environ["RUNNER_API_KEY"] = "abc123" + + @app.get("/test") + def test(_token: HTTPAuthorizationCredentials = Depends(config.auth)): + return "OK" + + with testclient.TestClient(app) as client: + response = client.get("/test") + assert response.status_code == 403 + + response = client.get( + "/test", headers={"Authorization": "Bearer invalid"} + ) + assert response.status_code == 401 + + response = client.get( + "/test", headers={"Authorization": "Bearer abc123"} + ) + assert response.status_code == 200 diff --git a/modal/tuner/endpoints/create_lora.py b/modal/tuner/endpoints/create_lora.py index 3506c7e..f12a645 100644 --- a/modal/tuner/endpoints/create_lora.py +++ b/modal/tuner/endpoints/create_lora.py @@ -1,23 +1,11 @@ -import os - -from fastapi import Depends, HTTPException, status +from fastapi import Depends from fastapi.responses import StreamingResponse -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from fastapi.security import HTTPAuthorizationCredentials from tuner.containers.mistral_7b_lora import Mistral7BLoraContainer from tuner.shared.common import config -auth_scheme = HTTPBearer() - -def create_lora( - token: HTTPAuthorizationCredentials = Depends(auth_scheme), -): - if token.credentials != os.environ[config.api_key_id]: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect bearer token", - headers={"WWW-Authenticate": "Bearer"}, - ) +def create_lora(_token: HTTPAuthorizationCredentials = Depends(config.auth)): tuner = Mistral7BLoraContainer() return StreamingResponse( tuner.generate.remote_gen(), diff --git a/modal/tuner/endpoints/list_lora.py b/modal/tuner/endpoints/list_lora.py index ec53dd0..e3571a5 100644 --- a/modal/tuner/endpoints/list_lora.py +++ b/modal/tuner/endpoints/list_lora.py @@ -1,23 +1,13 @@ import os -from fastapi import Depends, HTTPException, status +from fastapi import Depends, status from fastapi.responses import JSONResponse -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from fastapi.security import HTTPAuthorizationCredentials from shared.volumes import loras_path from tuner.shared.common import config -auth_scheme = HTTPBearer() - -def list_lora( - token: HTTPAuthorizationCredentials = Depends(auth_scheme), -): - if token.credentials != os.environ[config.api_key_id]: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect bearer token", - headers={"WWW-Authenticate": "Bearer"}, - ) +def list_lora(_token: HTTPAuthorizationCredentials = Depends(config.auth)): # Get all files from the loras volume files = os.listdir(loras_path) diff --git a/poetry.lock b/poetry.lock index 614c0d6..4efdcea 100644 --- a/poetry.lock +++ b/poetry.lock @@ -721,6 +721,17 @@ multidict = "*" [package.extras] protobuf = ["protobuf (>=3.15.0)"] +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false +python-versions = ">=3.7" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + [[package]] name = "h2" version = "4.1.0" @@ -747,6 +758,51 @@ files = [ {file = "hpack-4.0.0.tar.gz", hash = "sha256:fc41de0c63e687ebffde81187a948221294896f6bdc0ae2312708df339430095"}, ] +[[package]] +name = "httpcore" +version = "1.0.2" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"}, + {file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<0.23.0)"] + +[[package]] +name = "httpx" +version = "0.26.0" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-0.26.0-py3-none-any.whl", hash = "sha256:8915f5a3627c4d47b73e8202457cb28f1266982d1159bd5779d86a80c0eab1cd"}, + {file = "httpx-0.26.0.tar.gz", hash = "sha256:451b55c30d5185ea6b23c2c793abf9bb237d2a7dfb901ced6ff69ad37ec1dfaf"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] + [[package]] name = "huggingface-hub" version = "0.17.3" @@ -835,6 +891,17 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker perf = ["ipython"] testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + [[package]] name = "jinja2" version = "3.1.2" @@ -1322,6 +1389,21 @@ files = [ docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-autodoc-typehints (>=1.24)"] test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"] +[[package]] +name = "pluggy" +version = "1.3.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.3.0-py3-none-any.whl", hash = "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7"}, + {file = "pluggy-1.3.0.tar.gz", hash = "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + [[package]] name = "pre-commit" version = "3.6.0" @@ -1587,6 +1669,28 @@ files = [ [package.extras] plugins = ["importlib-metadata"] +[[package]] +name = "pytest" +version = "7.4.3" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-7.4.3-py3-none-any.whl", hash = "sha256:0d009c083ea859a71b76adf7c1d502e4bc170b80a8ef002da5806527b9591fac"}, + {file = "pytest-7.4.3.tar.gz", hash = "sha256:d989d136982de4e3b29dabcc838ad581c64e8ed52c11fbe86ddebd9da0818cd5"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<2.0" +tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -2387,6 +2491,17 @@ files = [ {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, ] +[[package]] +name = "tomli" +version = "2.0.1" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, + {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, +] + [[package]] name = "torch" version = "2.1.0" @@ -2964,4 +3079,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "7293a8b45f3926365eec2b827316c05b4e4dfc02cfd50b89bc834cefa991a0dc" +content-hash = "c071d9226cf9a6cbb6c553260c11b7da2a2ea67be3f022db1df08a3308d19323" diff --git a/pyproject.toml b/pyproject.toml index 29f8f85..b2210e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,10 +17,11 @@ datasets = "^2.14.5" scipy = "^1.11.3" wandb = "^0.15.12" - [tool.poetry.group.dev.dependencies] ruff = "^0.1.9" pre-commit = "^3.6.0" +pytest = "^7.4.3" +httpx = "^0.26.0" [build-system] requires = ["poetry-core"] diff --git a/scripts/shared.ts b/scripts/shared.ts index 801b13d..6c53516 100644 --- a/scripts/shared.ts +++ b/scripts/shared.ts @@ -16,10 +16,11 @@ export async function completion( model = defaultModel, max_tokens = 16, stream = false, - stop = [''] + stop = [''], + apiKey = key } = {} ) { - if (!url || !key) { + if (!url || !apiKey) { throw new Error('Missing url or key'); } @@ -27,7 +28,7 @@ export async function completion( id: Math.random().toString(36).substring(7), prompt, model, - params: { max_tokens, stop }, + params: {max_tokens, stop}, stream }; @@ -35,7 +36,7 @@ export async function completion( method: 'POST', headers: { 'Content-Type': 'application/json', - Authorization: `Bearer ${key}` + Authorization: `Bearer ${apiKey}` }, body: JSON.stringify(bodyPayload) }); @@ -47,6 +48,10 @@ export async function completion( const output = await p.text(); console.log(output); } + + if (!p.ok) { + throw new Error(`Status: ${p.status}`); + } } export function isEntryFile(url: string) { diff --git a/scripts/test-simple.ts b/scripts/test-simple.ts index ab1ae7d..06dfd04 100644 --- a/scripts/test-simple.ts +++ b/scripts/test-simple.ts @@ -15,6 +15,17 @@ async function main(model?: string) { max_tokens: 1024, stop: [''], }); + + // Unauthorized requests should fail with a 401 + let gotExpectedError = false; + try { + await completion(prompt, {model, apiKey: "BADKEY"}); + } catch (e: any) { + gotExpectedError = e.message == "Status: 401"; + } + if (!gotExpectedError) { + throw new Error("Unauthorized request returned unexpected response") + } } runIfCalledAsScript(main, import.meta.url);