diff --git a/app/src/Routes.tsx b/app/src/Routes.tsx index 65dc68309b..d97d015dce 100644 --- a/app/src/Routes.tsx +++ b/app/src/Routes.tsx @@ -25,6 +25,7 @@ import { ExperimentComparePage, experimentsLoader, ExperimentsPage, + ForgotPasswordPage, homeLoader, LoginPage, ModelPage, @@ -37,6 +38,7 @@ import { ProjectsRoot, resetPasswordLoader, ResetPasswordPage, + ResetPasswordWithTokenPage, SettingsPage, TracePage, TracingRoot, @@ -51,6 +53,11 @@ const router = createBrowserRouter( element={} loader={resetPasswordLoader} /> + } + /> + } /> } loader={authenticatedRootLoader}> }> (null); + const [error, setError] = useState(null); + const [isLoading, setIsLoading] = useState(false); + const onSubmit = useCallback( + async (params: ForgotPasswordFormParams) => { + setMessage(null); + setError(null); + setIsLoading(true); + try { + const response = await fetch("/auth/password-reset-email", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(params), + }); + if (!response.ok) { + setError("Failed attempt"); + return; + } + } catch (error) { + setError("Failed attempt"); + return; + } finally { + setIsLoading(() => false); + } + setMessage( + "A link to reset your password has been sent. Check your email for details." + ); + }, + [setMessage, setError] + ); + const { control, handleSubmit } = useForm({ + defaultValues: { email: "" }, + }); + return ( + <> + {message ? ( + + {message} + + ) : null} + {error ? ( + + {error} + + ) : null} +
+ ( + + )} + /> +
+ +
+ + + ); +} diff --git a/app/src/pages/auth/ForgotPasswordPage.tsx b/app/src/pages/auth/ForgotPasswordPage.tsx new file mode 100644 index 0000000000..3c5a83181b --- /dev/null +++ b/app/src/pages/auth/ForgotPasswordPage.tsx @@ -0,0 +1,20 @@ +import React from "react"; + +import { Flex, View } from "@arizeai/components"; + +import { AuthLayout } from "./AuthLayout"; +import { ForgotPasswordForm } from "./ForgotPasswordForm"; +import { PhoenixLogo } from "./PhoenixLogo"; + +export function ForgotPasswordPage() { + return ( + + + + + + + + + ); +} diff --git a/app/src/pages/auth/LoginForm.tsx b/app/src/pages/auth/LoginForm.tsx index 39e825a7d3..a7bc185966 100644 --- a/app/src/pages/auth/LoginForm.tsx +++ b/app/src/pages/auth/LoginForm.tsx @@ -5,6 +5,7 @@ import { css } from "@emotion/react"; import { Alert, Button, Form, TextField, View } from "@arizeai/components"; +import { Link } from "@phoenix/components"; import { getReturnUrl } from "@phoenix/utils/routingUtils"; type LoginFormParams = { @@ -105,6 +106,7 @@ export function LoginForm() { > Login + Forgot password? diff --git a/app/src/pages/auth/ResetPasswordWithTokenForm.tsx b/app/src/pages/auth/ResetPasswordWithTokenForm.tsx new file mode 100644 index 0000000000..c1fa7f3eb6 --- /dev/null +++ b/app/src/pages/auth/ResetPasswordWithTokenForm.tsx @@ -0,0 +1,157 @@ +import React, { useCallback, useState } from "react"; +import { Controller, useForm } from "react-hook-form"; +import { useNavigate } from "react-router"; + +import { + Alert, + Button, + Flex, + Form, + TextField, + View, +} from "@arizeai/components"; + +const MIN_PASSWORD_LENGTH = 4; + +export type ResetPasswordWithTokenFormParams = { + resetToken: string; + newPassword: string; + confirmPassword: string; +}; + +interface ResetPasswordWithTokenFormProps { + resetToken: string; +} + +export function ResetPasswordWithTokenForm({ + resetToken, +}: ResetPasswordWithTokenFormProps) { + const navigate = useNavigate(); + const [message, setMessage] = useState(null); + const [error, setError] = useState(null); + const [isLoading, setIsLoading] = useState(false); + const onSubmit = useCallback( + async ({ resetToken, newPassword }: ResetPasswordWithTokenFormParams) => { + setMessage(null); + setError(null); + setIsLoading(true); + try { + const response = await fetch("/auth/password-reset", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ token: resetToken, password: newPassword }), + }); + if (!response.ok) { + setError("Failed attempt"); + return; + } + } catch (error) { + setError("Failed attempt"); + return; + } finally { + setIsLoading(() => false); + } + setMessage("Success"); + navigate("/login"); + }, + [setMessage, setError, navigate] + ); + const { + control, + handleSubmit, + formState: { isDirty }, + } = useForm({ + defaultValues: { + resetToken: resetToken, + newPassword: "", + confirmPassword: "", + }, + }); + return ( + <> + {message ? ( + + {message} + + ) : null} + {error ? ( + + {error} + + ) : null} +
+ ( + + )} + /> + + value === formValues.newPassword || "Passwords do not match", + }} + render={({ + field: { name, onChange, onBlur, value }, + fieldState: { invalid, error }, + }) => ( + + )} + /> + + + + + + + + ); +} diff --git a/app/src/pages/auth/ResetPasswordWithTokenPage.tsx b/app/src/pages/auth/ResetPasswordWithTokenPage.tsx new file mode 100644 index 0000000000..6ea2075d18 --- /dev/null +++ b/app/src/pages/auth/ResetPasswordWithTokenPage.tsx @@ -0,0 +1,29 @@ +import React from "react"; +import { useNavigate } from "react-router"; +import { useSearchParams } from "react-router-dom"; + +import { Flex, View } from "@arizeai/components"; + +import { AuthLayout } from "./AuthLayout"; +import { PhoenixLogo } from "./PhoenixLogo"; +import { ResetPasswordWithTokenForm } from "./ResetPasswordWithTokenForm"; + +export function ResetPasswordWithTokenPage() { + const navigate = useNavigate(); + const [searchParams] = useSearchParams(); + const token = searchParams.get("token"); + if (!token) { + navigate("/login"); + return null; + } + return ( + + + + + + + + + ); +} diff --git a/app/src/pages/auth/index.tsx b/app/src/pages/auth/index.tsx index da907c439a..42edb98745 100644 --- a/app/src/pages/auth/index.tsx +++ b/app/src/pages/auth/index.tsx @@ -1,3 +1,5 @@ export * from "./LoginPage"; export * from "./ResetPasswordPage"; +export * from "./ResetPasswordWithTokenPage"; export * from "./resetPasswordLoader"; +export * from "./ForgotPasswordPage"; diff --git a/integration_tests/_helpers.py b/integration_tests/_helpers.py index c92ee184ac..f85b0c7982 100644 --- a/integration_tests/_helpers.py +++ b/integration_tests/_helpers.py @@ -9,6 +9,7 @@ from contextvars import ContextVar from dataclasses import dataclass, replace from datetime import datetime, timezone +from email.message import Message from functools import cached_property from io import BytesIO from subprocess import PIPE, STDOUT @@ -32,12 +33,14 @@ Union, cast, ) -from urllib.parse import urljoin +from urllib.parse import parse_qs, urljoin, urlparse from urllib.request import urlopen +import bs4 import httpx import jwt import pytest +import smtpdfix from httpx import Headers, HTTPStatusError from openinference.semconv.resource import ResourceAttributes from opentelemetry.sdk.resources import Resource @@ -58,6 +61,7 @@ get_env_database_schema, get_env_grpc_port, get_env_host, + get_env_smtp_mail_from, ) from phoenix.server.api.auth import IsAdmin from phoenix.server.api.exceptions import Unauthorized @@ -217,6 +221,19 @@ def delete_api_key(self, api_key: _ApiKey, /) -> None: def export_embeddings(self, filename: str) -> None: _export_embeddings(self, filename=filename) + def initiate_password_reset( + self, + smtpd: smtpdfix.AuthController, + /, + *, + should_receive_email: bool = True, + ) -> Optional[_PasswordResetToken]: + return _initiate_password_reset( + self.email, + smtpd, + should_receive_email=should_receive_email, + ) + _SYSTEM_USER_GID = _GqlId(GlobalID(type_name="User", node_id="1")) _DEFAULT_ADMIN = _User( @@ -257,6 +274,11 @@ def kind(self) -> _ApiKeyKind: class _Token(_String, ABC): ... +class _PasswordResetToken(_Token): + def reset(self, password: _Password, /) -> None: + return _reset_password(self, password=password) + + class _AccessToken(_Token, _CanLogOut[None]): def log_out(self) -> None: _log_out(self) @@ -856,6 +878,37 @@ def _log_out( resp.raise_for_status() +def _initiate_password_reset( + email: _Email, + smtpd: smtpdfix.AuthController, + /, + *, + should_receive_email: bool = True, +) -> Optional[_PasswordResetToken]: + old_msg_count = len(smtpd.messages) + json_ = dict(email=email) + resp = _httpx_client().post("auth/password-reset-email", json=json_) + resp.raise_for_status() + new_msg_count = len(smtpd.messages) - old_msg_count + assert new_msg_count == int(should_receive_email) + if not should_receive_email: + return None + msg = smtpd.messages[-1] + assert msg["to"] == email + assert msg["from"] == get_env_smtp_mail_from() + return _extract_password_reset_token(msg) + + +def _reset_password( + token: _PasswordResetToken, + /, + password: _Password, +) -> None: + json_ = dict(token=token, password=password) + resp = _httpx_client().post("auth/password-reset", json=json_) + resp.raise_for_status() + + def _export_embeddings(auth: Optional[_SecurityArtifact] = None, /, *, filename: str) -> None: resp = _httpx_client(auth).get("/exports", params={"filename": filename}) resp.raise_for_status() @@ -912,3 +965,29 @@ def _decode_token_ids( jwt.decode(v, options={"verify_signature": False})["jti"] for v in _extract_tokens(headers, key).values() ] + + +def _extract_password_reset_token(msg: Message) -> _PasswordResetToken: + assert (soup := _extract_html(msg)) + assert isinstance((link := soup.find(id="reset-url")), bs4.Tag) + assert isinstance((url := link.get("href")), str) + assert url + params = parse_qs(urlparse(url).query) + assert (tokens := params["token"]) + assert (token := tokens[0]) + decoded = jwt.decode(token, options=dict(verify_signature=False)) + assert (jti := decoded["jti"]) + assert jti.startswith("PasswordResetToken") + return _PasswordResetToken(token) + + +def _extract_html(msg: Message) -> Optional[bs4.BeautifulSoup]: + for part in msg.walk(): + if ( + part.get_content_type() == "text/html" + and (payload := part.get_payload(decode=True)) + and isinstance(payload, bytes) + ): + content = payload.decode(part.get_content_charset() or "utf-8") + return bs4.BeautifulSoup(content, "html.parser") + return None diff --git a/integration_tests/auth/conftest.py b/integration_tests/auth/conftest.py index 3cc1d9b884..4ed2dccce3 100644 --- a/integration_tests/auth/conftest.py +++ b/integration_tests/auth/conftest.py @@ -5,8 +5,25 @@ from unittest import mock import pytest +from faker import Faker from phoenix.auth import DEFAULT_SECRET_LENGTH -from phoenix.config import ENV_PHOENIX_ENABLE_AUTH, ENV_PHOENIX_SECRET +from phoenix.config import ( + ENV_PHOENIX_ENABLE_AUTH, + ENV_PHOENIX_SECRET, + ENV_PHOENIX_SMTP_HOSTNAME, + ENV_PHOENIX_SMTP_MAIL_FROM, + ENV_PHOENIX_SMTP_PASSWORD, + ENV_PHOENIX_SMTP_PORT, + ENV_PHOENIX_SMTP_USERNAME, + ENV_PHOENIX_SMTP_VALIDATE_CERTS, + get_env_smtp_hostname, + get_env_smtp_password, + get_env_smtp_port, + get_env_smtp_username, +) +from portpicker import pick_unused_port # type: ignore[import-untyped] +from smtpdfix import AuthController, Config, SMTPDFix +from smtpdfix.certs import _generate_certs from .._helpers import _Secret, _server @@ -20,12 +37,39 @@ def _secret() -> _Secret: def _app( _secret: _Secret, _env_phoenix_sql_database_url: Any, + _fake: Faker, ) -> Iterator[None]: values = ( (ENV_PHOENIX_ENABLE_AUTH, "true"), (ENV_PHOENIX_SECRET, _secret), + (ENV_PHOENIX_SMTP_HOSTNAME, "127.0.0.1"), + (ENV_PHOENIX_SMTP_PORT, str(pick_unused_port())), + (ENV_PHOENIX_SMTP_USERNAME, "test"), + (ENV_PHOENIX_SMTP_PASSWORD, "test"), + (ENV_PHOENIX_SMTP_MAIL_FROM, _fake.email()), + (ENV_PHOENIX_SMTP_VALIDATE_CERTS, "false"), ) with ExitStack() as stack: stack.enter_context(mock.patch.dict(os.environ, values)) stack.enter_context(_server()) yield + + +@pytest.fixture(scope="module") +def _smtpd( + _app: Any, + tmp_path_factory: pytest.TempPathFactory, +) -> Iterator[AuthController]: + path = tmp_path_factory.mktemp("certs") + cert, _ = _generate_certs(path, separate_key=False) + os.environ["SMTPD_SSL_CERTIFICATE_FILE"] = str(cert.resolve()) + config = Config() + config.login_username = get_env_smtp_username() + config.login_password = get_env_smtp_password() + config.use_starttls = True + with SMTPDFix( + hostname=get_env_smtp_hostname(), + port=get_env_smtp_port(), + config=config, + ) as controller: + yield controller diff --git a/integration_tests/auth/test_auth.py b/integration_tests/auth/test_auth.py index f0f4e8f1d3..876b801f68 100644 --- a/integration_tests/auth/test_auth.py +++ b/integration_tests/auth/test_auth.py @@ -15,8 +15,10 @@ TypeVar, ) +import httpx import jwt import pytest +import smtpdfix from httpx import HTTPStatusError from opentelemetry.sdk.environment_variables import ( OTEL_EXPORTER_OTLP_HEADERS, @@ -43,6 +45,7 @@ _ApiKey, _create_user, _DefaultAdminTokenSequestration, + _Email, _Expectation, _export_embeddings, _GetUser, @@ -50,6 +53,7 @@ _grpc_span_exporter, _Headers, _http_span_exporter, + _initiate_password_reset, _log_in, _LoggedInUser, _Password, @@ -123,6 +127,128 @@ def test_cannot_log_in_with_deleted_user( user.log_in() +class TestPasswordReset: + def test_initiate_password_reset_does_not_reveal_whether_user_exists( + self, + _emails: Iterator[_Email], + _smtpd: smtpdfix.AuthController, + ) -> None: + email = next(_emails) + assert not _initiate_password_reset(email, _smtpd, should_receive_email=False) + + @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN]) + def test_initiate_password_reset_does_not_change_existing_password( + self, + role_or_user: _RoleOrUser, + _get_user: _GetUser, + _smtpd: smtpdfix.AuthController, + ) -> None: + u = _get_user(role_or_user) + assert u.initiate_password_reset(_smtpd) + u.log_in() + + @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN]) + def test_password_reset_cannot_be_initiated_again_while_in_progress( + self, + role_or_user: _RoleOrUser, + _get_user: _GetUser, + _smtpd: smtpdfix.AuthController, + ) -> None: + u = _get_user(role_or_user) + assert u.initiate_password_reset(_smtpd) + with pytest.raises(httpx.HTTPStatusError): + assert u.initiate_password_reset(_smtpd) + + @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN]) + def test_password_reset_can_be_initiated_immediately_after_password_reset( + self, + role_or_user: _RoleOrUser, + _get_user: _GetUser, + _passwords: Iterator[_Password], + _smtpd: smtpdfix.AuthController, + ) -> None: + u = _get_user(role_or_user) + new_password = next(_passwords) + assert new_password != u.password + assert (token := u.initiate_password_reset(_smtpd)) + token.reset(new_password) + assert u.initiate_password_reset(_smtpd) + + @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN]) + def test_password_reset_token_is_single_use( + self, + role_or_user: _RoleOrUser, + _get_user: _GetUser, + _passwords: Iterator[_Password], + _smtpd: smtpdfix.AuthController, + ) -> None: + u = _get_user(role_or_user) + new_password = next(_passwords) + assert new_password != u.password + newer_password = next(_passwords) + assert newer_password != new_password + assert (token := u.initiate_password_reset(_smtpd)) + token.reset(new_password) + with _EXPECTATION_401: + token.reset(newer_password) + + @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN]) + def test_initiate_password_reset_and_then_reset_password_using_token_from_email( + self, + role_or_user: _RoleOrUser, + _get_user: _GetUser, + _passwords: Iterator[_Password], + _smtpd: smtpdfix.AuthController, + ) -> None: + u = _get_user(role_or_user) + logged_in_user = u.log_in() + logged_in_user.create_api_key() + assert (token := u.initiate_password_reset(_smtpd)) + new_password = next(_passwords) + assert new_password != u.password + token.reset(new_password) + with _EXPECTATION_401: + # old password should no longer work + u.log_in() + with _EXPECTATION_401: + # old logged-in tokens should no longer work + logged_in_user.create_api_key() + # new password should work + new_profile = replace(u.profile, password=new_password) + replace(u, profile=new_profile).log_in().create_api_key() + + @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN]) + def test_deleted_user_will_not_receive_email_after_initiating_password_reset( + self, + role_or_user: _RoleOrUser, + _get_user: _GetUser, + _smtpd: smtpdfix.AuthController, + ) -> None: + u = _get_user(role_or_user) + logged_in_user = u.log_in() + logged_in_user.create_api_key() + _DEFAULT_ADMIN.delete_users(u) + assert not u.initiate_password_reset(_smtpd, should_receive_email=False) + + @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN]) + def test_deleted_user_cannot_reset_password_using_token_from_email( + self, + role_or_user: _RoleOrUser, + _get_user: _GetUser, + _passwords: Iterator[_Password], + _smtpd: smtpdfix.AuthController, + ) -> None: + u = _get_user(role_or_user) + logged_in_user = u.log_in() + logged_in_user.create_api_key() + assert (token := u.initiate_password_reset(_smtpd)) + new_password = next(_passwords) + assert new_password != u.password + _DEFAULT_ADMIN.delete_users(u) + with _EXPECTATION_401: + token.reset(new_password) + + class TestLogOut: def test_default_admin_cannot_log_out_during_testing(self) -> None: """ diff --git a/integration_tests/requirements.txt b/integration_tests/requirements.txt index 681b6a7077..c9b09fe8ce 100644 --- a/integration_tests/requirements.txt +++ b/integration_tests/requirements.txt @@ -1,3 +1,4 @@ +beautifulsoup4 faker httpx openinference-semantic-conventions @@ -5,4 +6,6 @@ opentelemetry-sdk portpicker psutil pytest-randomly +pytest-smtpd +types-beautifulsoup4 types-psutil diff --git a/pyproject.toml b/pyproject.toml index 1eaea37c10..aab7b27b69 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ dependencies = [ "arize-phoenix-evals>=0.13.1", "arize-phoenix-otel>=0.4.1", "fastapi", + "fastapi-mail", "pydantic>=1.0,!=2.0.*,<3", # exclude 2.0.* since it does not support the `json_encoders` configuration setting "pyjwt", ] diff --git a/src/phoenix/auth.py b/src/phoenix/auth.py index ece7705fab..83713aa606 100644 --- a/src/phoenix/auth.py +++ b/src/phoenix/auth.py @@ -184,6 +184,9 @@ def validate( DEFAULT_ADMIN_EMAIL = "admin@localhost" DEFAULT_ADMIN_PASSWORD = "admin" DEFAULT_SECRET_LENGTH = 32 +DEFAULT_PASSWORD_RESET_TOKEN_EXPIRY_MINUTES = 15 +DEFAULT_ACCESS_TOKEN_EXPIRY_MINUTES = 10 +DEFAULT_REFRESH_TOKEN_EXPIRY_MINUTES = 60 * 24 * 7 """The default length of a secret key in bytes.""" EMAIL_PATTERN = re.compile(r"^[^@\s]+@[^@\s]+[.][^@\s]+\Z") """The regular expression pattern for a valid email address.""" diff --git a/src/phoenix/config.py b/src/phoenix/config.py index 78f8cc626f..2f3dcf8377 100644 --- a/src/phoenix/config.py +++ b/src/phoenix/config.py @@ -4,9 +4,7 @@ from datetime import timedelta from logging import getLogger from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import pandas as pd +from typing import Dict, List, Optional, Tuple, overload from phoenix.utilities.re import parse_env_headers @@ -74,8 +72,44 @@ ENV_PHOENIX_SECRET = "PHOENIX_SECRET" ENV_PHOENIX_API_KEY = "PHOENIX_API_KEY" ENV_PHOENIX_USE_SECURE_COOKIES = "PHOENIX_USE_SECURE_COOKIES" -ENV_PHOENIX_ACCESS_TOKEN_EXPIRY = "PHOENIX_ACCESS_TOKEN_EXPIRY" -ENV_PHOENIX_REFRESH_TOKEN_EXPIRY = "PHOENIX_REFRESH_TOKEN_EXPIRY" +ENV_PHOENIX_ACCESS_TOKEN_EXPIRY_MINUTES = "PHOENIX_ACCESS_TOKEN_EXPIRY_MINUTES" +""" +The duration, in minutes, before access tokens expire. +""" +ENV_PHOENIX_REFRESH_TOKEN_EXPIRY_MINUTES = "PHOENIX_REFRESH_TOKEN_EXPIRY_MINUTES" +""" +The duration, in minutes, before refresh tokens expire. +""" +ENV_PHOENIX_PASSWORD_RESET_TOKEN_EXPIRY_MINUTES = "PHOENIX_PASSWORD_RESET_TOKEN_EXPIRY_MINUTES" +""" +The duration, in minutes, before password reset tokens expire. +""" + +# SMTP settings +ENV_PHOENIX_SMTP_HOSTNAME = "PHOENIX_SMTP_HOSTNAME" +""" +The SMTP server hostname to use for sending emails. SMTP is disabled if this is not set. +""" +ENV_PHOENIX_SMTP_PORT = "PHOENIX_SMTP_PORT" +""" +The SMTP server port to use for sending emails. Defaults to 587. +""" +ENV_PHOENIX_SMTP_USERNAME = "PHOENIX_SMTP_USERNAME" +""" +The SMTP server username to use for sending emails. Should be set if SMTP is enabled. +""" +ENV_PHOENIX_SMTP_PASSWORD = "PHOENIX_SMTP_PASSWORD" +""" +The SMTP server password to use for sending emails. Should be set if SMTP is enabled. +""" +ENV_PHOENIX_SMTP_MAIL_FROM = "PHOENIX_SMTP_MAIL_FROM" +""" +The email address to use as the sender when sending emails. Should be set if SMTP is enabled. +""" +ENV_PHOENIX_SMTP_VALIDATE_CERTS = "PHOENIX_SMTP_VALIDATE_CERTS" +""" +Whether to validate SMTP server certificates. Defaults to true. +""" def server_instrumentation_is_enabled() -> bool: @@ -117,12 +151,16 @@ def get_working_dir() -> Path: return Path.home().resolve() / ".phoenix" -def get_boolean_env_var(env_var: str) -> Optional[bool]: +@overload +def _bool_val(env_var: str) -> Optional[bool]: ... +@overload +def _bool_val(env_var: str, default: bool) -> bool: ... +def _bool_val(env_var: str, default: Optional[bool] = None) -> Optional[bool]: """ - Parses a boolean environment variable, returning None if the variable is not set. + Parses a boolean environment variable, returning `default` if the variable is not set. """ if (value := os.environ.get(env_var)) is None: - return None + return default assert (lower := value.lower()) in ( "true", "false", @@ -130,11 +168,49 @@ def get_boolean_env_var(env_var: str) -> Optional[bool]: return lower == "true" +@overload +def _float_val(env_var: str) -> Optional[float]: ... +@overload +def _float_val(env_var: str, default: float) -> float: ... +def _float_val(env_var: str, default: Optional[float] = None) -> Optional[float]: + """ + Parses a numeric environment variable, returning `default` if the variable is not set. + """ + if (value := os.environ.get(env_var)) is None: + return default + try: + return float(value) + except ValueError: + raise ValueError( + f"Invalid value for environment variable {env_var}: {value}. " + f"Value must be a number." + ) + + +@overload +def _int_val(env_var: str) -> Optional[int]: ... +@overload +def _int_val(env_var: str, default: int) -> int: ... +def _int_val(env_var: str, default: Optional[int] = None) -> Optional[int]: + """ + Parses a numeric environment variable, returning `default` if the variable is not set. + """ + if (value := os.environ.get(env_var)) is None: + return default + try: + return int(value) + except ValueError: + raise ValueError( + f"Invalid value for environment variable {env_var}: {value}. " + f"Value must be an integer." + ) + + def get_env_enable_auth() -> bool: """ Gets the value of the PHOENIX_ENABLE_AUTH environment variable. """ - return get_boolean_env_var(ENV_PHOENIX_ENABLE_AUTH) is True + return _bool_val(ENV_PHOENIX_ENABLE_AUTH, False) def get_env_phoenix_secret() -> Optional[str]: @@ -152,7 +228,7 @@ def get_env_phoenix_secret() -> Optional[str]: def get_env_phoenix_use_secure_cookies() -> bool: - return bool(get_boolean_env_var(ENV_PHOENIX_USE_SECURE_COOKIES)) + return _bool_val(ENV_PHOENIX_USE_SECURE_COOKIES, False) def get_env_phoenix_api_key() -> Optional[str]: @@ -173,48 +249,72 @@ def get_env_auth_settings() -> Tuple[bool, Optional[str]]: return enable_auth, phoenix_secret +def get_env_password_reset_token_expiry() -> timedelta: + """ + Gets the password reset token expiry. + """ + from phoenix.auth import DEFAULT_PASSWORD_RESET_TOKEN_EXPIRY_MINUTES + + minutes = _float_val( + ENV_PHOENIX_PASSWORD_RESET_TOKEN_EXPIRY_MINUTES, + DEFAULT_PASSWORD_RESET_TOKEN_EXPIRY_MINUTES, + ) + assert minutes > 0 + return timedelta(minutes=minutes) + + def get_env_access_token_expiry() -> timedelta: """ Gets the access token expiry. """ - if (access_token_expiry := os.environ.get(ENV_PHOENIX_ACCESS_TOKEN_EXPIRY)) is None: - return timedelta(minutes=10) - try: - return _parse_duration(access_token_expiry) - except ValueError as error: - raise ValueError( - f"Error reading {ENV_PHOENIX_ACCESS_TOKEN_EXPIRY} environment variable: {str(error)}" - ) + from phoenix.auth import DEFAULT_ACCESS_TOKEN_EXPIRY_MINUTES + + minutes = _float_val( + ENV_PHOENIX_ACCESS_TOKEN_EXPIRY_MINUTES, + DEFAULT_ACCESS_TOKEN_EXPIRY_MINUTES, + ) + assert minutes > 0 + return timedelta(minutes=minutes) def get_env_refresh_token_expiry() -> timedelta: """ Gets the refresh token expiry. """ - if (refresh_token_expiry := os.environ.get(ENV_PHOENIX_REFRESH_TOKEN_EXPIRY)) is None: - return timedelta(weeks=1) - try: - return _parse_duration(refresh_token_expiry) - except ValueError as error: - raise ValueError( - f"Error reading {ENV_PHOENIX_REFRESH_TOKEN_EXPIRY} environment variable: {str(error)}" - ) + from phoenix.auth import DEFAULT_REFRESH_TOKEN_EXPIRY_MINUTES + minutes = _float_val( + ENV_PHOENIX_REFRESH_TOKEN_EXPIRY_MINUTES, + DEFAULT_REFRESH_TOKEN_EXPIRY_MINUTES, + ) + assert minutes > 0 + return timedelta(minutes=minutes) -def _parse_duration(duration_str: str) -> timedelta: - """ - Parses a duration string into a timedelta object, assuming the duration is - in seconds if no unit is provided. - """ - try: - duration = timedelta(seconds=float(duration_str)) - except ValueError: - duration = pd.Timedelta(duration_str) - if pd.isnull(duration): - raise ValueError("duration cannot be null") - if duration <= timedelta(0): - raise ValueError("duration must be positive") - return duration + +def get_env_smtp_username() -> str: + return os.getenv(ENV_PHOENIX_SMTP_USERNAME) or "" + + +def get_env_smtp_password() -> str: + return os.getenv(ENV_PHOENIX_SMTP_PASSWORD) or "" + + +def get_env_smtp_mail_from() -> str: + return os.getenv(ENV_PHOENIX_SMTP_MAIL_FROM) or "" + + +def get_env_smtp_hostname() -> str: + return os.getenv(ENV_PHOENIX_SMTP_HOSTNAME) or "" + + +def get_env_smtp_port() -> int: + port = _int_val(ENV_PHOENIX_SMTP_PORT, 587) + assert 0 < port <= 65_535 + return port + + +def get_env_smtp_validate_certs() -> bool: + return _bool_val(ENV_PHOENIX_SMTP_VALIDATE_CERTS, True) PHOENIX_DIR = Path(__file__).resolve().parent diff --git a/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py b/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py index 27c8199b98..e351f4957f 100644 --- a/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +++ b/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py @@ -71,6 +71,25 @@ def upgrade() -> None: sa.CheckConstraint("password_hash is null or password_salt is not null", name="salt"), sqlite_autoincrement=True, ) + op.create_table( + "password_reset_tokens", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column( + "user_id", + sa.Integer, + sa.ForeignKey("users.id", ondelete="CASCADE"), + unique=True, + index=True, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False, index=True), + sqlite_autoincrement=True, + ) op.create_table( "refresh_tokens", sa.Column("id", sa.Integer, primary_key=True), @@ -125,5 +144,6 @@ def downgrade() -> None: op.drop_table("api_keys") op.drop_table("access_tokens") op.drop_table("refresh_tokens") + op.drop_table("password_reset_tokens") op.drop_table("users") op.drop_table("user_roles") diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index fa33c472c8..5a81d57820 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -648,6 +648,11 @@ class User(Base): UtcTimeStamp, server_default=func.now(), onupdate=func.now() ) deleted_at: Mapped[Optional[datetime]] = mapped_column(UtcTimeStamp) + password_reset_token: Mapped["PasswordResetToken"] = relationship( + "PasswordResetToken", + back_populates="user", + uselist=False, + ) access_tokens: Mapped[List["AccessToken"]] = relationship("AccessToken", back_populates="user") refresh_tokens: Mapped[List["RefreshToken"]] = relationship( "RefreshToken", back_populates="user" @@ -659,6 +664,20 @@ class User(Base): ) +class PasswordResetToken(Base): + __tablename__ = "password_reset_tokens" + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column( + ForeignKey("users.id", ondelete="CASCADE"), + unique=True, + index=True, + ) + user: Mapped["User"] = relationship("User", back_populates="password_reset_token") + created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now()) + expires_at: Mapped[Optional[datetime]] = mapped_column(UtcTimeStamp, nullable=False, index=True) + __table_args__ = (dict(sqlite_autoincrement=True),) + + class RefreshToken(Base): __tablename__ = "refresh_tokens" id: Mapped[int] = mapped_column(primary_key=True) diff --git a/src/phoenix/server/api/mutations/user_mutations.py b/src/phoenix/server/api/mutations/user_mutations.py index 889403c7ed..cfae3afba3 100644 --- a/src/phoenix/server/api/mutations/user_mutations.py +++ b/src/phoenix/server/api/mutations/user_mutations.py @@ -29,7 +29,7 @@ from phoenix.server.api.types.node import from_global_id_with_expected_type from phoenix.server.api.types.User import User, to_gql_user from phoenix.server.bearer_auth import PhoenixUser -from phoenix.server.types import AccessTokenId, ApiKeyId, RefreshTokenId +from phoenix.server.types import AccessTokenId, ApiKeyId, PasswordResetTokenId, RefreshTokenId @strawberry.input @@ -269,6 +269,14 @@ async def delete_users( raise Conflict("Cannot delete the default admin user") if num_resolved_user_ids < len(user_ids): raise NotFound("Some user IDs could not be found") + password_reset_token_ids = [ + PasswordResetTokenId(id_) + async for id_ in await session.stream_scalars( + delete(models.PasswordResetToken) + .where(models.PasswordResetToken.user_id.in_(user_ids)) + .returning(models.PasswordResetToken.id) + ) + ] access_token_ids = [ AccessTokenId(id_) async for id_ in await session.stream_scalars( @@ -298,7 +306,12 @@ async def delete_users( .where(models.User.id.in_(user_ids)) .values(deleted_at=func.now()) ) - await token_store.revoke(*access_token_ids, *refresh_token_ids, *api_key_ids) + await token_store.revoke( + *password_reset_token_ids, + *access_token_ids, + *refresh_token_ids, + *api_key_ids, + ) def _select_role_id_by_name(role_name: str) -> Select[Tuple[int]]: diff --git a/src/phoenix/server/api/routers/auth.py b/src/phoenix/server/api/routers/auth.py index 037ef3c38e..ed77080948 100644 --- a/src/phoenix/server/api/routers/auth.py +++ b/src/phoenix/server/api/routers/auth.py @@ -1,30 +1,44 @@ import asyncio +import secrets from datetime import datetime, timedelta, timezone from functools import partial +from typing import Tuple from fastapi import APIRouter, Depends, HTTPException, Request, Response -from sqlalchemy import and_, select +from sqlalchemy import Select, select from sqlalchemy.orm import joinedload -from starlette.status import HTTP_204_NO_CONTENT, HTTP_401_UNAUTHORIZED, HTTP_404_NOT_FOUND +from starlette.status import ( + HTTP_204_NO_CONTENT, + HTTP_401_UNAUTHORIZED, + HTTP_404_NOT_FOUND, + HTTP_422_UNPROCESSABLE_ENTITY, + HTTP_503_SERVICE_UNAVAILABLE, +) from phoenix.auth import ( + DEFAULT_SECRET_LENGTH, PHOENIX_ACCESS_TOKEN_COOKIE_NAME, PHOENIX_REFRESH_TOKEN_COOKIE_NAME, Token, + compute_password_hash, delete_access_token_cookie, delete_refresh_token_cookie, is_valid_password, set_access_token_cookie, set_refresh_token_cookie, + validate_password_format, ) +from phoenix.config import get_base_url +from phoenix.db import enums, models from phoenix.db.enums import UserRole -from phoenix.db.models import User as OrmUser from phoenix.server.bearer_auth import PhoenixUser -from phoenix.server.jwt_store import JwtStore +from phoenix.server.email.templates.types import PasswordResetTemplateBody +from phoenix.server.email.types import EmailSender from phoenix.server.rate_limiters import ServerRateLimiter, fastapi_rate_limiter from phoenix.server.types import ( AccessTokenAttributes, AccessTokenClaims, + PasswordResetTokenClaims, RefreshTokenAttributes, RefreshTokenClaims, TokenStore, @@ -37,7 +51,16 @@ partition_seconds=60, active_partitions=2, ) -login_rate_limiter = fastapi_rate_limiter(rate_limiter, paths=["/login"]) +login_rate_limiter = fastapi_rate_limiter( + rate_limiter, + paths=[ + "/login", + "/logout", + "/refresh", + "/password-reset-email", + "/password-reset", + ], +) router = APIRouter( prefix="/auth", include_in_schema=False, dependencies=[Depends(login_rate_limiter)] ) @@ -56,9 +79,7 @@ async def login(request: Request) -> Response: async with request.app.state.db() as session: user = await session.scalar( - select(OrmUser) - .where(and_(OrmUser.email == email, OrmUser.deleted_at.is_(None))) - .options(joinedload(OrmUser.role)) + _select_active_user().filter_by(email=email).options(joinedload(models.User.role)) ) if ( user is None @@ -83,7 +104,7 @@ async def login(request: Request) -> Response: user_role=UserRole(user.role.name), ), ) - token_store: JwtStore = request.app.state.get_token_store() + token_store: TokenStore = request.app.state.get_token_store() refresh_token, refresh_token_id = await token_store.create_refresh_token(refresh_token_claims) access_token_claims = AccessTokenClaims( subject=UserId(user.id), @@ -110,7 +131,8 @@ async def logout( request: Request, ) -> Response: token_store: TokenStore = request.app.state.get_token_store() - assert isinstance(user := request.user, PhoenixUser) + if not isinstance(user := request.user, PhoenixUser): + raise HTTPException(status_code=HTTP_401_UNAUTHORIZED) await token_store.log_out(user.identity) response = Response(status_code=HTTP_204_NO_CONTENT) response = delete_access_token_cookie(response) @@ -124,7 +146,7 @@ async def refresh_tokens(request: Request) -> Response: assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta) if (refresh_token := request.cookies.get(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)) is None: raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Missing refresh token") - token_store: JwtStore = request.app.state.get_token_store() + token_store: TokenStore = request.app.state.get_token_store() refresh_token_claims = await token_store.read(Token(refresh_token)) if ( not isinstance(refresh_token_claims, RefreshTokenClaims) @@ -150,9 +172,7 @@ async def refresh_tokens(request: Request) -> Response: async with request.app.state.db() as session: if ( user := await session.scalar( - select(OrmUser) - .where(and_(OrmUser.id == user_id, OrmUser.deleted_at.is_(None))) - .options(joinedload(OrmUser.role)) + _select_active_user().filter_by(id=user_id).options(joinedload(models.User.role)) ) ) is None: raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="User not found") @@ -187,4 +207,96 @@ async def refresh_tokens(request: Request) -> Response: return response +@router.post("/password-reset-email") +async def initiate_password_reset(request: Request) -> Response: + data = await request.json() + if not (email := data.get("email")): + raise MISSING_EMAIL + sender: EmailSender = request.app.state.email_sender + if sender is None: + raise UNAVAILABLE + assert isinstance(token_expiry := request.app.state.password_reset_token_expiry, timedelta) + async with request.app.state.db() as session: + user = await session.scalar( + _select_active_user() + .filter_by(email=email) + .options( + joinedload(models.User.password_reset_token).load_only(models.PasswordResetToken.id) + ) + ) + if user is None or user.auth_method != enums.AuthMethod.LOCAL.value: + # Withold privileged information + return Response(status_code=HTTP_204_NO_CONTENT) + if user.password_reset_token: + raise RESET_IN_PROGRESS + password_reset_token_claims = PasswordResetTokenClaims( + subject=UserId(user.id), + issued_at=datetime.now(timezone.utc), + expiration_time=datetime.now(timezone.utc) + token_expiry, + ) + token_store: TokenStore = request.app.state.get_token_store() + token, _ = await token_store.create_password_reset_token(password_reset_token_claims) + await sender.send_password_reset_email(email, PasswordResetTemplateBody(token, get_base_url())) + return Response(status_code=HTTP_204_NO_CONTENT) + + +@router.post("/password-reset") +async def reset_password(request: Request) -> Response: + data = await request.json() + if not (password := data.get("password")): + raise MISSING_PASSWORD + token_store: TokenStore = request.app.state.get_token_store() + if ( + not (token := data.get("token")) + or not isinstance((claims := await token_store.read(token)), PasswordResetTokenClaims) + or not claims.expiration_time + or claims.expiration_time < datetime.now(timezone.utc) + ): + raise INVALID_TOKEN + assert (user_id := claims.subject) + async with request.app.state.db() as session: + user = await session.scalar(_select_active_user().filter_by(id=int(user_id))) + if user is None or user.auth_method != enums.AuthMethod.LOCAL.value: + # Withold privileged information + return Response(status_code=HTTP_204_NO_CONTENT) + validate_password_format(password) + user.password_salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH) + loop = asyncio.get_running_loop() + user.password_hash = await loop.run_in_executor( + None, partial(compute_password_hash, password=password, salt=user.password_salt) + ) + async with request.app.state.db() as session: + session.add(user) + await session.flush() + response = Response(status_code=HTTP_204_NO_CONTENT) + assert (token_id := claims.token_id) + await token_store.revoke(token_id) + await token_store.log_out(UserId(user.id)) + return response + + +def _select_active_user() -> Select[Tuple[models.User]]: + return select(models.User).where(models.User.deleted_at.is_(None)) + + LOGIN_FAILED_MESSAGE = "Invalid email and/or password" + +MISSING_EMAIL = HTTPException( + status_code=HTTP_422_UNPROCESSABLE_ENTITY, + detail="Email required", +) +MISSING_PASSWORD = HTTPException( + status_code=HTTP_422_UNPROCESSABLE_ENTITY, + detail="Password required", +) +UNAVAILABLE = HTTPException( + status_code=HTTP_503_SERVICE_UNAVAILABLE, +) +RESET_IN_PROGRESS = HTTPException( + status_code=HTTP_503_SERVICE_UNAVAILABLE, + detail="Password reset already in progress", +) +INVALID_TOKEN = HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid token", +) diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 2070e09f73..c58b2ea4ed 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -96,6 +96,7 @@ from phoenix.server.bearer_auth import BearerTokenAuthBackend, is_authenticated from phoenix.server.dml_event import DmlEvent from phoenix.server.dml_event_handler import DmlEventHandler +from phoenix.server.email.types import EmailSender from phoenix.server.grpc_server import GrpcServer from phoenix.server.jwt_store import JwtStore from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider @@ -605,9 +606,11 @@ def create_app( startup_callbacks: Iterable[_Callback] = (), shutdown_callbacks: Iterable[_Callback] = (), secret: Optional[str] = None, + password_reset_token_expiry: Optional[timedelta] = None, access_token_expiry: Optional[timedelta] = None, refresh_token_expiry: Optional[timedelta] = None, scaffolder_config: Optional[ScaffolderConfig] = None, + email_sender: Optional[EmailSender] = None, ) -> FastAPI: startup_callbacks_list: List[_Callback] = list(startup_callbacks) shutdown_callbacks_list: List[_Callback] = list(shutdown_callbacks) @@ -742,9 +745,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: ) app.state.read_only = read_only app.state.export_path = export_path + app.state.password_reset_token_expiry = password_reset_token_expiry app.state.access_token_expiry = access_token_expiry app.state.refresh_token_expiry = refresh_token_expiry app.state.db = db + app.state.email_sender = email_sender app = _add_get_secret_method(app=app, secret=secret) app = _add_get_token_store_method(app=app, token_store=token_store) if tracer_provider: diff --git a/src/phoenix/server/email/__init__.py b/src/phoenix/server/email/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/phoenix/server/email/sender.py b/src/phoenix/server/email/sender.py new file mode 100644 index 0000000000..f44219af70 --- /dev/null +++ b/src/phoenix/server/email/sender.py @@ -0,0 +1,29 @@ +from dataclasses import asdict +from pathlib import Path + +from fastapi_mail import ConnectionConfig, FastMail, MessageSchema + +from phoenix.server.email.templates.types import PasswordResetTemplateBody + +EMAIL_TEMPLATE_FOLDER = Path(__file__).parent / "templates" + + +class FastMailSender: + def __init__(self, conf: ConnectionConfig) -> None: + self._fm = FastMail(conf) + + async def send_password_reset_email( + self, + email: str, + values: PasswordResetTemplateBody, + ) -> None: + message = MessageSchema( + subject="Password Reset Request", + recipients=[email], + template_body=asdict(values), + subtype="html", + ) + await self._fm.send_message( + message, + template_name="password_reset.html", + ) diff --git a/src/phoenix/server/email/templates/__init__.py b/src/phoenix/server/email/templates/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/phoenix/server/email/templates/password_reset.html b/src/phoenix/server/email/templates/password_reset.html new file mode 100644 index 0000000000..aab4b3dddc --- /dev/null +++ b/src/phoenix/server/email/templates/password_reset.html @@ -0,0 +1,15 @@ + + + + + Password Reset + + +

Hello.

+

You have requested a password reset. Please click on the link below to reset your password:

+

+ Reset Password +

+

If you did not make this request, please contact your administrator.

+ + diff --git a/src/phoenix/server/email/templates/types.py b/src/phoenix/server/email/templates/types.py new file mode 100644 index 0000000000..a7cf1c0e3b --- /dev/null +++ b/src/phoenix/server/email/templates/types.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from phoenix.server.types import PasswordResetToken + + +@dataclass(frozen=True) +class PasswordResetTemplateBody: + token: PasswordResetToken + base_url: str diff --git a/src/phoenix/server/email/types.py b/src/phoenix/server/email/types.py new file mode 100644 index 0000000000..ad64166285 --- /dev/null +++ b/src/phoenix/server/email/types.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from typing import Protocol + +from phoenix.server.email.templates.types import PasswordResetTemplateBody + + +class EmailSender(Protocol): + async def send_password_reset_email( + self, + email: str, + values: PasswordResetTemplateBody, + ) -> None: ... diff --git a/src/phoenix/server/jwt_store.py b/src/phoenix/server/jwt_store.py index 430d97684d..4b8da3af72 100644 --- a/src/phoenix/server/jwt_store.py +++ b/src/phoenix/server/jwt_store.py @@ -5,7 +5,7 @@ from dataclasses import replace from datetime import datetime, timezone from functools import cached_property, singledispatchmethod -from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar +from typing import Any, Callable, Coroutine, Dict, Generic, List, Optional, Tuple, Type, TypeVar import jwt from sqlalchemy import Select, delete, select @@ -28,6 +28,10 @@ ApiKeyId, DaemonTask, DbSessionFactory, + PasswordResetToken, + PasswordResetTokenAttributes, + PasswordResetTokenClaims, + PasswordResetTokenId, RefreshToken, RefreshTokenAttributes, RefreshTokenClaims, @@ -53,6 +57,7 @@ def __init__( self._db = db self._secret = secret args = (db, secret, algorithm, sleep_seconds) + self._password_reset_token_store = _PasswordResetTokenStore(*args, **kwargs) self._access_token_store = _AccessTokenStore(*args, **kwargs) self._refresh_token_store = _RefreshTokenStore(*args, **kwargs) self._api_key_store = _ApiKeyStore(*args, **kwargs) @@ -87,6 +92,10 @@ async def read(self, token: Token) -> Optional[ClaimSet]: async def _get(self, _: TokenId) -> Optional[ClaimSet]: return None + @_get.register + async def _(self, token_id: PasswordResetTokenId) -> Optional[ClaimSet]: + return await self._password_reset_token_store.get(token_id) + @_get.register async def _(self, token_id: AccessTokenId) -> Optional[ClaimSet]: return await self._access_token_store.get(token_id) @@ -103,6 +112,10 @@ async def _(self, token_id: ApiKeyId) -> Optional[ClaimSet]: async def _evict(self, _: TokenId) -> Optional[ClaimSet]: return None + @_evict.register + async def _(self, token_id: PasswordResetTokenId) -> Optional[ClaimSet]: + return await self._password_reset_token_store.evict(token_id) + @_evict.register async def _(self, token_id: AccessTokenId) -> Optional[ClaimSet]: return await self._access_token_store.evict(token_id) @@ -115,6 +128,12 @@ async def _(self, token_id: RefreshTokenId) -> Optional[ClaimSet]: async def _(self, token_id: ApiKeyId) -> Optional[ClaimSet]: return await self._api_key_store.evict(token_id) + async def create_password_reset_token( + self, + claim: PasswordResetTokenClaims, + ) -> Tuple[PasswordResetToken, PasswordResetTokenId]: + return await self._password_reset_token_store.create(claim) + async def create_access_token( self, claim: AccessTokenClaims, @@ -136,21 +155,29 @@ async def create_api_key( async def revoke(self, *token_ids: TokenId) -> None: if not token_ids: return + password_reset_token_ids: List[PasswordResetTokenId] = [] access_token_ids: List[AccessTokenId] = [] refresh_token_ids: List[RefreshTokenId] = [] api_key_ids: List[ApiKeyId] = [] for token_id in token_ids: + if isinstance(token_id, PasswordResetTokenId): + password_reset_token_ids.append(token_id) if isinstance(token_id, AccessTokenId): access_token_ids.append(token_id) elif isinstance(token_id, RefreshTokenId): refresh_token_ids.append(token_id) elif isinstance(token_id, ApiKeyId): api_key_ids.append(token_id) - await gather( - self._access_token_store.revoke(*access_token_ids), - self._refresh_token_store.revoke(*refresh_token_ids), - self._api_key_store.revoke(*api_key_ids), - ) + coroutines: List[Coroutine[None, None, None]] = [] + if password_reset_token_ids: + coroutines.append(self._password_reset_token_store.revoke(*password_reset_token_ids)) + if access_token_ids: + coroutines.append(self._access_token_store.revoke(*access_token_ids)) + if refresh_token_ids: + coroutines.append(self._refresh_token_store.revoke(*refresh_token_ids)) + if api_key_ids: + coroutines.append(self._api_key_store.revoke(*api_key_ids)) + await gather(*coroutines) async def log_out(self, user_id: UserId) -> None: for cls in (AccessTokenId, RefreshTokenId): @@ -166,6 +193,7 @@ async def log_out(self, user_id: UserId) -> None: _ClaimSetT = TypeVar("_ClaimSetT", bound=ClaimSet) _RecordT = TypeVar( "_RecordT", + models.PasswordResetToken, models.AccessToken, models.RefreshToken, models.ApiKey, @@ -286,6 +314,45 @@ async def _run(self) -> None: self._tasks.pop() +class _PasswordResetTokenStore( + _Store[ + PasswordResetTokenClaims, + PasswordResetToken, + PasswordResetTokenId, + models.PasswordResetToken, + ] +): + _table = models.PasswordResetToken + _token_id = PasswordResetTokenId + _token = PasswordResetToken + + def _from_db( + self, + record: models.PasswordResetToken, + user_role: UserRole, + ) -> Tuple[PasswordResetTokenId, PasswordResetTokenClaims]: + token_id = PasswordResetTokenId(record.id) + return token_id, PasswordResetTokenClaims( + token_id=token_id, + subject=UserId(record.user_id), + issued_at=record.created_at, + expiration_time=record.expires_at, + attributes=PasswordResetTokenAttributes( + user_role=user_role, + ), + ) + + def _to_db(self, claim: PasswordResetTokenClaims) -> models.PasswordResetToken: + assert claim.expiration_time + assert claim.subject + user_id = int(claim.subject) + return models.PasswordResetToken( + user_id=user_id, + created_at=claim.issued_at, + expires_at=claim.expiration_time, + ) + + class _AccessTokenStore( _Store[ AccessTokenClaims, diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index 32889cd23b..8e788b2dfe 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -11,6 +11,7 @@ from typing import List, Optional from urllib.parse import urljoin +from fastapi_mail import ConnectionConfig from jinja2 import BaseLoader, Environment from uvicorn import Config, Server @@ -25,8 +26,15 @@ get_env_grpc_port, get_env_host, get_env_host_root_path, + get_env_password_reset_token_expiry, get_env_port, get_env_refresh_token_expiry, + get_env_smtp_hostname, + get_env_smtp_mail_from, + get_env_smtp_password, + get_env_smtp_port, + get_env_smtp_username, + get_env_smtp_validate_certs, get_pids_path, get_working_dir, ) @@ -47,6 +55,7 @@ create_engine_and_run_migrations, instrument_engine_if_enabled, ) +from phoenix.server.email.sender import EMAIL_TEMPLATE_FOLDER, FastMailSender from phoenix.server.types import DbSessionFactory from phoenix.settings import Settings from phoenix.trace.fixtures import ( @@ -367,6 +376,25 @@ def _get_pid_file() -> Path: scaffold_datasets=scaffold_datasets, phoenix_url=root_path, ) + email_sender = None + if mail_sever := get_env_smtp_hostname(): + assert (mail_username := get_env_smtp_username()), "SMTP username is required" + assert (mail_password := get_env_smtp_password()), "SMTP password is required" + assert (mail_from := get_env_smtp_mail_from()), "SMTP mail_from is required" + email_sender = FastMailSender( + ConnectionConfig( + MAIL_USERNAME=mail_username, + MAIL_PASSWORD=mail_password, + MAIL_FROM=mail_from, + MAIL_SERVER=mail_sever, + MAIL_PORT=get_env_smtp_port(), + VALIDATE_CERTS=get_env_smtp_validate_certs(), + USE_CREDENTIALS=True, + MAIL_STARTTLS=True, + MAIL_SSL_TLS=False, + TEMPLATE_FOLDER=EMAIL_TEMPLATE_FOLDER, + ) + ) app = create_app( db=factory, export_path=export_path, @@ -384,9 +412,11 @@ def _get_pid_file() -> Path: startup_callbacks=[lambda: print(msg)], shutdown_callbacks=instrumentation_cleanups, secret=secret, + password_reset_token_expiry=get_env_password_reset_token_expiry(), access_token_expiry=get_env_access_token_expiry(), refresh_token_expiry=get_env_refresh_token_expiry(), scaffolder_config=scaffolder_config, + email_sender=email_sender, ) server = Server(config=Config(app, host=host, port=port, root_path=host_root_path)) # type: ignore Thread(target=_write_pid_file_when_ready, args=(server,), daemon=True).start() diff --git a/src/phoenix/server/types.py b/src/phoenix/server/types.py index fbbd4c6d2b..aaaf32fa5e 100644 --- a/src/phoenix/server/types.py +++ b/src/phoenix/server/types.py @@ -143,6 +143,9 @@ def set(self, table: Type[models.Base], id_: int) -> None: self._cache[table][id_] = datetime.now(timezone.utc) +class PasswordResetToken(Token): ... + + class AccessToken(Token): ... @@ -161,6 +164,10 @@ class UserTokenAttributes(TokenAttributes): class RefreshTokenAttributes(UserTokenAttributes): ... +@dataclass(frozen=True) +class PasswordResetTokenAttributes(UserTokenAttributes): ... + + @dataclass(frozen=True) class AccessTokenAttributes(UserTokenAttributes): refresh_token_id: RefreshTokenId @@ -198,6 +205,11 @@ def parse(cls, value: str) -> Optional[TokenId]: return None +@final +class PasswordResetTokenId(TokenId): + table = models.PasswordResetToken + + @final class AccessTokenId(TokenId): table = models.AccessToken @@ -224,6 +236,12 @@ class UserClaimSet(ClaimSet): attributes: Optional[UserTokenAttributes] = None +@dataclass(frozen=True) +class PasswordResetTokenClaims(UserClaimSet): + token_id: Optional[PasswordResetTokenId] = None + attributes: Optional[PasswordResetTokenAttributes] = None + + @dataclass(frozen=True) class AccessTokenClaims(UserClaimSet): token_id: Optional[AccessTokenId] = None @@ -251,6 +269,10 @@ async def log_out(self, user_id: UserId) -> None: ... class TokenStore(CanReadToken, CanRevokeTokens, CanLogOutUser, Protocol): + async def create_password_reset_token( + self, + claims: PasswordResetTokenClaims, + ) -> Tuple[PasswordResetToken, PasswordResetTokenId]: ... async def create_access_token( self, claims: AccessTokenClaims, diff --git a/tests/test_config.py b/tests/test_config.py deleted file mode 100644 index abae5def99..0000000000 --- a/tests/test_config.py +++ /dev/null @@ -1,86 +0,0 @@ -from datetime import timedelta -from typing import Callable - -import pytest - -from phoenix.config import ( - ENV_PHOENIX_ACCESS_TOKEN_EXPIRY, - ENV_PHOENIX_REFRESH_TOKEN_EXPIRY, - get_env_access_token_expiry, - get_env_refresh_token_expiry, -) - - -@pytest.mark.parametrize( - "env_var_value, expected_value", - ( - pytest.param("3600", timedelta(seconds=3600), id="with-positive-unitless-integer"), - pytest.param("3600.10", timedelta(seconds=3600.10), id="with-positive-decimal"), - pytest.param("36 days", timedelta(days=36), id="with-positive-day-unit"), - pytest.param("36d", timedelta(days=36), id="with-positive-d-unit"), - pytest.param( - "P4M6DT3H12M45S", - timedelta(days=(4 * 30 + 6), hours=3, minutes=12, seconds=45), - id="with-iso-8601-duration", - ), - ), -) -@pytest.mark.parametrize( - "env_var_name, env_var_getter", - ( - pytest.param( - ENV_PHOENIX_ACCESS_TOKEN_EXPIRY, get_env_access_token_expiry, id="access-token-expiry" - ), - pytest.param( - ENV_PHOENIX_REFRESH_TOKEN_EXPIRY, - get_env_refresh_token_expiry, - id="refresh-token-expiry", - ), - ), -) -def test_get_env_token_expiry_parses_valid_values( - env_var_name: str, - env_var_getter: Callable[[], timedelta], - env_var_value: str, - expected_value: timedelta, - monkeypatch: pytest.MonkeyPatch, -) -> None: - monkeypatch.setenv(env_var_name, env_var_value) - env_var_getter() == expected_value - - -@pytest.mark.parametrize( - "env_var_value, error_message", - ( - pytest.param("-3600", "duration must be positive", id="with-negative-integer"), - pytest.param("-3600.10", "duration must be positive", id="with-negative-decimal"), - pytest.param("-36d", "duration must be positive", id="with-negative-d-unit"), - pytest.param("0", "duration must be positive", id="with-zero-duration"), - pytest.param("nan", "duration cannot be null", id="with-null"), - ), -) -@pytest.mark.parametrize( - "env_var_name, env_var_getter", - ( - pytest.param( - ENV_PHOENIX_ACCESS_TOKEN_EXPIRY, get_env_access_token_expiry, id="access-token-expiry" - ), - pytest.param( - ENV_PHOENIX_REFRESH_TOKEN_EXPIRY, - get_env_refresh_token_expiry, - id="refresh-token-expiry", - ), - ), -) -def test_get_env_token_expiry_raises_expected_errors_for_invalid_values( - env_var_name: str, - env_var_getter: Callable[[], timedelta], - env_var_value: str, - error_message: str, - monkeypatch: pytest.MonkeyPatch, -) -> None: - monkeypatch.setenv(env_var_name, env_var_value) - with pytest.raises( - ValueError, match=f"Error reading {env_var_name} environment variable: {error_message}" - ): - env_var_getter()