({
+ defaultValues: {
+ resetToken: resetToken,
+ newPassword: "",
+ confirmPassword: "",
+ },
+ });
+ return (
+ <>
+ {message ? (
+
+ {message}
+
+ ) : null}
+ {error ? (
+
+ {error}
+
+ ) : null}
+
+ >
+ );
+}
diff --git a/app/src/pages/auth/ResetPasswordWithTokenPage.tsx b/app/src/pages/auth/ResetPasswordWithTokenPage.tsx
new file mode 100644
index 0000000000..6ea2075d18
--- /dev/null
+++ b/app/src/pages/auth/ResetPasswordWithTokenPage.tsx
@@ -0,0 +1,29 @@
+import React from "react";
+import { useNavigate } from "react-router";
+import { useSearchParams } from "react-router-dom";
+
+import { Flex, View } from "@arizeai/components";
+
+import { AuthLayout } from "./AuthLayout";
+import { PhoenixLogo } from "./PhoenixLogo";
+import { ResetPasswordWithTokenForm } from "./ResetPasswordWithTokenForm";
+
+export function ResetPasswordWithTokenPage() {
+ const navigate = useNavigate();
+ const [searchParams] = useSearchParams();
+ const token = searchParams.get("token");
+ if (!token) {
+ navigate("/login");
+ return null;
+ }
+ return (
+
+
+
+
+
+
+
+
+ );
+}
diff --git a/app/src/pages/auth/index.tsx b/app/src/pages/auth/index.tsx
index da907c439a..42edb98745 100644
--- a/app/src/pages/auth/index.tsx
+++ b/app/src/pages/auth/index.tsx
@@ -1,3 +1,5 @@
export * from "./LoginPage";
export * from "./ResetPasswordPage";
+export * from "./ResetPasswordWithTokenPage";
export * from "./resetPasswordLoader";
+export * from "./ForgotPasswordPage";
diff --git a/integration_tests/_helpers.py b/integration_tests/_helpers.py
index c92ee184ac..f85b0c7982 100644
--- a/integration_tests/_helpers.py
+++ b/integration_tests/_helpers.py
@@ -9,6 +9,7 @@
from contextvars import ContextVar
from dataclasses import dataclass, replace
from datetime import datetime, timezone
+from email.message import Message
from functools import cached_property
from io import BytesIO
from subprocess import PIPE, STDOUT
@@ -32,12 +33,14 @@
Union,
cast,
)
-from urllib.parse import urljoin
+from urllib.parse import parse_qs, urljoin, urlparse
from urllib.request import urlopen
+import bs4
import httpx
import jwt
import pytest
+import smtpdfix
from httpx import Headers, HTTPStatusError
from openinference.semconv.resource import ResourceAttributes
from opentelemetry.sdk.resources import Resource
@@ -58,6 +61,7 @@
get_env_database_schema,
get_env_grpc_port,
get_env_host,
+ get_env_smtp_mail_from,
)
from phoenix.server.api.auth import IsAdmin
from phoenix.server.api.exceptions import Unauthorized
@@ -217,6 +221,19 @@ def delete_api_key(self, api_key: _ApiKey, /) -> None:
def export_embeddings(self, filename: str) -> None:
_export_embeddings(self, filename=filename)
+ def initiate_password_reset(
+ self,
+ smtpd: smtpdfix.AuthController,
+ /,
+ *,
+ should_receive_email: bool = True,
+ ) -> Optional[_PasswordResetToken]:
+ return _initiate_password_reset(
+ self.email,
+ smtpd,
+ should_receive_email=should_receive_email,
+ )
+
_SYSTEM_USER_GID = _GqlId(GlobalID(type_name="User", node_id="1"))
_DEFAULT_ADMIN = _User(
@@ -257,6 +274,11 @@ def kind(self) -> _ApiKeyKind:
class _Token(_String, ABC): ...
+class _PasswordResetToken(_Token):
+ def reset(self, password: _Password, /) -> None:
+ return _reset_password(self, password=password)
+
+
class _AccessToken(_Token, _CanLogOut[None]):
def log_out(self) -> None:
_log_out(self)
@@ -856,6 +878,37 @@ def _log_out(
resp.raise_for_status()
+def _initiate_password_reset(
+ email: _Email,
+ smtpd: smtpdfix.AuthController,
+ /,
+ *,
+ should_receive_email: bool = True,
+) -> Optional[_PasswordResetToken]:
+ old_msg_count = len(smtpd.messages)
+ json_ = dict(email=email)
+ resp = _httpx_client().post("auth/password-reset-email", json=json_)
+ resp.raise_for_status()
+ new_msg_count = len(smtpd.messages) - old_msg_count
+ assert new_msg_count == int(should_receive_email)
+ if not should_receive_email:
+ return None
+ msg = smtpd.messages[-1]
+ assert msg["to"] == email
+ assert msg["from"] == get_env_smtp_mail_from()
+ return _extract_password_reset_token(msg)
+
+
+def _reset_password(
+ token: _PasswordResetToken,
+ /,
+ password: _Password,
+) -> None:
+ json_ = dict(token=token, password=password)
+ resp = _httpx_client().post("auth/password-reset", json=json_)
+ resp.raise_for_status()
+
+
def _export_embeddings(auth: Optional[_SecurityArtifact] = None, /, *, filename: str) -> None:
resp = _httpx_client(auth).get("/exports", params={"filename": filename})
resp.raise_for_status()
@@ -912,3 +965,29 @@ def _decode_token_ids(
jwt.decode(v, options={"verify_signature": False})["jti"]
for v in _extract_tokens(headers, key).values()
]
+
+
+def _extract_password_reset_token(msg: Message) -> _PasswordResetToken:
+ assert (soup := _extract_html(msg))
+ assert isinstance((link := soup.find(id="reset-url")), bs4.Tag)
+ assert isinstance((url := link.get("href")), str)
+ assert url
+ params = parse_qs(urlparse(url).query)
+ assert (tokens := params["token"])
+ assert (token := tokens[0])
+ decoded = jwt.decode(token, options=dict(verify_signature=False))
+ assert (jti := decoded["jti"])
+ assert jti.startswith("PasswordResetToken")
+ return _PasswordResetToken(token)
+
+
+def _extract_html(msg: Message) -> Optional[bs4.BeautifulSoup]:
+ for part in msg.walk():
+ if (
+ part.get_content_type() == "text/html"
+ and (payload := part.get_payload(decode=True))
+ and isinstance(payload, bytes)
+ ):
+ content = payload.decode(part.get_content_charset() or "utf-8")
+ return bs4.BeautifulSoup(content, "html.parser")
+ return None
diff --git a/integration_tests/auth/conftest.py b/integration_tests/auth/conftest.py
index 3cc1d9b884..4ed2dccce3 100644
--- a/integration_tests/auth/conftest.py
+++ b/integration_tests/auth/conftest.py
@@ -5,8 +5,25 @@
from unittest import mock
import pytest
+from faker import Faker
from phoenix.auth import DEFAULT_SECRET_LENGTH
-from phoenix.config import ENV_PHOENIX_ENABLE_AUTH, ENV_PHOENIX_SECRET
+from phoenix.config import (
+ ENV_PHOENIX_ENABLE_AUTH,
+ ENV_PHOENIX_SECRET,
+ ENV_PHOENIX_SMTP_HOSTNAME,
+ ENV_PHOENIX_SMTP_MAIL_FROM,
+ ENV_PHOENIX_SMTP_PASSWORD,
+ ENV_PHOENIX_SMTP_PORT,
+ ENV_PHOENIX_SMTP_USERNAME,
+ ENV_PHOENIX_SMTP_VALIDATE_CERTS,
+ get_env_smtp_hostname,
+ get_env_smtp_password,
+ get_env_smtp_port,
+ get_env_smtp_username,
+)
+from portpicker import pick_unused_port # type: ignore[import-untyped]
+from smtpdfix import AuthController, Config, SMTPDFix
+from smtpdfix.certs import _generate_certs
from .._helpers import _Secret, _server
@@ -20,12 +37,39 @@ def _secret() -> _Secret:
def _app(
_secret: _Secret,
_env_phoenix_sql_database_url: Any,
+ _fake: Faker,
) -> Iterator[None]:
values = (
(ENV_PHOENIX_ENABLE_AUTH, "true"),
(ENV_PHOENIX_SECRET, _secret),
+ (ENV_PHOENIX_SMTP_HOSTNAME, "127.0.0.1"),
+ (ENV_PHOENIX_SMTP_PORT, str(pick_unused_port())),
+ (ENV_PHOENIX_SMTP_USERNAME, "test"),
+ (ENV_PHOENIX_SMTP_PASSWORD, "test"),
+ (ENV_PHOENIX_SMTP_MAIL_FROM, _fake.email()),
+ (ENV_PHOENIX_SMTP_VALIDATE_CERTS, "false"),
)
with ExitStack() as stack:
stack.enter_context(mock.patch.dict(os.environ, values))
stack.enter_context(_server())
yield
+
+
+@pytest.fixture(scope="module")
+def _smtpd(
+ _app: Any,
+ tmp_path_factory: pytest.TempPathFactory,
+) -> Iterator[AuthController]:
+ path = tmp_path_factory.mktemp("certs")
+ cert, _ = _generate_certs(path, separate_key=False)
+ os.environ["SMTPD_SSL_CERTIFICATE_FILE"] = str(cert.resolve())
+ config = Config()
+ config.login_username = get_env_smtp_username()
+ config.login_password = get_env_smtp_password()
+ config.use_starttls = True
+ with SMTPDFix(
+ hostname=get_env_smtp_hostname(),
+ port=get_env_smtp_port(),
+ config=config,
+ ) as controller:
+ yield controller
diff --git a/integration_tests/auth/test_auth.py b/integration_tests/auth/test_auth.py
index f0f4e8f1d3..876b801f68 100644
--- a/integration_tests/auth/test_auth.py
+++ b/integration_tests/auth/test_auth.py
@@ -15,8 +15,10 @@
TypeVar,
)
+import httpx
import jwt
import pytest
+import smtpdfix
from httpx import HTTPStatusError
from opentelemetry.sdk.environment_variables import (
OTEL_EXPORTER_OTLP_HEADERS,
@@ -43,6 +45,7 @@
_ApiKey,
_create_user,
_DefaultAdminTokenSequestration,
+ _Email,
_Expectation,
_export_embeddings,
_GetUser,
@@ -50,6 +53,7 @@
_grpc_span_exporter,
_Headers,
_http_span_exporter,
+ _initiate_password_reset,
_log_in,
_LoggedInUser,
_Password,
@@ -123,6 +127,128 @@ def test_cannot_log_in_with_deleted_user(
user.log_in()
+class TestPasswordReset:
+ def test_initiate_password_reset_does_not_reveal_whether_user_exists(
+ self,
+ _emails: Iterator[_Email],
+ _smtpd: smtpdfix.AuthController,
+ ) -> None:
+ email = next(_emails)
+ assert not _initiate_password_reset(email, _smtpd, should_receive_email=False)
+
+ @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN])
+ def test_initiate_password_reset_does_not_change_existing_password(
+ self,
+ role_or_user: _RoleOrUser,
+ _get_user: _GetUser,
+ _smtpd: smtpdfix.AuthController,
+ ) -> None:
+ u = _get_user(role_or_user)
+ assert u.initiate_password_reset(_smtpd)
+ u.log_in()
+
+ @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN])
+ def test_password_reset_cannot_be_initiated_again_while_in_progress(
+ self,
+ role_or_user: _RoleOrUser,
+ _get_user: _GetUser,
+ _smtpd: smtpdfix.AuthController,
+ ) -> None:
+ u = _get_user(role_or_user)
+ assert u.initiate_password_reset(_smtpd)
+ with pytest.raises(httpx.HTTPStatusError):
+ assert u.initiate_password_reset(_smtpd)
+
+ @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN])
+ def test_password_reset_can_be_initiated_immediately_after_password_reset(
+ self,
+ role_or_user: _RoleOrUser,
+ _get_user: _GetUser,
+ _passwords: Iterator[_Password],
+ _smtpd: smtpdfix.AuthController,
+ ) -> None:
+ u = _get_user(role_or_user)
+ new_password = next(_passwords)
+ assert new_password != u.password
+ assert (token := u.initiate_password_reset(_smtpd))
+ token.reset(new_password)
+ assert u.initiate_password_reset(_smtpd)
+
+ @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN])
+ def test_password_reset_token_is_single_use(
+ self,
+ role_or_user: _RoleOrUser,
+ _get_user: _GetUser,
+ _passwords: Iterator[_Password],
+ _smtpd: smtpdfix.AuthController,
+ ) -> None:
+ u = _get_user(role_or_user)
+ new_password = next(_passwords)
+ assert new_password != u.password
+ newer_password = next(_passwords)
+ assert newer_password != new_password
+ assert (token := u.initiate_password_reset(_smtpd))
+ token.reset(new_password)
+ with _EXPECTATION_401:
+ token.reset(newer_password)
+
+ @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN])
+ def test_initiate_password_reset_and_then_reset_password_using_token_from_email(
+ self,
+ role_or_user: _RoleOrUser,
+ _get_user: _GetUser,
+ _passwords: Iterator[_Password],
+ _smtpd: smtpdfix.AuthController,
+ ) -> None:
+ u = _get_user(role_or_user)
+ logged_in_user = u.log_in()
+ logged_in_user.create_api_key()
+ assert (token := u.initiate_password_reset(_smtpd))
+ new_password = next(_passwords)
+ assert new_password != u.password
+ token.reset(new_password)
+ with _EXPECTATION_401:
+ # old password should no longer work
+ u.log_in()
+ with _EXPECTATION_401:
+ # old logged-in tokens should no longer work
+ logged_in_user.create_api_key()
+ # new password should work
+ new_profile = replace(u.profile, password=new_password)
+ replace(u, profile=new_profile).log_in().create_api_key()
+
+ @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN])
+ def test_deleted_user_will_not_receive_email_after_initiating_password_reset(
+ self,
+ role_or_user: _RoleOrUser,
+ _get_user: _GetUser,
+ _smtpd: smtpdfix.AuthController,
+ ) -> None:
+ u = _get_user(role_or_user)
+ logged_in_user = u.log_in()
+ logged_in_user.create_api_key()
+ _DEFAULT_ADMIN.delete_users(u)
+ assert not u.initiate_password_reset(_smtpd, should_receive_email=False)
+
+ @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN])
+ def test_deleted_user_cannot_reset_password_using_token_from_email(
+ self,
+ role_or_user: _RoleOrUser,
+ _get_user: _GetUser,
+ _passwords: Iterator[_Password],
+ _smtpd: smtpdfix.AuthController,
+ ) -> None:
+ u = _get_user(role_or_user)
+ logged_in_user = u.log_in()
+ logged_in_user.create_api_key()
+ assert (token := u.initiate_password_reset(_smtpd))
+ new_password = next(_passwords)
+ assert new_password != u.password
+ _DEFAULT_ADMIN.delete_users(u)
+ with _EXPECTATION_401:
+ token.reset(new_password)
+
+
class TestLogOut:
def test_default_admin_cannot_log_out_during_testing(self) -> None:
"""
diff --git a/integration_tests/requirements.txt b/integration_tests/requirements.txt
index 681b6a7077..c9b09fe8ce 100644
--- a/integration_tests/requirements.txt
+++ b/integration_tests/requirements.txt
@@ -1,3 +1,4 @@
+beautifulsoup4
faker
httpx
openinference-semantic-conventions
@@ -5,4 +6,6 @@ opentelemetry-sdk
portpicker
psutil
pytest-randomly
+pytest-smtpd
+types-beautifulsoup4
types-psutil
diff --git a/pyproject.toml b/pyproject.toml
index 1eaea37c10..aab7b27b69 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -59,6 +59,7 @@ dependencies = [
"arize-phoenix-evals>=0.13.1",
"arize-phoenix-otel>=0.4.1",
"fastapi",
+ "fastapi-mail",
"pydantic>=1.0,!=2.0.*,<3", # exclude 2.0.* since it does not support the `json_encoders` configuration setting
"pyjwt",
]
diff --git a/src/phoenix/auth.py b/src/phoenix/auth.py
index ece7705fab..83713aa606 100644
--- a/src/phoenix/auth.py
+++ b/src/phoenix/auth.py
@@ -184,6 +184,9 @@ def validate(
DEFAULT_ADMIN_EMAIL = "admin@localhost"
DEFAULT_ADMIN_PASSWORD = "admin"
DEFAULT_SECRET_LENGTH = 32
+DEFAULT_PASSWORD_RESET_TOKEN_EXPIRY_MINUTES = 15
+DEFAULT_ACCESS_TOKEN_EXPIRY_MINUTES = 10
+DEFAULT_REFRESH_TOKEN_EXPIRY_MINUTES = 60 * 24 * 7
"""The default length of a secret key in bytes."""
EMAIL_PATTERN = re.compile(r"^[^@\s]+@[^@\s]+[.][^@\s]+\Z")
"""The regular expression pattern for a valid email address."""
diff --git a/src/phoenix/config.py b/src/phoenix/config.py
index 78f8cc626f..2f3dcf8377 100644
--- a/src/phoenix/config.py
+++ b/src/phoenix/config.py
@@ -4,9 +4,7 @@
from datetime import timedelta
from logging import getLogger
from pathlib import Path
-from typing import Dict, List, Optional, Tuple
-
-import pandas as pd
+from typing import Dict, List, Optional, Tuple, overload
from phoenix.utilities.re import parse_env_headers
@@ -74,8 +72,44 @@
ENV_PHOENIX_SECRET = "PHOENIX_SECRET"
ENV_PHOENIX_API_KEY = "PHOENIX_API_KEY"
ENV_PHOENIX_USE_SECURE_COOKIES = "PHOENIX_USE_SECURE_COOKIES"
-ENV_PHOENIX_ACCESS_TOKEN_EXPIRY = "PHOENIX_ACCESS_TOKEN_EXPIRY"
-ENV_PHOENIX_REFRESH_TOKEN_EXPIRY = "PHOENIX_REFRESH_TOKEN_EXPIRY"
+ENV_PHOENIX_ACCESS_TOKEN_EXPIRY_MINUTES = "PHOENIX_ACCESS_TOKEN_EXPIRY_MINUTES"
+"""
+The duration, in minutes, before access tokens expire.
+"""
+ENV_PHOENIX_REFRESH_TOKEN_EXPIRY_MINUTES = "PHOENIX_REFRESH_TOKEN_EXPIRY_MINUTES"
+"""
+The duration, in minutes, before refresh tokens expire.
+"""
+ENV_PHOENIX_PASSWORD_RESET_TOKEN_EXPIRY_MINUTES = "PHOENIX_PASSWORD_RESET_TOKEN_EXPIRY_MINUTES"
+"""
+The duration, in minutes, before password reset tokens expire.
+"""
+
+# SMTP settings
+ENV_PHOENIX_SMTP_HOSTNAME = "PHOENIX_SMTP_HOSTNAME"
+"""
+The SMTP server hostname to use for sending emails. SMTP is disabled if this is not set.
+"""
+ENV_PHOENIX_SMTP_PORT = "PHOENIX_SMTP_PORT"
+"""
+The SMTP server port to use for sending emails. Defaults to 587.
+"""
+ENV_PHOENIX_SMTP_USERNAME = "PHOENIX_SMTP_USERNAME"
+"""
+The SMTP server username to use for sending emails. Should be set if SMTP is enabled.
+"""
+ENV_PHOENIX_SMTP_PASSWORD = "PHOENIX_SMTP_PASSWORD"
+"""
+The SMTP server password to use for sending emails. Should be set if SMTP is enabled.
+"""
+ENV_PHOENIX_SMTP_MAIL_FROM = "PHOENIX_SMTP_MAIL_FROM"
+"""
+The email address to use as the sender when sending emails. Should be set if SMTP is enabled.
+"""
+ENV_PHOENIX_SMTP_VALIDATE_CERTS = "PHOENIX_SMTP_VALIDATE_CERTS"
+"""
+Whether to validate SMTP server certificates. Defaults to true.
+"""
def server_instrumentation_is_enabled() -> bool:
@@ -117,12 +151,16 @@ def get_working_dir() -> Path:
return Path.home().resolve() / ".phoenix"
-def get_boolean_env_var(env_var: str) -> Optional[bool]:
+@overload
+def _bool_val(env_var: str) -> Optional[bool]: ...
+@overload
+def _bool_val(env_var: str, default: bool) -> bool: ...
+def _bool_val(env_var: str, default: Optional[bool] = None) -> Optional[bool]:
"""
- Parses a boolean environment variable, returning None if the variable is not set.
+ Parses a boolean environment variable, returning `default` if the variable is not set.
"""
if (value := os.environ.get(env_var)) is None:
- return None
+ return default
assert (lower := value.lower()) in (
"true",
"false",
@@ -130,11 +168,49 @@ def get_boolean_env_var(env_var: str) -> Optional[bool]:
return lower == "true"
+@overload
+def _float_val(env_var: str) -> Optional[float]: ...
+@overload
+def _float_val(env_var: str, default: float) -> float: ...
+def _float_val(env_var: str, default: Optional[float] = None) -> Optional[float]:
+ """
+ Parses a numeric environment variable, returning `default` if the variable is not set.
+ """
+ if (value := os.environ.get(env_var)) is None:
+ return default
+ try:
+ return float(value)
+ except ValueError:
+ raise ValueError(
+ f"Invalid value for environment variable {env_var}: {value}. "
+ f"Value must be a number."
+ )
+
+
+@overload
+def _int_val(env_var: str) -> Optional[int]: ...
+@overload
+def _int_val(env_var: str, default: int) -> int: ...
+def _int_val(env_var: str, default: Optional[int] = None) -> Optional[int]:
+ """
+ Parses a numeric environment variable, returning `default` if the variable is not set.
+ """
+ if (value := os.environ.get(env_var)) is None:
+ return default
+ try:
+ return int(value)
+ except ValueError:
+ raise ValueError(
+ f"Invalid value for environment variable {env_var}: {value}. "
+ f"Value must be an integer."
+ )
+
+
def get_env_enable_auth() -> bool:
"""
Gets the value of the PHOENIX_ENABLE_AUTH environment variable.
"""
- return get_boolean_env_var(ENV_PHOENIX_ENABLE_AUTH) is True
+ return _bool_val(ENV_PHOENIX_ENABLE_AUTH, False)
def get_env_phoenix_secret() -> Optional[str]:
@@ -152,7 +228,7 @@ def get_env_phoenix_secret() -> Optional[str]:
def get_env_phoenix_use_secure_cookies() -> bool:
- return bool(get_boolean_env_var(ENV_PHOENIX_USE_SECURE_COOKIES))
+ return _bool_val(ENV_PHOENIX_USE_SECURE_COOKIES, False)
def get_env_phoenix_api_key() -> Optional[str]:
@@ -173,48 +249,72 @@ def get_env_auth_settings() -> Tuple[bool, Optional[str]]:
return enable_auth, phoenix_secret
+def get_env_password_reset_token_expiry() -> timedelta:
+ """
+ Gets the password reset token expiry.
+ """
+ from phoenix.auth import DEFAULT_PASSWORD_RESET_TOKEN_EXPIRY_MINUTES
+
+ minutes = _float_val(
+ ENV_PHOENIX_PASSWORD_RESET_TOKEN_EXPIRY_MINUTES,
+ DEFAULT_PASSWORD_RESET_TOKEN_EXPIRY_MINUTES,
+ )
+ assert minutes > 0
+ return timedelta(minutes=minutes)
+
+
def get_env_access_token_expiry() -> timedelta:
"""
Gets the access token expiry.
"""
- if (access_token_expiry := os.environ.get(ENV_PHOENIX_ACCESS_TOKEN_EXPIRY)) is None:
- return timedelta(minutes=10)
- try:
- return _parse_duration(access_token_expiry)
- except ValueError as error:
- raise ValueError(
- f"Error reading {ENV_PHOENIX_ACCESS_TOKEN_EXPIRY} environment variable: {str(error)}"
- )
+ from phoenix.auth import DEFAULT_ACCESS_TOKEN_EXPIRY_MINUTES
+
+ minutes = _float_val(
+ ENV_PHOENIX_ACCESS_TOKEN_EXPIRY_MINUTES,
+ DEFAULT_ACCESS_TOKEN_EXPIRY_MINUTES,
+ )
+ assert minutes > 0
+ return timedelta(minutes=minutes)
def get_env_refresh_token_expiry() -> timedelta:
"""
Gets the refresh token expiry.
"""
- if (refresh_token_expiry := os.environ.get(ENV_PHOENIX_REFRESH_TOKEN_EXPIRY)) is None:
- return timedelta(weeks=1)
- try:
- return _parse_duration(refresh_token_expiry)
- except ValueError as error:
- raise ValueError(
- f"Error reading {ENV_PHOENIX_REFRESH_TOKEN_EXPIRY} environment variable: {str(error)}"
- )
+ from phoenix.auth import DEFAULT_REFRESH_TOKEN_EXPIRY_MINUTES
+ minutes = _float_val(
+ ENV_PHOENIX_REFRESH_TOKEN_EXPIRY_MINUTES,
+ DEFAULT_REFRESH_TOKEN_EXPIRY_MINUTES,
+ )
+ assert minutes > 0
+ return timedelta(minutes=minutes)
-def _parse_duration(duration_str: str) -> timedelta:
- """
- Parses a duration string into a timedelta object, assuming the duration is
- in seconds if no unit is provided.
- """
- try:
- duration = timedelta(seconds=float(duration_str))
- except ValueError:
- duration = pd.Timedelta(duration_str)
- if pd.isnull(duration):
- raise ValueError("duration cannot be null")
- if duration <= timedelta(0):
- raise ValueError("duration must be positive")
- return duration
+
+def get_env_smtp_username() -> str:
+ return os.getenv(ENV_PHOENIX_SMTP_USERNAME) or ""
+
+
+def get_env_smtp_password() -> str:
+ return os.getenv(ENV_PHOENIX_SMTP_PASSWORD) or ""
+
+
+def get_env_smtp_mail_from() -> str:
+ return os.getenv(ENV_PHOENIX_SMTP_MAIL_FROM) or ""
+
+
+def get_env_smtp_hostname() -> str:
+ return os.getenv(ENV_PHOENIX_SMTP_HOSTNAME) or ""
+
+
+def get_env_smtp_port() -> int:
+ port = _int_val(ENV_PHOENIX_SMTP_PORT, 587)
+ assert 0 < port <= 65_535
+ return port
+
+
+def get_env_smtp_validate_certs() -> bool:
+ return _bool_val(ENV_PHOENIX_SMTP_VALIDATE_CERTS, True)
PHOENIX_DIR = Path(__file__).resolve().parent
diff --git a/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py b/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py
index 27c8199b98..e351f4957f 100644
--- a/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py
+++ b/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py
@@ -71,6 +71,25 @@ def upgrade() -> None:
sa.CheckConstraint("password_hash is null or password_salt is not null", name="salt"),
sqlite_autoincrement=True,
)
+ op.create_table(
+ "password_reset_tokens",
+ sa.Column("id", sa.Integer, primary_key=True),
+ sa.Column(
+ "user_id",
+ sa.Integer,
+ sa.ForeignKey("users.id", ondelete="CASCADE"),
+ unique=True,
+ index=True,
+ ),
+ sa.Column(
+ "created_at",
+ sa.TIMESTAMP(timezone=True),
+ nullable=False,
+ server_default=sa.func.now(),
+ ),
+ sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False, index=True),
+ sqlite_autoincrement=True,
+ )
op.create_table(
"refresh_tokens",
sa.Column("id", sa.Integer, primary_key=True),
@@ -125,5 +144,6 @@ def downgrade() -> None:
op.drop_table("api_keys")
op.drop_table("access_tokens")
op.drop_table("refresh_tokens")
+ op.drop_table("password_reset_tokens")
op.drop_table("users")
op.drop_table("user_roles")
diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py
index fa33c472c8..5a81d57820 100644
--- a/src/phoenix/db/models.py
+++ b/src/phoenix/db/models.py
@@ -648,6 +648,11 @@ class User(Base):
UtcTimeStamp, server_default=func.now(), onupdate=func.now()
)
deleted_at: Mapped[Optional[datetime]] = mapped_column(UtcTimeStamp)
+ password_reset_token: Mapped["PasswordResetToken"] = relationship(
+ "PasswordResetToken",
+ back_populates="user",
+ uselist=False,
+ )
access_tokens: Mapped[List["AccessToken"]] = relationship("AccessToken", back_populates="user")
refresh_tokens: Mapped[List["RefreshToken"]] = relationship(
"RefreshToken", back_populates="user"
@@ -659,6 +664,20 @@ class User(Base):
)
+class PasswordResetToken(Base):
+ __tablename__ = "password_reset_tokens"
+ id: Mapped[int] = mapped_column(primary_key=True)
+ user_id: Mapped[int] = mapped_column(
+ ForeignKey("users.id", ondelete="CASCADE"),
+ unique=True,
+ index=True,
+ )
+ user: Mapped["User"] = relationship("User", back_populates="password_reset_token")
+ created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
+ expires_at: Mapped[Optional[datetime]] = mapped_column(UtcTimeStamp, nullable=False, index=True)
+ __table_args__ = (dict(sqlite_autoincrement=True),)
+
+
class RefreshToken(Base):
__tablename__ = "refresh_tokens"
id: Mapped[int] = mapped_column(primary_key=True)
diff --git a/src/phoenix/server/api/mutations/user_mutations.py b/src/phoenix/server/api/mutations/user_mutations.py
index 889403c7ed..cfae3afba3 100644
--- a/src/phoenix/server/api/mutations/user_mutations.py
+++ b/src/phoenix/server/api/mutations/user_mutations.py
@@ -29,7 +29,7 @@
from phoenix.server.api.types.node import from_global_id_with_expected_type
from phoenix.server.api.types.User import User, to_gql_user
from phoenix.server.bearer_auth import PhoenixUser
-from phoenix.server.types import AccessTokenId, ApiKeyId, RefreshTokenId
+from phoenix.server.types import AccessTokenId, ApiKeyId, PasswordResetTokenId, RefreshTokenId
@strawberry.input
@@ -269,6 +269,14 @@ async def delete_users(
raise Conflict("Cannot delete the default admin user")
if num_resolved_user_ids < len(user_ids):
raise NotFound("Some user IDs could not be found")
+ password_reset_token_ids = [
+ PasswordResetTokenId(id_)
+ async for id_ in await session.stream_scalars(
+ delete(models.PasswordResetToken)
+ .where(models.PasswordResetToken.user_id.in_(user_ids))
+ .returning(models.PasswordResetToken.id)
+ )
+ ]
access_token_ids = [
AccessTokenId(id_)
async for id_ in await session.stream_scalars(
@@ -298,7 +306,12 @@ async def delete_users(
.where(models.User.id.in_(user_ids))
.values(deleted_at=func.now())
)
- await token_store.revoke(*access_token_ids, *refresh_token_ids, *api_key_ids)
+ await token_store.revoke(
+ *password_reset_token_ids,
+ *access_token_ids,
+ *refresh_token_ids,
+ *api_key_ids,
+ )
def _select_role_id_by_name(role_name: str) -> Select[Tuple[int]]:
diff --git a/src/phoenix/server/api/routers/auth.py b/src/phoenix/server/api/routers/auth.py
index 037ef3c38e..ed77080948 100644
--- a/src/phoenix/server/api/routers/auth.py
+++ b/src/phoenix/server/api/routers/auth.py
@@ -1,30 +1,44 @@
import asyncio
+import secrets
from datetime import datetime, timedelta, timezone
from functools import partial
+from typing import Tuple
from fastapi import APIRouter, Depends, HTTPException, Request, Response
-from sqlalchemy import and_, select
+from sqlalchemy import Select, select
from sqlalchemy.orm import joinedload
-from starlette.status import HTTP_204_NO_CONTENT, HTTP_401_UNAUTHORIZED, HTTP_404_NOT_FOUND
+from starlette.status import (
+ HTTP_204_NO_CONTENT,
+ HTTP_401_UNAUTHORIZED,
+ HTTP_404_NOT_FOUND,
+ HTTP_422_UNPROCESSABLE_ENTITY,
+ HTTP_503_SERVICE_UNAVAILABLE,
+)
from phoenix.auth import (
+ DEFAULT_SECRET_LENGTH,
PHOENIX_ACCESS_TOKEN_COOKIE_NAME,
PHOENIX_REFRESH_TOKEN_COOKIE_NAME,
Token,
+ compute_password_hash,
delete_access_token_cookie,
delete_refresh_token_cookie,
is_valid_password,
set_access_token_cookie,
set_refresh_token_cookie,
+ validate_password_format,
)
+from phoenix.config import get_base_url
+from phoenix.db import enums, models
from phoenix.db.enums import UserRole
-from phoenix.db.models import User as OrmUser
from phoenix.server.bearer_auth import PhoenixUser
-from phoenix.server.jwt_store import JwtStore
+from phoenix.server.email.templates.types import PasswordResetTemplateBody
+from phoenix.server.email.types import EmailSender
from phoenix.server.rate_limiters import ServerRateLimiter, fastapi_rate_limiter
from phoenix.server.types import (
AccessTokenAttributes,
AccessTokenClaims,
+ PasswordResetTokenClaims,
RefreshTokenAttributes,
RefreshTokenClaims,
TokenStore,
@@ -37,7 +51,16 @@
partition_seconds=60,
active_partitions=2,
)
-login_rate_limiter = fastapi_rate_limiter(rate_limiter, paths=["/login"])
+login_rate_limiter = fastapi_rate_limiter(
+ rate_limiter,
+ paths=[
+ "/login",
+ "/logout",
+ "/refresh",
+ "/password-reset-email",
+ "/password-reset",
+ ],
+)
router = APIRouter(
prefix="/auth", include_in_schema=False, dependencies=[Depends(login_rate_limiter)]
)
@@ -56,9 +79,7 @@ async def login(request: Request) -> Response:
async with request.app.state.db() as session:
user = await session.scalar(
- select(OrmUser)
- .where(and_(OrmUser.email == email, OrmUser.deleted_at.is_(None)))
- .options(joinedload(OrmUser.role))
+ _select_active_user().filter_by(email=email).options(joinedload(models.User.role))
)
if (
user is None
@@ -83,7 +104,7 @@ async def login(request: Request) -> Response:
user_role=UserRole(user.role.name),
),
)
- token_store: JwtStore = request.app.state.get_token_store()
+ token_store: TokenStore = request.app.state.get_token_store()
refresh_token, refresh_token_id = await token_store.create_refresh_token(refresh_token_claims)
access_token_claims = AccessTokenClaims(
subject=UserId(user.id),
@@ -110,7 +131,8 @@ async def logout(
request: Request,
) -> Response:
token_store: TokenStore = request.app.state.get_token_store()
- assert isinstance(user := request.user, PhoenixUser)
+ if not isinstance(user := request.user, PhoenixUser):
+ raise HTTPException(status_code=HTTP_401_UNAUTHORIZED)
await token_store.log_out(user.identity)
response = Response(status_code=HTTP_204_NO_CONTENT)
response = delete_access_token_cookie(response)
@@ -124,7 +146,7 @@ async def refresh_tokens(request: Request) -> Response:
assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta)
if (refresh_token := request.cookies.get(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)) is None:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Missing refresh token")
- token_store: JwtStore = request.app.state.get_token_store()
+ token_store: TokenStore = request.app.state.get_token_store()
refresh_token_claims = await token_store.read(Token(refresh_token))
if (
not isinstance(refresh_token_claims, RefreshTokenClaims)
@@ -150,9 +172,7 @@ async def refresh_tokens(request: Request) -> Response:
async with request.app.state.db() as session:
if (
user := await session.scalar(
- select(OrmUser)
- .where(and_(OrmUser.id == user_id, OrmUser.deleted_at.is_(None)))
- .options(joinedload(OrmUser.role))
+ _select_active_user().filter_by(id=user_id).options(joinedload(models.User.role))
)
) is None:
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="User not found")
@@ -187,4 +207,96 @@ async def refresh_tokens(request: Request) -> Response:
return response
+@router.post("/password-reset-email")
+async def initiate_password_reset(request: Request) -> Response:
+ data = await request.json()
+ if not (email := data.get("email")):
+ raise MISSING_EMAIL
+ sender: EmailSender = request.app.state.email_sender
+ if sender is None:
+ raise UNAVAILABLE
+ assert isinstance(token_expiry := request.app.state.password_reset_token_expiry, timedelta)
+ async with request.app.state.db() as session:
+ user = await session.scalar(
+ _select_active_user()
+ .filter_by(email=email)
+ .options(
+ joinedload(models.User.password_reset_token).load_only(models.PasswordResetToken.id)
+ )
+ )
+ if user is None or user.auth_method != enums.AuthMethod.LOCAL.value:
+ # Withold privileged information
+ return Response(status_code=HTTP_204_NO_CONTENT)
+ if user.password_reset_token:
+ raise RESET_IN_PROGRESS
+ password_reset_token_claims = PasswordResetTokenClaims(
+ subject=UserId(user.id),
+ issued_at=datetime.now(timezone.utc),
+ expiration_time=datetime.now(timezone.utc) + token_expiry,
+ )
+ token_store: TokenStore = request.app.state.get_token_store()
+ token, _ = await token_store.create_password_reset_token(password_reset_token_claims)
+ await sender.send_password_reset_email(email, PasswordResetTemplateBody(token, get_base_url()))
+ return Response(status_code=HTTP_204_NO_CONTENT)
+
+
+@router.post("/password-reset")
+async def reset_password(request: Request) -> Response:
+ data = await request.json()
+ if not (password := data.get("password")):
+ raise MISSING_PASSWORD
+ token_store: TokenStore = request.app.state.get_token_store()
+ if (
+ not (token := data.get("token"))
+ or not isinstance((claims := await token_store.read(token)), PasswordResetTokenClaims)
+ or not claims.expiration_time
+ or claims.expiration_time < datetime.now(timezone.utc)
+ ):
+ raise INVALID_TOKEN
+ assert (user_id := claims.subject)
+ async with request.app.state.db() as session:
+ user = await session.scalar(_select_active_user().filter_by(id=int(user_id)))
+ if user is None or user.auth_method != enums.AuthMethod.LOCAL.value:
+ # Withold privileged information
+ return Response(status_code=HTTP_204_NO_CONTENT)
+ validate_password_format(password)
+ user.password_salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH)
+ loop = asyncio.get_running_loop()
+ user.password_hash = await loop.run_in_executor(
+ None, partial(compute_password_hash, password=password, salt=user.password_salt)
+ )
+ async with request.app.state.db() as session:
+ session.add(user)
+ await session.flush()
+ response = Response(status_code=HTTP_204_NO_CONTENT)
+ assert (token_id := claims.token_id)
+ await token_store.revoke(token_id)
+ await token_store.log_out(UserId(user.id))
+ return response
+
+
+def _select_active_user() -> Select[Tuple[models.User]]:
+ return select(models.User).where(models.User.deleted_at.is_(None))
+
+
LOGIN_FAILED_MESSAGE = "Invalid email and/or password"
+
+MISSING_EMAIL = HTTPException(
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
+ detail="Email required",
+)
+MISSING_PASSWORD = HTTPException(
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
+ detail="Password required",
+)
+UNAVAILABLE = HTTPException(
+ status_code=HTTP_503_SERVICE_UNAVAILABLE,
+)
+RESET_IN_PROGRESS = HTTPException(
+ status_code=HTTP_503_SERVICE_UNAVAILABLE,
+ detail="Password reset already in progress",
+)
+INVALID_TOKEN = HTTPException(
+ status_code=HTTP_401_UNAUTHORIZED,
+ detail="Invalid token",
+)
diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py
index 2070e09f73..c58b2ea4ed 100644
--- a/src/phoenix/server/app.py
+++ b/src/phoenix/server/app.py
@@ -96,6 +96,7 @@
from phoenix.server.bearer_auth import BearerTokenAuthBackend, is_authenticated
from phoenix.server.dml_event import DmlEvent
from phoenix.server.dml_event_handler import DmlEventHandler
+from phoenix.server.email.types import EmailSender
from phoenix.server.grpc_server import GrpcServer
from phoenix.server.jwt_store import JwtStore
from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider
@@ -605,9 +606,11 @@ def create_app(
startup_callbacks: Iterable[_Callback] = (),
shutdown_callbacks: Iterable[_Callback] = (),
secret: Optional[str] = None,
+ password_reset_token_expiry: Optional[timedelta] = None,
access_token_expiry: Optional[timedelta] = None,
refresh_token_expiry: Optional[timedelta] = None,
scaffolder_config: Optional[ScaffolderConfig] = None,
+ email_sender: Optional[EmailSender] = None,
) -> FastAPI:
startup_callbacks_list: List[_Callback] = list(startup_callbacks)
shutdown_callbacks_list: List[_Callback] = list(shutdown_callbacks)
@@ -742,9 +745,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
)
app.state.read_only = read_only
app.state.export_path = export_path
+ app.state.password_reset_token_expiry = password_reset_token_expiry
app.state.access_token_expiry = access_token_expiry
app.state.refresh_token_expiry = refresh_token_expiry
app.state.db = db
+ app.state.email_sender = email_sender
app = _add_get_secret_method(app=app, secret=secret)
app = _add_get_token_store_method(app=app, token_store=token_store)
if tracer_provider:
diff --git a/src/phoenix/server/email/__init__.py b/src/phoenix/server/email/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/phoenix/server/email/sender.py b/src/phoenix/server/email/sender.py
new file mode 100644
index 0000000000..f44219af70
--- /dev/null
+++ b/src/phoenix/server/email/sender.py
@@ -0,0 +1,29 @@
+from dataclasses import asdict
+from pathlib import Path
+
+from fastapi_mail import ConnectionConfig, FastMail, MessageSchema
+
+from phoenix.server.email.templates.types import PasswordResetTemplateBody
+
+EMAIL_TEMPLATE_FOLDER = Path(__file__).parent / "templates"
+
+
+class FastMailSender:
+ def __init__(self, conf: ConnectionConfig) -> None:
+ self._fm = FastMail(conf)
+
+ async def send_password_reset_email(
+ self,
+ email: str,
+ values: PasswordResetTemplateBody,
+ ) -> None:
+ message = MessageSchema(
+ subject="Password Reset Request",
+ recipients=[email],
+ template_body=asdict(values),
+ subtype="html",
+ )
+ await self._fm.send_message(
+ message,
+ template_name="password_reset.html",
+ )
diff --git a/src/phoenix/server/email/templates/__init__.py b/src/phoenix/server/email/templates/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/phoenix/server/email/templates/password_reset.html b/src/phoenix/server/email/templates/password_reset.html
new file mode 100644
index 0000000000..aab4b3dddc
--- /dev/null
+++ b/src/phoenix/server/email/templates/password_reset.html
@@ -0,0 +1,15 @@
+
+
+
+
+ Password Reset
+
+
+Hello.
+You have requested a password reset. Please click on the link below to reset your password:
+
+ Reset Password
+
+If you did not make this request, please contact your administrator.
+
+
diff --git a/src/phoenix/server/email/templates/types.py b/src/phoenix/server/email/templates/types.py
new file mode 100644
index 0000000000..a7cf1c0e3b
--- /dev/null
+++ b/src/phoenix/server/email/templates/types.py
@@ -0,0 +1,11 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+from phoenix.server.types import PasswordResetToken
+
+
+@dataclass(frozen=True)
+class PasswordResetTemplateBody:
+ token: PasswordResetToken
+ base_url: str
diff --git a/src/phoenix/server/email/types.py b/src/phoenix/server/email/types.py
new file mode 100644
index 0000000000..ad64166285
--- /dev/null
+++ b/src/phoenix/server/email/types.py
@@ -0,0 +1,13 @@
+from __future__ import annotations
+
+from typing import Protocol
+
+from phoenix.server.email.templates.types import PasswordResetTemplateBody
+
+
+class EmailSender(Protocol):
+ async def send_password_reset_email(
+ self,
+ email: str,
+ values: PasswordResetTemplateBody,
+ ) -> None: ...
diff --git a/src/phoenix/server/jwt_store.py b/src/phoenix/server/jwt_store.py
index 430d97684d..4b8da3af72 100644
--- a/src/phoenix/server/jwt_store.py
+++ b/src/phoenix/server/jwt_store.py
@@ -5,7 +5,7 @@
from dataclasses import replace
from datetime import datetime, timezone
from functools import cached_property, singledispatchmethod
-from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
+from typing import Any, Callable, Coroutine, Dict, Generic, List, Optional, Tuple, Type, TypeVar
import jwt
from sqlalchemy import Select, delete, select
@@ -28,6 +28,10 @@
ApiKeyId,
DaemonTask,
DbSessionFactory,
+ PasswordResetToken,
+ PasswordResetTokenAttributes,
+ PasswordResetTokenClaims,
+ PasswordResetTokenId,
RefreshToken,
RefreshTokenAttributes,
RefreshTokenClaims,
@@ -53,6 +57,7 @@ def __init__(
self._db = db
self._secret = secret
args = (db, secret, algorithm, sleep_seconds)
+ self._password_reset_token_store = _PasswordResetTokenStore(*args, **kwargs)
self._access_token_store = _AccessTokenStore(*args, **kwargs)
self._refresh_token_store = _RefreshTokenStore(*args, **kwargs)
self._api_key_store = _ApiKeyStore(*args, **kwargs)
@@ -87,6 +92,10 @@ async def read(self, token: Token) -> Optional[ClaimSet]:
async def _get(self, _: TokenId) -> Optional[ClaimSet]:
return None
+ @_get.register
+ async def _(self, token_id: PasswordResetTokenId) -> Optional[ClaimSet]:
+ return await self._password_reset_token_store.get(token_id)
+
@_get.register
async def _(self, token_id: AccessTokenId) -> Optional[ClaimSet]:
return await self._access_token_store.get(token_id)
@@ -103,6 +112,10 @@ async def _(self, token_id: ApiKeyId) -> Optional[ClaimSet]:
async def _evict(self, _: TokenId) -> Optional[ClaimSet]:
return None
+ @_evict.register
+ async def _(self, token_id: PasswordResetTokenId) -> Optional[ClaimSet]:
+ return await self._password_reset_token_store.evict(token_id)
+
@_evict.register
async def _(self, token_id: AccessTokenId) -> Optional[ClaimSet]:
return await self._access_token_store.evict(token_id)
@@ -115,6 +128,12 @@ async def _(self, token_id: RefreshTokenId) -> Optional[ClaimSet]:
async def _(self, token_id: ApiKeyId) -> Optional[ClaimSet]:
return await self._api_key_store.evict(token_id)
+ async def create_password_reset_token(
+ self,
+ claim: PasswordResetTokenClaims,
+ ) -> Tuple[PasswordResetToken, PasswordResetTokenId]:
+ return await self._password_reset_token_store.create(claim)
+
async def create_access_token(
self,
claim: AccessTokenClaims,
@@ -136,21 +155,29 @@ async def create_api_key(
async def revoke(self, *token_ids: TokenId) -> None:
if not token_ids:
return
+ password_reset_token_ids: List[PasswordResetTokenId] = []
access_token_ids: List[AccessTokenId] = []
refresh_token_ids: List[RefreshTokenId] = []
api_key_ids: List[ApiKeyId] = []
for token_id in token_ids:
+ if isinstance(token_id, PasswordResetTokenId):
+ password_reset_token_ids.append(token_id)
if isinstance(token_id, AccessTokenId):
access_token_ids.append(token_id)
elif isinstance(token_id, RefreshTokenId):
refresh_token_ids.append(token_id)
elif isinstance(token_id, ApiKeyId):
api_key_ids.append(token_id)
- await gather(
- self._access_token_store.revoke(*access_token_ids),
- self._refresh_token_store.revoke(*refresh_token_ids),
- self._api_key_store.revoke(*api_key_ids),
- )
+ coroutines: List[Coroutine[None, None, None]] = []
+ if password_reset_token_ids:
+ coroutines.append(self._password_reset_token_store.revoke(*password_reset_token_ids))
+ if access_token_ids:
+ coroutines.append(self._access_token_store.revoke(*access_token_ids))
+ if refresh_token_ids:
+ coroutines.append(self._refresh_token_store.revoke(*refresh_token_ids))
+ if api_key_ids:
+ coroutines.append(self._api_key_store.revoke(*api_key_ids))
+ await gather(*coroutines)
async def log_out(self, user_id: UserId) -> None:
for cls in (AccessTokenId, RefreshTokenId):
@@ -166,6 +193,7 @@ async def log_out(self, user_id: UserId) -> None:
_ClaimSetT = TypeVar("_ClaimSetT", bound=ClaimSet)
_RecordT = TypeVar(
"_RecordT",
+ models.PasswordResetToken,
models.AccessToken,
models.RefreshToken,
models.ApiKey,
@@ -286,6 +314,45 @@ async def _run(self) -> None:
self._tasks.pop()
+class _PasswordResetTokenStore(
+ _Store[
+ PasswordResetTokenClaims,
+ PasswordResetToken,
+ PasswordResetTokenId,
+ models.PasswordResetToken,
+ ]
+):
+ _table = models.PasswordResetToken
+ _token_id = PasswordResetTokenId
+ _token = PasswordResetToken
+
+ def _from_db(
+ self,
+ record: models.PasswordResetToken,
+ user_role: UserRole,
+ ) -> Tuple[PasswordResetTokenId, PasswordResetTokenClaims]:
+ token_id = PasswordResetTokenId(record.id)
+ return token_id, PasswordResetTokenClaims(
+ token_id=token_id,
+ subject=UserId(record.user_id),
+ issued_at=record.created_at,
+ expiration_time=record.expires_at,
+ attributes=PasswordResetTokenAttributes(
+ user_role=user_role,
+ ),
+ )
+
+ def _to_db(self, claim: PasswordResetTokenClaims) -> models.PasswordResetToken:
+ assert claim.expiration_time
+ assert claim.subject
+ user_id = int(claim.subject)
+ return models.PasswordResetToken(
+ user_id=user_id,
+ created_at=claim.issued_at,
+ expires_at=claim.expiration_time,
+ )
+
+
class _AccessTokenStore(
_Store[
AccessTokenClaims,
diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py
index 32889cd23b..8e788b2dfe 100644
--- a/src/phoenix/server/main.py
+++ b/src/phoenix/server/main.py
@@ -11,6 +11,7 @@
from typing import List, Optional
from urllib.parse import urljoin
+from fastapi_mail import ConnectionConfig
from jinja2 import BaseLoader, Environment
from uvicorn import Config, Server
@@ -25,8 +26,15 @@
get_env_grpc_port,
get_env_host,
get_env_host_root_path,
+ get_env_password_reset_token_expiry,
get_env_port,
get_env_refresh_token_expiry,
+ get_env_smtp_hostname,
+ get_env_smtp_mail_from,
+ get_env_smtp_password,
+ get_env_smtp_port,
+ get_env_smtp_username,
+ get_env_smtp_validate_certs,
get_pids_path,
get_working_dir,
)
@@ -47,6 +55,7 @@
create_engine_and_run_migrations,
instrument_engine_if_enabled,
)
+from phoenix.server.email.sender import EMAIL_TEMPLATE_FOLDER, FastMailSender
from phoenix.server.types import DbSessionFactory
from phoenix.settings import Settings
from phoenix.trace.fixtures import (
@@ -367,6 +376,25 @@ def _get_pid_file() -> Path:
scaffold_datasets=scaffold_datasets,
phoenix_url=root_path,
)
+ email_sender = None
+ if mail_sever := get_env_smtp_hostname():
+ assert (mail_username := get_env_smtp_username()), "SMTP username is required"
+ assert (mail_password := get_env_smtp_password()), "SMTP password is required"
+ assert (mail_from := get_env_smtp_mail_from()), "SMTP mail_from is required"
+ email_sender = FastMailSender(
+ ConnectionConfig(
+ MAIL_USERNAME=mail_username,
+ MAIL_PASSWORD=mail_password,
+ MAIL_FROM=mail_from,
+ MAIL_SERVER=mail_sever,
+ MAIL_PORT=get_env_smtp_port(),
+ VALIDATE_CERTS=get_env_smtp_validate_certs(),
+ USE_CREDENTIALS=True,
+ MAIL_STARTTLS=True,
+ MAIL_SSL_TLS=False,
+ TEMPLATE_FOLDER=EMAIL_TEMPLATE_FOLDER,
+ )
+ )
app = create_app(
db=factory,
export_path=export_path,
@@ -384,9 +412,11 @@ def _get_pid_file() -> Path:
startup_callbacks=[lambda: print(msg)],
shutdown_callbacks=instrumentation_cleanups,
secret=secret,
+ password_reset_token_expiry=get_env_password_reset_token_expiry(),
access_token_expiry=get_env_access_token_expiry(),
refresh_token_expiry=get_env_refresh_token_expiry(),
scaffolder_config=scaffolder_config,
+ email_sender=email_sender,
)
server = Server(config=Config(app, host=host, port=port, root_path=host_root_path)) # type: ignore
Thread(target=_write_pid_file_when_ready, args=(server,), daemon=True).start()
diff --git a/src/phoenix/server/types.py b/src/phoenix/server/types.py
index fbbd4c6d2b..aaaf32fa5e 100644
--- a/src/phoenix/server/types.py
+++ b/src/phoenix/server/types.py
@@ -143,6 +143,9 @@ def set(self, table: Type[models.Base], id_: int) -> None:
self._cache[table][id_] = datetime.now(timezone.utc)
+class PasswordResetToken(Token): ...
+
+
class AccessToken(Token): ...
@@ -161,6 +164,10 @@ class UserTokenAttributes(TokenAttributes):
class RefreshTokenAttributes(UserTokenAttributes): ...
+@dataclass(frozen=True)
+class PasswordResetTokenAttributes(UserTokenAttributes): ...
+
+
@dataclass(frozen=True)
class AccessTokenAttributes(UserTokenAttributes):
refresh_token_id: RefreshTokenId
@@ -198,6 +205,11 @@ def parse(cls, value: str) -> Optional[TokenId]:
return None
+@final
+class PasswordResetTokenId(TokenId):
+ table = models.PasswordResetToken
+
+
@final
class AccessTokenId(TokenId):
table = models.AccessToken
@@ -224,6 +236,12 @@ class UserClaimSet(ClaimSet):
attributes: Optional[UserTokenAttributes] = None
+@dataclass(frozen=True)
+class PasswordResetTokenClaims(UserClaimSet):
+ token_id: Optional[PasswordResetTokenId] = None
+ attributes: Optional[PasswordResetTokenAttributes] = None
+
+
@dataclass(frozen=True)
class AccessTokenClaims(UserClaimSet):
token_id: Optional[AccessTokenId] = None
@@ -251,6 +269,10 @@ async def log_out(self, user_id: UserId) -> None: ...
class TokenStore(CanReadToken, CanRevokeTokens, CanLogOutUser, Protocol):
+ async def create_password_reset_token(
+ self,
+ claims: PasswordResetTokenClaims,
+ ) -> Tuple[PasswordResetToken, PasswordResetTokenId]: ...
async def create_access_token(
self,
claims: AccessTokenClaims,
diff --git a/tests/test_config.py b/tests/test_config.py
deleted file mode 100644
index abae5def99..0000000000
--- a/tests/test_config.py
+++ /dev/null
@@ -1,86 +0,0 @@
-from datetime import timedelta
-from typing import Callable
-
-import pytest
-
-from phoenix.config import (
- ENV_PHOENIX_ACCESS_TOKEN_EXPIRY,
- ENV_PHOENIX_REFRESH_TOKEN_EXPIRY,
- get_env_access_token_expiry,
- get_env_refresh_token_expiry,
-)
-
-
-@pytest.mark.parametrize(
- "env_var_value, expected_value",
- (
- pytest.param("3600", timedelta(seconds=3600), id="with-positive-unitless-integer"),
- pytest.param("3600.10", timedelta(seconds=3600.10), id="with-positive-decimal"),
- pytest.param("36 days", timedelta(days=36), id="with-positive-day-unit"),
- pytest.param("36d", timedelta(days=36), id="with-positive-d-unit"),
- pytest.param(
- "P4M6DT3H12M45S",
- timedelta(days=(4 * 30 + 6), hours=3, minutes=12, seconds=45),
- id="with-iso-8601-duration",
- ),
- ),
-)
-@pytest.mark.parametrize(
- "env_var_name, env_var_getter",
- (
- pytest.param(
- ENV_PHOENIX_ACCESS_TOKEN_EXPIRY, get_env_access_token_expiry, id="access-token-expiry"
- ),
- pytest.param(
- ENV_PHOENIX_REFRESH_TOKEN_EXPIRY,
- get_env_refresh_token_expiry,
- id="refresh-token-expiry",
- ),
- ),
-)
-def test_get_env_token_expiry_parses_valid_values(
- env_var_name: str,
- env_var_getter: Callable[[], timedelta],
- env_var_value: str,
- expected_value: timedelta,
- monkeypatch: pytest.MonkeyPatch,
-) -> None:
- monkeypatch.setenv(env_var_name, env_var_value)
- env_var_getter() == expected_value
-
-
-@pytest.mark.parametrize(
- "env_var_value, error_message",
- (
- pytest.param("-3600", "duration must be positive", id="with-negative-integer"),
- pytest.param("-3600.10", "duration must be positive", id="with-negative-decimal"),
- pytest.param("-36d", "duration must be positive", id="with-negative-d-unit"),
- pytest.param("0", "duration must be positive", id="with-zero-duration"),
- pytest.param("nan", "duration cannot be null", id="with-null"),
- ),
-)
-@pytest.mark.parametrize(
- "env_var_name, env_var_getter",
- (
- pytest.param(
- ENV_PHOENIX_ACCESS_TOKEN_EXPIRY, get_env_access_token_expiry, id="access-token-expiry"
- ),
- pytest.param(
- ENV_PHOENIX_REFRESH_TOKEN_EXPIRY,
- get_env_refresh_token_expiry,
- id="refresh-token-expiry",
- ),
- ),
-)
-def test_get_env_token_expiry_raises_expected_errors_for_invalid_values(
- env_var_name: str,
- env_var_getter: Callable[[], timedelta],
- env_var_value: str,
- error_message: str,
- monkeypatch: pytest.MonkeyPatch,
-) -> None:
- monkeypatch.setenv(env_var_name, env_var_value)
- with pytest.raises(
- ValueError, match=f"Error reading {env_var_name} environment variable: {error_message}"
- ):
- env_var_getter()