Skip to content

Commit

Permalink
fix: Correct auth login rate limit routes (#4698)
Browse files Browse the repository at this point in the history
* Use fully qualified auth paths for rate limiter

* Add playwright rate limiter test

* Disable auth rate limits in ITs

* Format w/ prettier
  • Loading branch information
anticorrelator authored and RogerHYang committed Sep 21, 2024
1 parent 8892180 commit e79a54c
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 13 deletions.
6 changes: 6 additions & 0 deletions app/playwright.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ export default defineConfig({
name: "webkit",
use: { ...devices["Desktop Safari"] },
},
{
name: "rate limit",
use: { ...devices["Desktop Chrome"] },
dependencies: ["chromium", "firefox", "webkit"],
testMatch: "**/*.rate-limit.spec.ts",
},
/* Test against mobile viewports. */
// {
// name: 'Mobile Chrome',
Expand Down
6 changes: 5 additions & 1 deletion app/src/pages/auth/LoginForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ export function LoginForm(props: LoginFormProps) {
body: JSON.stringify(params),
});
if (!response.ok) {
setError("Invalid login");
const errorMessage =
response.status === 429
? "Too many requests. Please try again later."
: "Invalid login";
setError(errorMessage);
return;
}
} catch (error) {
Expand Down
17 changes: 17 additions & 0 deletions app/tests/login.rate-limit.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { expect, test } from "@playwright/test";

test("that login gets rate limited after too many attempts", async ({ page }) => {
await page.goto("/login");
await page.waitForURL("**/login");

const email = `fakeuser@localhost.com`;
// Add the user
await page.getByLabel("Email").fill(email);
await page.getByLabel("Password *", { exact: true }).fill("not-a-password");

const numberOfAttempts = 10;
for (let i = 0; i < numberOfAttempts; i++) {
await page.getByRole("button", { name: "Login" }).click();
}
await expect(page.getByText("Too many requests. Please try again later.")).toBeVisible();
});
2 changes: 2 additions & 0 deletions integration_tests/auth/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from faker import Faker
from phoenix.auth import DEFAULT_SECRET_LENGTH
from phoenix.config import (
ENV_PHOENIX_DISABLE_RATE_LIMIT,
ENV_PHOENIX_ENABLE_AUTH,
ENV_PHOENIX_SECRET,
ENV_PHOENIX_SMTP_HOSTNAME,
Expand Down Expand Up @@ -41,6 +42,7 @@ def _app(
) -> Iterator[None]:
values = (
(ENV_PHOENIX_ENABLE_AUTH, "true"),
(ENV_PHOENIX_DISABLE_RATE_LIMIT, "true"),
(ENV_PHOENIX_SECRET, _secret),
(ENV_PHOENIX_SMTP_HOSTNAME, "127.0.0.1"),
(ENV_PHOENIX_SMTP_PORT, str(pick_unused_port())),
Expand Down
8 changes: 8 additions & 0 deletions src/phoenix/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@

# Authentication settings
ENV_PHOENIX_ENABLE_AUTH = "PHOENIX_ENABLE_AUTH"
ENV_PHOENIX_DISABLE_RATE_LIMIT = "PHOENIX_DISABLE_RATE_LIMIT"
ENV_PHOENIX_SECRET = "PHOENIX_SECRET"
ENV_PHOENIX_API_KEY = "PHOENIX_API_KEY"
ENV_PHOENIX_USE_SECURE_COOKIES = "PHOENIX_USE_SECURE_COOKIES"
Expand Down Expand Up @@ -235,6 +236,13 @@ def get_env_enable_auth() -> bool:
return _bool_val(ENV_PHOENIX_ENABLE_AUTH, False)


def get_env_disable_rate_limit() -> bool:
"""
Gets the value of the PHOENIX_DISABLE_RATE_LIMIT environment variable.
"""
return _bool_val(ENV_PHOENIX_DISABLE_RATE_LIMIT, False)


def get_env_phoenix_secret() -> Optional[str]:
"""
Gets the value of the PHOENIX_SECRET environment variable
Expand Down
20 changes: 10 additions & 10 deletions src/phoenix/server/api/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
set_refresh_token_cookie,
validate_password_format,
)
from phoenix.config import get_base_url
from phoenix.config import get_base_url, get_env_disable_rate_limit
from phoenix.db import models
from phoenix.server.bearer_auth import PhoenixUser, create_access_and_refresh_tokens
from phoenix.server.email.templates.types import PasswordResetTemplateBody
Expand All @@ -48,23 +48,23 @@

rate_limiter = ServerRateLimiter(
per_second_rate_limit=0.2,
enforcement_window_seconds=30,
enforcement_window_seconds=60,
partition_seconds=60,
active_partitions=2,
)
login_rate_limiter = fastapi_ip_rate_limiter(
rate_limiter,
paths=[
"/login",
"/logout",
"/refresh",
"/password-reset-email",
"/password-reset",
"/auth/login",
"/auth/logout",
"/auth/refresh",
"/auth/password-reset-email",
"/auth/password-reset",
],
)
router = APIRouter(
prefix="/auth", include_in_schema=False, dependencies=[Depends(login_rate_limiter)]
)

auth_dependencies = [Depends(login_rate_limiter)] if not get_env_disable_rate_limit() else []
router = APIRouter(prefix="/auth", include_in_schema=False, dependencies=auth_dependencies)


@router.post("/login")
Expand Down
12 changes: 10 additions & 2 deletions src/phoenix/server/api/routers/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
set_oauth2_state_cookie,
set_refresh_token_cookie,
)
from phoenix.config import get_env_disable_rate_limit
from phoenix.db import models
from phoenix.db.enums import UserRole
from phoenix.server.bearer_auth import create_access_and_refresh_tokens
Expand Down Expand Up @@ -64,8 +65,15 @@
include_in_schema=False,
)

if not get_env_disable_rate_limit():
login_dependencies = [Depends(login_rate_limiter)]
create_tokens_dependencies = [Depends(create_tokens_rate_limiter)]
else:
login_dependencies = []
create_tokens_dependencies = []

@router.post("/{idp_name}/login", dependencies=[Depends(login_rate_limiter)])

@router.post("/{idp_name}/login", dependencies=login_dependencies)
async def login(
request: Request,
idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
Expand Down Expand Up @@ -99,7 +107,7 @@ async def login(
return response


@router.get("/{idp_name}/tokens", dependencies=[Depends(create_tokens_rate_limiter)])
@router.get("/{idp_name}/tokens", dependencies=create_tokens_dependencies)
async def create_tokens(
request: Request,
idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
Expand Down

0 comments on commit e79a54c

Please sign in to comment.