diff --git a/xinference/api/oauth2/auth_service.py b/xinference/api/oauth2/auth_service.py new file mode 100644 index 0000000000..1b793059a9 --- /dev/null +++ b/xinference/api/oauth2/auth_service.py @@ -0,0 +1,132 @@ +# Copyright 2022-2024 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import timedelta +from typing import List, Optional + +import pydantic +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer, SecurityScopes +from jose import JWTError, jwt +from pydantic import BaseModel, ValidationError +from typing_extensions import Annotated + +from .types import AuthStartupConfig, User +from .utils import create_access_token, get_password_hash, verify_password + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + + +class TokenData(BaseModel): + username: str + scopes: List[str] = [] + + +class AuthService: + def __init__(self, auth_config_file: Optional[str]): + self._auth_config_file = auth_config_file + self._config = self.init_auth_config() + + @property + def config(self): + return self._config + + def init_auth_config(self): + if self._auth_config_file: + config: AuthStartupConfig = pydantic.parse_file_as( + path=self._auth_config_file, type_=AuthStartupConfig + ) + for user in config.user_config: + user.password = get_password_hash(user.password) + return config + + def __call__( + self, + security_scopes: SecurityScopes, + token: Annotated[str, Depends(oauth2_scheme)], + ): + """ + Advanced dependencies. See: https://fastapi.tiangolo.com/advanced/advanced-dependencies/ + """ + if security_scopes.scopes: + authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' + else: + authenticate_value = "Bearer" + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": authenticate_value}, + ) + + try: + assert self._config is not None + payload = jwt.decode( + token, + self._config.auth_config.secret_key, + algorithms=[self._config.auth_config.algorithm], + options={"verify_exp": False}, # TODO: supports token expiration + ) + username: str = payload.get("sub") + if username is None: + raise credentials_exception + token_scopes = payload.get("scopes", []) + token_data = TokenData(scopes=token_scopes, username=username) + except (JWTError, ValidationError): + raise credentials_exception + user = self.get_user(token_data.username) + if user is None: + raise credentials_exception + if "admin" in token_data.scopes: + return user + for scope in security_scopes.scopes: + if scope not in token_data.scopes: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not enough permissions", + headers={"WWW-Authenticate": authenticate_value}, + ) + return user + + def get_user(self, username: str) -> Optional[User]: + for user in self._config.user_config: + if user.username == username: + return user + return None + + def authenticate_user(self, username: str, password: str): + user = self.get_user(username) + if not user: + return False + if not verify_password(password, user.password): + return False + return user + + def generate_token_for_user(self, username: str, password: str): + user = self.authenticate_user(username, password) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + assert user is not None and isinstance(user, User) + access_token_expires = timedelta( + minutes=self._config.auth_config.token_expire_in_minutes + ) + access_token = create_access_token( + data={"sub": user.username, "scopes": user.permissions}, + secret_key=self._config.auth_config.secret_key, + algorithm=self._config.auth_config.algorithm, + expires_delta=access_token_expires, + ) + return {"access_token": access_token, "token_type": "bearer"} diff --git a/xinference/api/oauth2/common.py b/xinference/api/oauth2/common.py deleted file mode 100644 index 3d74b66482..0000000000 --- a/xinference/api/oauth2/common.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2022-2023 XProbe Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -XINFERENCE_OAUTH2_CONFIG = None diff --git a/xinference/api/oauth2/core.py b/xinference/api/oauth2/core.py deleted file mode 100644 index e1a6724de0..0000000000 --- a/xinference/api/oauth2/core.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2022-2023 XProbe Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import List, Optional, Union - -from fastapi import Depends, HTTPException, status -from fastapi.security import OAuth2PasswordBearer, SecurityScopes -from jose import JWTError, jwt -from pydantic import BaseModel, ValidationError -from typing_extensions import Annotated - -from .types import AuthStartupConfig, User - -logger = logging.getLogger(__name__) - - -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") - - -def get_db(): - from .common import XINFERENCE_OAUTH2_CONFIG - - # In a real enterprise-level environment, this should be the database - yield XINFERENCE_OAUTH2_CONFIG - - -def get_user(db_users: List[User], username: str) -> Optional[User]: - for user in db_users: - if user.username == username: - return user - return None - - -class TokenData(BaseModel): - username: Union[str, None] = None - scopes: List[str] = [] - - -def verify_token( - security_scopes: SecurityScopes, - token: Annotated[str, Depends(oauth2_scheme)], - config: Optional[AuthStartupConfig] = Depends(get_db), -): - if security_scopes.scopes: - authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' - else: - authenticate_value = "Bearer" - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": authenticate_value}, - ) - - try: - assert config is not None - payload = jwt.decode( - token, - config.auth_config.secret_key, - algorithms=[config.auth_config.algorithm], - options={"verify_exp": False}, # TODO: supports token expiration - ) - username: str = payload.get("sub") - if username is None: - raise credentials_exception - token_scopes = payload.get("scopes", []) - # TODO: check expire - token_data = TokenData(scopes=token_scopes, username=username) - except (JWTError, ValidationError): - raise credentials_exception - user = get_user(config.user_config, username=token_data.username) # type: ignore - if user is None: - raise credentials_exception - if "admin" in token_data.scopes: - return user - for scope in security_scopes.scopes: - if scope not in token_data.scopes: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Not enough permissions", - headers={"WWW-Authenticate": authenticate_value}, - ) - return user diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index 4383d6816e..1d765d8233 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -22,11 +22,9 @@ import sys import time import warnings -from datetime import timedelta from typing import Any, List, Optional, Union import gradio as gr -import pydantic import xoscar as xo from aioprometheus import REGISTRY, MetricsMiddleware from aioprometheus.asgi.starlette import metrics @@ -41,7 +39,6 @@ Response, Security, UploadFile, - status, ) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse @@ -66,9 +63,8 @@ CreateCompletion, ImageList, ) -from .oauth2.core import get_user, verify_token -from .oauth2.types import AuthStartupConfig, LoginUserForm, User -from .oauth2.utils import create_access_token, get_password_hash, verify_password +from .oauth2.auth_service import AuthService +from .oauth2.types import LoginUserForm logger = logging.getLogger(__name__) @@ -137,15 +133,6 @@ class BuildGradioInterfaceRequest(BaseModel): model_lang: List[str] -def authenticate_user(db_users: List[User], username: str, password: str): - user = get_user(db_users, username) - if not user: - return False - if not verify_password(password, user.password): - return False - return user - - class RESTfulAPI: def __init__( self, @@ -160,25 +147,12 @@ def __init__( self._port = port self._supervisor_ref = None self._event_collector_ref = None - self._auth_config: AuthStartupConfig = self.init_auth_config(auth_config_file) + self._auth_service = AuthService(auth_config_file) self._router = APIRouter() self._app = FastAPI() - @staticmethod - def init_auth_config(auth_config_file: Optional[str]): - from .oauth2 import common - - if auth_config_file: - config: AuthStartupConfig = pydantic.parse_file_as( - path=auth_config_file, type_=AuthStartupConfig - ) - for user in config.user_config: - user.password = get_password_hash(user.password) - common.XINFERENCE_OAUTH2_CONFIG = config # type: ignore - return config - def is_authenticated(self): - return False if self._auth_config is None else True + return False if self._auth_service.config is None else True @staticmethod def handle_request_limit_error(e: Exception): @@ -216,28 +190,10 @@ async def _report_error_event(self, model_uid: str, content: str): ) async def login_for_access_token(self, form_data: LoginUserForm) -> JSONResponse: - user = authenticate_user( - self._auth_config.user_config, form_data.username, form_data.password - ) - if not user: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect username or password", - headers={"WWW-Authenticate": "Bearer"}, - ) - assert user is not None and isinstance(user, User) - access_token_expires = timedelta( - minutes=self._auth_config.auth_config.token_expire_in_minutes - ) - access_token = create_access_token( - data={"sub": user.username, "scopes": user.permissions}, - secret_key=self._auth_config.auth_config.secret_key, - algorithm=self._auth_config.auth_config.algorithm, - expires_delta=access_token_expires, - ) - return JSONResponse( - content={"access_token": access_token, "token_type": "bearer"} + result = self._auth_service.generate_token_for_user( + form_data.username, form_data.password ) + return JSONResponse(content=result) async def is_cluster_authenticated(self) -> JSONResponse: return JSONResponse(content={"auth": self.is_authenticated()}) @@ -270,7 +226,7 @@ def serve(self, logging_conf: Optional[dict] = None): "/v1/ui/{model_uid}", self.build_gradio_interface, methods=["POST"], - dependencies=[Security(verify_token, scopes=["models:read"])] + dependencies=[Security(self._auth_service, scopes=["models:read"])] if self.is_authenticated() else None, ) @@ -285,7 +241,7 @@ def serve(self, logging_conf: Optional[dict] = None): "/v1/models/instances", self.get_instance_info, methods=["GET"], - dependencies=[Security(verify_token, scopes=["models:list"])] + dependencies=[Security(self._auth_service, scopes=["models:list"])] if self.is_authenticated() else None, ) @@ -293,7 +249,7 @@ def serve(self, logging_conf: Optional[dict] = None): "/v1/models/{model_type}/{model_name}/versions", self.get_model_versions, methods=["GET"], - dependencies=[Security(verify_token, scopes=["models:list"])] + dependencies=[Security(self._auth_service, scopes=["models:list"])] if self.is_authenticated() else None, ) @@ -301,7 +257,7 @@ def serve(self, logging_conf: Optional[dict] = None): "/v1/models", self.list_models, methods=["GET"], - dependencies=[Security(verify_token, scopes=["models:list"])] + dependencies=[Security(self._auth_service, scopes=["models:list"])] if self.is_authenticated() else None, ) @@ -310,7 +266,7 @@ def serve(self, logging_conf: Optional[dict] = None): "/v1/models/{model_uid}", self.describe_model, methods=["GET"], - dependencies=[Security(verify_token, scopes=["models:list"])] + dependencies=[Security(self._auth_service, scopes=["models:list"])] if self.is_authenticated() else None, ) @@ -318,7 +274,7 @@ def serve(self, logging_conf: Optional[dict] = None): "/v1/models/{model_uid}/events", self.get_model_events, methods=["GET"], - dependencies=[Security(verify_token, scopes=["models:read"])] + dependencies=[Security(self._auth_service, scopes=["models:read"])] if self.is_authenticated() else None, ) @@ -326,7 +282,7 @@ def serve(self, logging_conf: Optional[dict] = None): "/v1/models/instance", self.launch_model_by_version, methods=["POST"], - dependencies=[Security(verify_token, scopes=["models:start"])] + dependencies=[Security(self._auth_service, scopes=["models:start"])] if self.is_authenticated() else None, ) @@ -334,7 +290,7 @@ def serve(self, logging_conf: Optional[dict] = None): "/v1/models", self.launch_model, methods=["POST"], - dependencies=[Security(verify_token, scopes=["models:start"])] + dependencies=[Security(self._auth_service, scopes=["models:start"])] if self.is_authenticated() else None, ) @@ -342,7 +298,7 @@ def serve(self, logging_conf: Optional[dict] = None): "/experimental/speculative_llms", self.launch_speculative_llm, methods=["POST"], - dependencies=[Security(verify_token, scopes=["models:start"])] + dependencies=[Security(self._auth_service, scopes=["models:start"])] if self.is_authenticated() else None, ) @@ -350,7 +306,7 @@ def serve(self, logging_conf: Optional[dict] = None): "/v1/models/{model_uid}", self.terminate_model, methods=["DELETE"], - dependencies=[Security(verify_token, scopes=["models:stop"])] + dependencies=[Security(self._auth_service, scopes=["models:stop"])] if self.is_authenticated() else None, ) @@ -359,7 +315,7 @@ def serve(self, logging_conf: Optional[dict] = None): self.create_completion, methods=["POST"], response_model=Completion, - dependencies=[Security(verify_token, scopes=["models:read"])] + dependencies=[Security(self._auth_service, scopes=["models:read"])] if self.is_authenticated() else None, ) @@ -367,7 +323,7 @@ def serve(self, logging_conf: Optional[dict] = None): "/v1/embeddings", self.create_embedding, methods=["POST"], - dependencies=[Security(verify_token, scopes=["models:read"])] + dependencies=[Security(self._auth_service, scopes=["models:read"])] if self.is_authenticated() else None, ) @@ -375,7 +331,7 @@ def serve(self, logging_conf: Optional[dict] = None): "/v1/rerank", self.rerank, methods=["POST"], - dependencies=[Security(verify_token, scopes=["models:read"])] + dependencies=[Security(self._auth_service, scopes=["models:read"])] if self.is_authenticated() else None, ) @@ -384,7 +340,7 @@ def serve(self, logging_conf: Optional[dict] = None): self.create_images, methods=["POST"], response_model=ImageList, - dependencies=[Security(verify_token, scopes=["models:read"])] + dependencies=[Security(self._auth_service, scopes=["models:read"])] if self.is_authenticated() else None, ) @@ -393,7 +349,7 @@ def serve(self, logging_conf: Optional[dict] = None): self.create_variations, methods=["POST"], response_model=ImageList, - dependencies=[Security(verify_token, scopes=["models:read"])] + dependencies=[Security(self._auth_service, scopes=["models:read"])] if self.is_authenticated() else None, ) @@ -402,7 +358,7 @@ def serve(self, logging_conf: Optional[dict] = None): self.create_chat_completion, methods=["POST"], response_model=ChatCompletion, - dependencies=[Security(verify_token, scopes=["models:read"])] + dependencies=[Security(self._auth_service, scopes=["models:read"])] if self.is_authenticated() else None, ) @@ -412,7 +368,7 @@ def serve(self, logging_conf: Optional[dict] = None): "/v1/model_registrations/{model_type}", self.register_model, methods=["POST"], - dependencies=[Security(verify_token, scopes=["models:register"])] + dependencies=[Security(self._auth_service, scopes=["models:register"])] if self.is_authenticated() else None, ) @@ -420,7 +376,7 @@ def serve(self, logging_conf: Optional[dict] = None): "/v1/model_registrations/{model_type}/{model_name}", self.unregister_model, methods=["DELETE"], - dependencies=[Security(verify_token, scopes=["models:unregister"])] + dependencies=[Security(self._auth_service, scopes=["models:unregister"])] if self.is_authenticated() else None, ) @@ -428,7 +384,7 @@ def serve(self, logging_conf: Optional[dict] = None): "/v1/model_registrations/{model_type}", self.list_model_registrations, methods=["GET"], - dependencies=[Security(verify_token, scopes=["models:list"])] + dependencies=[Security(self._auth_service, scopes=["models:list"])] if self.is_authenticated() else None, ) @@ -436,7 +392,7 @@ def serve(self, logging_conf: Optional[dict] = None): "/v1/model_registrations/{model_type}/{model_name}", self.get_model_registrations, methods=["GET"], - dependencies=[Security(verify_token, scopes=["models:list"])] + dependencies=[Security(self._auth_service, scopes=["models:list"])] if self.is_authenticated() else None, )