From 8846be1d73d406672ba8d4458ff791dc5419bcb9 Mon Sep 17 00:00:00 2001 From: Ago327 Date: Fri, 1 Mar 2024 16:54:21 +0800 Subject: [PATCH 01/18] Setup for Contributing Doc --- .../development/Setup for Contributing.md | 95 +++++++++++++++++++ ...57\345\242\203\346\220\255\345\273\272.md" | 91 ++++++++++++++++++ 2 files changed, 186 insertions(+) create mode 100644 doc/source/development/Setup for Contributing.md create mode 100644 "doc/source/development/\345\274\200\345\217\221\347\216\257\345\242\203\346\220\255\345\273\272.md" diff --git a/doc/source/development/Setup for Contributing.md b/doc/source/development/Setup for Contributing.md new file mode 100644 index 0000000000..cfae9e3fce --- /dev/null +++ b/doc/source/development/Setup for Contributing.md @@ -0,0 +1,95 @@ +# Setup for Contributing + + + +Table of contents: + +- [Getting startted with Git](#Getting-startted-with-Git) +- [Setting up Conda environment](#Setting-up-Conda-environment) +- [Xinference Installation](#Xinference-Installation) +- [Frontend Compilation](#Frontend-Compilation) + + + +## [Getting startted with Git](#Setup-for-Contributing) + + + +For more details, refer to [Working with the code](https://github.com/xorbitsai/xorbits/blob/main/doc/source/development/contributing.rst#working-with-the-code). + +If the speed of `git clone` is slow, you can use the following command to add a proxy: + +``` +export https_proxy=YourProxyAddress +``` + + + +## [Setting up Conda environment](#Setup-for-Contributing) + + + +Before formally installing Xinference, it's recommended to create a new Conda environment for ease of subsequent operations. + +If you're using a campus-level public computing cloud platform, the setup command is as follows: + +``` +mkdir /fs/fast/ustudentID/envs +conda create --prefix=/fs/fast/ustudentID/envs/xinf +conda activate /fs/fast/ustudentID/envs/xinf +``` + +The `studentID` needs to be replaced with the corresponding student ID of your server account, and `xinf` can be replaced with a custom Conda environment name. + +Afterward, you'll need to install Python and npm in the newly created Conda environment. Here are the commands: + +``` +conda install python=3.10 +conda install nodejs +``` + + + +## [Xinference Installation](#Setup-for-Contributing) + + + +For more details, refer to [installation](https://inference.readthedocs.io/zh-cn/latest/getting_started/installation.html)。 + +After the initial installation of Xinference, you need to run the following commands in the `/inference/` directory to check if it can run properly: + +``` +pip install -e . +xinference-local +``` + +If errors occur or the process freezes during execution, the next step is to compile the frontend, refer to [Frontend-Compilation](#Frontend-Compilation). + + +If errors occur or the process freezes during execution, the next step is to compile the frontend. + +If the commands run successfully, you can use Xinference normally. For detailed usage instructions, refer to [using_xinference](https://inference.readthedocs.io/zh-cn/latest/getting_started/using_xinference.html). + + + +## [Frontend Compilation](#Setup-for-Contributing) + + + +Firstly, navigate to the `/inference/xinference/web/ui` directory. If the `/node_modules/` folder already exists in this directory, it's recommended to manually delete it. Then, execute the following command to clear the cache: + +``` +npm cache clean +``` + +Next, execute the following command in this directory to compile the frontend: + +``` +npm install +npm run build +``` + + +After compiling the frontend, you can retry running Xinference. + +At this point, all the necessary environment setup for development has been completed. diff --git "a/doc/source/development/\345\274\200\345\217\221\347\216\257\345\242\203\346\220\255\345\273\272.md" "b/doc/source/development/\345\274\200\345\217\221\347\216\257\345\242\203\346\220\255\345\273\272.md" new file mode 100644 index 0000000000..c27a533678 --- /dev/null +++ "b/doc/source/development/\345\274\200\345\217\221\347\216\257\345\242\203\346\220\255\345\273\272.md" @@ -0,0 +1,91 @@ +# 开发环境搭建 + + + +目录: + +- [Git的准备](#Git的准备) +- [conda环境搭建](#conda环境搭建) +- [Xinference安装](#Xinference安装) +- [前端编译](#前端编译) + + + +## [Git的准备](#开发环境搭建) + + + +详情见[Working with the code](https://github.com/xorbitsai/xorbits/blob/main/doc/source/development/contributing.rst#working-with-the-code)。 + +如果在`git clone`代码的时候速度较慢,可以通过如下命令添加代理: + +``` +export https_proxy=代理地址 +``` + + + +## [conda环境搭建](#开发环境搭建) + + + +在正式安装Xinference之前,最好新建一个conda环境方便后续操作。 + +如果是在校级公共计算云平台,搭建命令如下: + +``` +mkdir /fs/fast/u学号/envs +conda create --prefix=/fs/fast/u学号/envs/xinf +conda activate /fs/fast/u学号/envs/xinf +``` + +其中`学号`需要替换成对应服务器账号的学号,`xinf`可替换为自定义的conda环境名。 + +随后需要在新建的conda环境中安装python以及npm。命令如下: + +``` +conda install python=3.10 +conda install nodejs +``` + + + +## [Xinference安装](#开发环境搭建) + + + +详情见[入门教程](https://inference.readthedocs.io/zh-cn/latest/getting_started/installation.html)。 + +Xinference初步安装完成后需要在`/inference/`目录下运行以下命令检查是否能正常运行。 + +``` +pip install -e . +xinference-local +``` + +如果出现报错或者在运行过程中卡死,那就需要进行下一步[前端编译](#前端编译)。 + +如果能够成功运行命令就能正常使用Xinference了,使用教程详情见[入门教程](https://inference.readthedocs.io/zh-cn/latest/getting_started/using_xinference.html)。 + + + +## [前端编译](#开发环境搭建) + + + +首先需要进入`/inference/xinference/web/ui`目录下,如果该目录下已经存在`/node_modules/`文件夹的话建议先手动删除该文件夹,随后执行如下命令清除缓存: + +``` +npm cache clean +``` + +接着在该目录下执行以下命令进行前端编译: + +``` +npm install +npm run build +``` + +前端编译完以后就可以重新尝试运行Xinference。 + +至此,开发的环境准备工作已经全部完成。 From 147e819a566b2e37fde10ef157bc96d0b2b2e1d5 Mon Sep 17 00:00:00 2001 From: Ago327 Date: Mon, 11 Mar 2024 09:20:05 +0800 Subject: [PATCH 02/18] doc_development --- .../development/Setup for Contributing.md | 95 ------------------- ...57\345\242\203\346\220\255\345\273\272.md" | 91 ------------------ 2 files changed, 186 deletions(-) delete mode 100644 doc/source/development/Setup for Contributing.md delete mode 100644 "doc/source/development/\345\274\200\345\217\221\347\216\257\345\242\203\346\220\255\345\273\272.md" diff --git a/doc/source/development/Setup for Contributing.md b/doc/source/development/Setup for Contributing.md deleted file mode 100644 index cfae9e3fce..0000000000 --- a/doc/source/development/Setup for Contributing.md +++ /dev/null @@ -1,95 +0,0 @@ -# Setup for Contributing - - - -Table of contents: - -- [Getting startted with Git](#Getting-startted-with-Git) -- [Setting up Conda environment](#Setting-up-Conda-environment) -- [Xinference Installation](#Xinference-Installation) -- [Frontend Compilation](#Frontend-Compilation) - - - -## [Getting startted with Git](#Setup-for-Contributing) - - - -For more details, refer to [Working with the code](https://github.com/xorbitsai/xorbits/blob/main/doc/source/development/contributing.rst#working-with-the-code). - -If the speed of `git clone` is slow, you can use the following command to add a proxy: - -``` -export https_proxy=YourProxyAddress -``` - - - -## [Setting up Conda environment](#Setup-for-Contributing) - - - -Before formally installing Xinference, it's recommended to create a new Conda environment for ease of subsequent operations. - -If you're using a campus-level public computing cloud platform, the setup command is as follows: - -``` -mkdir /fs/fast/ustudentID/envs -conda create --prefix=/fs/fast/ustudentID/envs/xinf -conda activate /fs/fast/ustudentID/envs/xinf -``` - -The `studentID` needs to be replaced with the corresponding student ID of your server account, and `xinf` can be replaced with a custom Conda environment name. - -Afterward, you'll need to install Python and npm in the newly created Conda environment. Here are the commands: - -``` -conda install python=3.10 -conda install nodejs -``` - - - -## [Xinference Installation](#Setup-for-Contributing) - - - -For more details, refer to [installation](https://inference.readthedocs.io/zh-cn/latest/getting_started/installation.html)。 - -After the initial installation of Xinference, you need to run the following commands in the `/inference/` directory to check if it can run properly: - -``` -pip install -e . -xinference-local -``` - -If errors occur or the process freezes during execution, the next step is to compile the frontend, refer to [Frontend-Compilation](#Frontend-Compilation). - - -If errors occur or the process freezes during execution, the next step is to compile the frontend. - -If the commands run successfully, you can use Xinference normally. For detailed usage instructions, refer to [using_xinference](https://inference.readthedocs.io/zh-cn/latest/getting_started/using_xinference.html). - - - -## [Frontend Compilation](#Setup-for-Contributing) - - - -Firstly, navigate to the `/inference/xinference/web/ui` directory. If the `/node_modules/` folder already exists in this directory, it's recommended to manually delete it. Then, execute the following command to clear the cache: - -``` -npm cache clean -``` - -Next, execute the following command in this directory to compile the frontend: - -``` -npm install -npm run build -``` - - -After compiling the frontend, you can retry running Xinference. - -At this point, all the necessary environment setup for development has been completed. diff --git "a/doc/source/development/\345\274\200\345\217\221\347\216\257\345\242\203\346\220\255\345\273\272.md" "b/doc/source/development/\345\274\200\345\217\221\347\216\257\345\242\203\346\220\255\345\273\272.md" deleted file mode 100644 index c27a533678..0000000000 --- "a/doc/source/development/\345\274\200\345\217\221\347\216\257\345\242\203\346\220\255\345\273\272.md" +++ /dev/null @@ -1,91 +0,0 @@ -# 开发环境搭建 - - - -目录: - -- [Git的准备](#Git的准备) -- [conda环境搭建](#conda环境搭建) -- [Xinference安装](#Xinference安装) -- [前端编译](#前端编译) - - - -## [Git的准备](#开发环境搭建) - - - -详情见[Working with the code](https://github.com/xorbitsai/xorbits/blob/main/doc/source/development/contributing.rst#working-with-the-code)。 - -如果在`git clone`代码的时候速度较慢,可以通过如下命令添加代理: - -``` -export https_proxy=代理地址 -``` - - - -## [conda环境搭建](#开发环境搭建) - - - -在正式安装Xinference之前,最好新建一个conda环境方便后续操作。 - -如果是在校级公共计算云平台,搭建命令如下: - -``` -mkdir /fs/fast/u学号/envs -conda create --prefix=/fs/fast/u学号/envs/xinf -conda activate /fs/fast/u学号/envs/xinf -``` - -其中`学号`需要替换成对应服务器账号的学号,`xinf`可替换为自定义的conda环境名。 - -随后需要在新建的conda环境中安装python以及npm。命令如下: - -``` -conda install python=3.10 -conda install nodejs -``` - - - -## [Xinference安装](#开发环境搭建) - - - -详情见[入门教程](https://inference.readthedocs.io/zh-cn/latest/getting_started/installation.html)。 - -Xinference初步安装完成后需要在`/inference/`目录下运行以下命令检查是否能正常运行。 - -``` -pip install -e . -xinference-local -``` - -如果出现报错或者在运行过程中卡死,那就需要进行下一步[前端编译](#前端编译)。 - -如果能够成功运行命令就能正常使用Xinference了,使用教程详情见[入门教程](https://inference.readthedocs.io/zh-cn/latest/getting_started/using_xinference.html)。 - - - -## [前端编译](#开发环境搭建) - - - -首先需要进入`/inference/xinference/web/ui`目录下,如果该目录下已经存在`/node_modules/`文件夹的话建议先手动删除该文件夹,随后执行如下命令清除缓存: - -``` -npm cache clean -``` - -接着在该目录下执行以下命令进行前端编译: - -``` -npm install -npm run build -``` - -前端编译完以后就可以重新尝试运行Xinference。 - -至此,开发的环境准备工作已经全部完成。 From 829850406f3025a63d6fc2149f475b42fc5b9b1d Mon Sep 17 00:00:00 2001 From: Ago327 Date: Tue, 19 Mar 2024 16:48:12 +0800 Subject: [PATCH 03/18] init api-key check --- xinference/api/oauth2/auth_service.py | 21 +++++++++++++++++++++ xinference/api/oauth2/types.py | 1 + xinference/client/restful/restful_client.py | 4 ++-- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/xinference/api/oauth2/auth_service.py b/xinference/api/oauth2/auth_service.py index 2f11c26f18..124344b406 100644 --- a/xinference/api/oauth2/auth_service.py +++ b/xinference/api/oauth2/auth_service.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re from datetime import timedelta from typing import List, Optional @@ -40,13 +41,33 @@ def __init__(self, auth_config_file: Optional[str]): def config(self): return self._config + @staticmethod + def is_legal_apikey(key: str): + pattern = re.compile("^[sk]{2}-[a-zA-Z0-9]{48}$") + if re.match(pattern, key): + return True + else: + return False + def init_auth_config(self): if self._auth_config_file: config: AuthStartupConfig = parse_file_as( path=self._auth_config_file, type_=AuthStartupConfig ) + total_keys = set() for user in config.user_config: user.password = get_password_hash(user.password) + if len(set(user.api_keys)) != len(user.api_keys): + raise ValueError("User has duplicate Api-Keys") + for api_key in user.api_keys: + if not self.is_legal_apikey(api_key): + raise ValueError( + "Api-Key should be a string started with 'sk-' with a total length of 51" + ) + if api_key in total_keys: + raise ValueError("Api-Keys of different users have conflict") + else: + total_keys.add(api_key) return config def __call__( diff --git a/xinference/api/oauth2/types.py b/xinference/api/oauth2/types.py index 106680deac..deb5740a19 100644 --- a/xinference/api/oauth2/types.py +++ b/xinference/api/oauth2/types.py @@ -23,6 +23,7 @@ class LoginUserForm(BaseModel): class User(LoginUserForm): permissions: List[str] + api_keys: List[str] class AuthConfig(BaseModel): diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index ca5d8ef0a3..270765faeb 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -651,9 +651,9 @@ def translations( class Client: - def __init__(self, base_url): + def __init__(self, base_url, api_key: Optional[str]): self.base_url = base_url - self._headers = {} + self._headers: Dict[str, str] = {} self._cluster_authed = False self._check_cluster_authenticated() From b299cb5c12a65e0d68ad7a1fe7eaf825ab8526b3 Mon Sep 17 00:00:00 2001 From: Ago327 Date: Tue, 19 Mar 2024 17:02:18 +0800 Subject: [PATCH 04/18] set api-key non-positional --- xinference/client/restful/restful_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index 270765faeb..8f4124e742 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -651,7 +651,7 @@ def translations( class Client: - def __init__(self, base_url, api_key: Optional[str]): + def __init__(self, base_url, api_key: Optional[str] = None): self.base_url = base_url self._headers: Dict[str, str] = {} self._cluster_authed = False From 9c35307c6ce63859bf56360c440bc3ef65f98a2b Mon Sep 17 00:00:00 2001 From: Ago327 Date: Thu, 21 Mar 2024 10:18:42 +0800 Subject: [PATCH 05/18] client with api-key --- .../development/contributing_environment.rst | 2 +- .../development/contributing_environment.po | 12 ++++--- xinference/api/oauth2/auth_service.py | 31 +++++++++++++++++-- xinference/api/restful_api.py | 12 +++++++ xinference/client/restful/restful_client.py | 18 +++++++++++ 5 files changed, 68 insertions(+), 7 deletions(-) diff --git a/doc/source/development/contributing_environment.rst b/doc/source/development/contributing_environment.rst index 80aa266dfb..25ee11cc71 100644 --- a/doc/source/development/contributing_environment.rst +++ b/doc/source/development/contributing_environment.rst @@ -8,7 +8,7 @@ Creating a development environment Before proceeding with any code modifications, it's essential to set up the necessary environment for Xinference development, which includes familiarizing yourself with Git usage, establishing an isolated environment, installing Xinference, and compiling the frontend. -Getting startted with Git +Getting started with Git ------------------------- Now that you have identified an issue you wish to resolve, an enhancement to incorporate, or documentation to enhance, diff --git a/doc/source/locale/zh_CN/LC_MESSAGES/development/contributing_environment.po b/doc/source/locale/zh_CN/LC_MESSAGES/development/contributing_environment.po index 8ced444523..ac2f0f4917 100644 --- a/doc/source/locale/zh_CN/LC_MESSAGES/development/contributing_environment.po +++ b/doc/source/locale/zh_CN/LC_MESSAGES/development/contributing_environment.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: Xinference \n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2024-03-06 16:29+0800\n" +"POT-Creation-Date: 2024-03-21 09:59+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language: zh_CN\n" @@ -38,7 +38,8 @@ msgstr "" "Xinference 以及前端部分的编译。" #: ../../source/development/contributing_environment.rst:12 -msgid "Getting startted with Git" +#, fuzzy +msgid "Getting started with Git" msgstr "Git 的使用" #: ../../source/development/contributing_environment.rst:14 @@ -151,7 +152,8 @@ msgid "" "`__." msgstr "" "如果命令能够成功运行,接下来就能正常使用 Xinference 了,使用教程详情见 `使用 " -"`__。" +"`__。" #: ../../source/development/contributing_environment.rst:83 msgid "" @@ -198,5 +200,7 @@ msgid "" "After compiling the frontend, you can ``cd`` back to the directory where " "the ``setup.cfg`` and ``setup.py`` files are located, and install " "Xinference via ``pip install -e .``." -msgstr "编译完前端后,您可以返回到包含 ``setup.cfg`` 和 ``setup.py`` 文件的目录,然后通过 ``pip install -e .`` 安装 Xinference。" +msgstr "" +"编译完前端后,您可以返回到包含 ``setup.cfg`` 和 ``setup.py`` 文件的目录,然后通过 ``pip install -e " +".`` 安装 Xinference。" diff --git a/xinference/api/oauth2/auth_service.py b/xinference/api/oauth2/auth_service.py index 124344b406..c92ff31aca 100644 --- a/xinference/api/oauth2/auth_service.py +++ b/xinference/api/oauth2/auth_service.py @@ -42,7 +42,7 @@ def config(self): return self._config @staticmethod - def is_legal_apikey(key: str): + def is_legal_api_key(key: str): pattern = re.compile("^[sk]{2}-[a-zA-Z0-9]{48}$") if re.match(pattern, key): return True @@ -60,7 +60,7 @@ def init_auth_config(self): if len(set(user.api_keys)) != len(user.api_keys): raise ValueError("User has duplicate Api-Keys") for api_key in user.api_keys: - if not self.is_legal_apikey(api_key): + if not self.is_legal_api_key(api_key): raise ValueError( "Api-Key should be a string started with 'sk-' with a total length of 51" ) @@ -123,6 +123,13 @@ def get_user(self, username: str) -> Optional[User]: return user return None + def get_user_with_api_key(self, api_key: str) -> Optional[User]: + for user in self._config.user_config: + for key in user.api_keys: + if api_key == key: + return user + return None + def authenticate_user(self, username: str, password: str): user = self.get_user(username) if not user: @@ -150,3 +157,23 @@ def generate_token_for_user(self, username: str, password: str): expires_delta=access_token_expires, ) return {"access_token": access_token, "token_type": "bearer"} + + def generate_token_with_api_key_for_user(self, api_key: str): + user = self.get_user_with_api_key(api_key) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect api-key", + headers={"WWW-Authenticate": "Bearer"}, + ) + assert user is not None and isinstance(user, User) + access_token_expires = timedelta( + minutes=self._config.auth_config.token_expire_in_minutes + ) + access_token = create_access_token( + data={"sub": user.username, "scopes": user.permissions}, + secret_key=self._config.auth_config.secret_key, + algorithm=self._config.auth_config.algorithm, + expires_delta=access_token_expires, + ) + return {"access_token": access_token, "token_type": "bearer"} diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index 4c7098a95b..de093becd6 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -208,6 +208,15 @@ async def login_for_access_token(self, request: Request) -> JSONResponse: ) return JSONResponse(content=result) + async def login_with_api_key_for_access_token( + self, request: Request + ) -> JSONResponse: + form_data = await request.json() + result = self._auth_service.generate_token_with_api_key_for_user( + form_data["api_key"] + ) + return JSONResponse(content=result) + async def is_cluster_authenticated(self) -> JSONResponse: return JSONResponse(content={"auth": self.is_authenticated()}) @@ -269,6 +278,9 @@ def serve(self, logging_conf: Optional[dict] = None): self._router.add_api_route( "/token", self.login_for_access_token, methods=["POST"] ) + self._router.add_api_route( + "/token/api_key", self.login_with_api_key_for_access_token, methods=["POST"] + ) self._router.add_api_route( "/v1/cluster/auth", self.is_cluster_authenticated, methods=["GET"] ) diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index 8f4124e742..f11db0fd17 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -656,6 +656,8 @@ def __init__(self, base_url, api_key: Optional[str] = None): self._headers: Dict[str, str] = {} self._cluster_authed = False self._check_cluster_authenticated() + if api_key is not None: + self.login_with_api_key(api_key) def _set_token(self, token: Optional[str]): if not self._cluster_authed or token is None: @@ -712,6 +714,22 @@ def login(self, username: str, password: str): access_token = response_data["access_token"] self._headers["Authorization"] = f"Bearer {access_token}" + def login_with_api_key(self, api_key: str): + if not self._cluster_authed: + return + url = f"{self.base_url}/token/api_key" + + payload = {"api_key": api_key} + + response = requests.post(url, json=payload) + if response.status_code != 200: + raise RuntimeError(f"Failed to login, detail: {response.json()['detail']}") + + response_data = response.json() + # Only bearer token for now + access_token = response_data["access_token"] + self._headers["Authorization"] = f"Bearer {access_token}" + def list_models(self) -> Dict[str, Dict[str, Any]]: """ Retrieve the model specifications from the Server. From 09f6f3189d445a1ce4b419fd60002ea1e967fdd8 Mon Sep 17 00:00:00 2001 From: Ago327 Date: Thu, 21 Mar 2024 10:26:36 +0800 Subject: [PATCH 06/18] fix doc --- .../LC_MESSAGES/development/contributing_environment.po | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/doc/source/locale/zh_CN/LC_MESSAGES/development/contributing_environment.po b/doc/source/locale/zh_CN/LC_MESSAGES/development/contributing_environment.po index ac2f0f4917..7aa3765a32 100644 --- a/doc/source/locale/zh_CN/LC_MESSAGES/development/contributing_environment.po +++ b/doc/source/locale/zh_CN/LC_MESSAGES/development/contributing_environment.po @@ -38,7 +38,6 @@ msgstr "" "Xinference 以及前端部分的编译。" #: ../../source/development/contributing_environment.rst:12 -#, fuzzy msgid "Getting started with Git" msgstr "Git 的使用" @@ -152,8 +151,7 @@ msgid "" "`__." msgstr "" "如果命令能够成功运行,接下来就能正常使用 Xinference 了,使用教程详情见 `使用 " -"`__。" +"`__。" #: ../../source/development/contributing_environment.rst:83 msgid "" @@ -200,7 +198,5 @@ msgid "" "After compiling the frontend, you can ``cd`` back to the directory where " "the ``setup.cfg`` and ``setup.py`` files are located, and install " "Xinference via ``pip install -e .``." -msgstr "" -"编译完前端后,您可以返回到包含 ``setup.cfg`` 和 ``setup.py`` 文件的目录,然后通过 ``pip install -e " -".`` 安装 Xinference。" +msgstr "编译完前端后,您可以返回到包含 ``setup.cfg`` 和 ``setup.py`` 文件的目录,然后通过 ``pip install -e .`` 安装 Xinference。" From 6dd83acdcacc955d727229baa98df813c4f60e4b Mon Sep 17 00:00:00 2001 From: Ago327 Date: Mon, 25 Mar 2024 16:27:58 +0800 Subject: [PATCH 07/18] compatible with both client and curl --- xinference/api/oauth2/auth_service.py | 56 +++++++++------------ xinference/api/restful_api.py | 12 ----- xinference/client/restful/restful_client.py | 18 +------ 3 files changed, 25 insertions(+), 61 deletions(-) diff --git a/xinference/api/oauth2/auth_service.py b/xinference/api/oauth2/auth_service.py index c92ff31aca..117c55faac 100644 --- a/xinference/api/oauth2/auth_service.py +++ b/xinference/api/oauth2/auth_service.py @@ -78,6 +78,7 @@ def __call__( """ Advanced dependencies. See: https://fastapi.tiangolo.com/advanced/advanced-dependencies/ """ + print(f"DEBUG: Enter __call__ with token {token}") if security_scopes.scopes: authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' else: @@ -88,22 +89,33 @@ def __call__( headers={"WWW-Authenticate": authenticate_value}, ) + through_api_key = False + try: assert self._config is not None - payload = jwt.decode( - token, - self._config.auth_config.secret_key, - algorithms=[self._config.auth_config.algorithm], - options={"verify_exp": False}, # TODO: supports token expiration - ) - username: str = payload.get("sub") - if username is None: - raise credentials_exception - token_scopes = payload.get("scopes", []) - token_data = TokenData(scopes=token_scopes, username=username) + if self.is_legal_api_key(token): + through_api_key = True + else: + payload = jwt.decode( + token, + self._config.auth_config.secret_key, + algorithms=[self._config.auth_config.algorithm], + options={"verify_exp": False}, # TODO: supports token expiration + ) + username: str = payload.get("sub") + if username is None: + raise credentials_exception + token_scopes = payload.get("scopes", []) + token_data = TokenData(scopes=token_scopes, username=username) except (JWTError, ValidationError): raise credentials_exception - user = self.get_user(token_data.username) + if not through_api_key: + user = self.get_user(token_data.username) + else: + user = self.get_user_with_api_key(token) + if user is None: + raise credentials_exception + token_data = TokenData(scopes=user.permissions, username=user.username) if user is None: raise credentials_exception if "admin" in token_data.scopes: @@ -157,23 +169,3 @@ def generate_token_for_user(self, username: str, password: str): expires_delta=access_token_expires, ) return {"access_token": access_token, "token_type": "bearer"} - - def generate_token_with_api_key_for_user(self, api_key: str): - user = self.get_user_with_api_key(api_key) - if not user: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect api-key", - headers={"WWW-Authenticate": "Bearer"}, - ) - assert user is not None and isinstance(user, User) - access_token_expires = timedelta( - minutes=self._config.auth_config.token_expire_in_minutes - ) - access_token = create_access_token( - data={"sub": user.username, "scopes": user.permissions}, - secret_key=self._config.auth_config.secret_key, - algorithm=self._config.auth_config.algorithm, - expires_delta=access_token_expires, - ) - return {"access_token": access_token, "token_type": "bearer"} diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index de093becd6..4c7098a95b 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -208,15 +208,6 @@ async def login_for_access_token(self, request: Request) -> JSONResponse: ) return JSONResponse(content=result) - async def login_with_api_key_for_access_token( - self, request: Request - ) -> JSONResponse: - form_data = await request.json() - result = self._auth_service.generate_token_with_api_key_for_user( - form_data["api_key"] - ) - return JSONResponse(content=result) - async def is_cluster_authenticated(self) -> JSONResponse: return JSONResponse(content={"auth": self.is_authenticated()}) @@ -278,9 +269,6 @@ def serve(self, logging_conf: Optional[dict] = None): self._router.add_api_route( "/token", self.login_for_access_token, methods=["POST"] ) - self._router.add_api_route( - "/token/api_key", self.login_with_api_key_for_access_token, methods=["POST"] - ) self._router.add_api_route( "/v1/cluster/auth", self.is_cluster_authenticated, methods=["GET"] ) diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index f11db0fd17..0e6c00f23a 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -657,7 +657,7 @@ def __init__(self, base_url, api_key: Optional[str] = None): self._cluster_authed = False self._check_cluster_authenticated() if api_key is not None: - self.login_with_api_key(api_key) + self._headers["Authorization"] = f"Bearer {api_key}" def _set_token(self, token: Optional[str]): if not self._cluster_authed or token is None: @@ -714,22 +714,6 @@ def login(self, username: str, password: str): access_token = response_data["access_token"] self._headers["Authorization"] = f"Bearer {access_token}" - def login_with_api_key(self, api_key: str): - if not self._cluster_authed: - return - url = f"{self.base_url}/token/api_key" - - payload = {"api_key": api_key} - - response = requests.post(url, json=payload) - if response.status_code != 200: - raise RuntimeError(f"Failed to login, detail: {response.json()['detail']}") - - response_data = response.json() - # Only bearer token for now - access_token = response_data["access_token"] - self._headers["Authorization"] = f"Bearer {access_token}" - def list_models(self) -> Dict[str, Dict[str, Any]]: """ Retrieve the model specifications from the Server. From d6ee4997d8cdcdfb215878cdcdbe53b3bb5bcccb Mon Sep 17 00:00:00 2001 From: Ago327 Date: Mon, 25 Mar 2024 17:07:46 +0800 Subject: [PATCH 08/18] compatible with cmdline --- xinference/deploy/cmdline.py | 116 +++++++++++++++++++++++++++++------ 1 file changed, 97 insertions(+), 19 deletions(-) diff --git a/xinference/deploy/cmdline.py b/xinference/deploy/cmdline.py index df620023e5..b622349786 100644 --- a/xinference/deploy/cmdline.py +++ b/xinference/deploy/cmdline.py @@ -376,17 +376,25 @@ def worker( is_flag=True, help="Persist the model configuration to the filesystem, retains the model registration after server restarts.", ) +@click.option( + "--api-key", + "-ak", + default=None, + type=str, + help="Api-Key for access xinference api with authorization.", +) def register_model( endpoint: Optional[str], model_type: str, file: str, persist: bool, + api_key: Optional[str], ): endpoint = get_endpoint(endpoint) with open(file) as fd: model = fd.read() - client = RESTfulClient(base_url=endpoint) + client = RESTfulClient(base_url=endpoint, api_key=api_key) client._set_token(get_stored_token(endpoint, client)) client.register_model( model_type=model_type, @@ -408,15 +416,24 @@ def register_model( help="Type of model to unregister (default is 'LLM').", ) @click.option("--model-name", "-n", type=str, help="Name of the model to unregister.") +@click.option( + "--api-key", + "-ak", + default=None, + type=str, + help="Api-Key for access xinference api with authorization.", +) def unregister_model( endpoint: Optional[str], model_type: str, model_name: str, + api_key: Optional[str], ): endpoint = get_endpoint(endpoint) - client = RESTfulClient(base_url=endpoint) - client._set_token(get_stored_token(endpoint, client)) + client = RESTfulClient(base_url=endpoint, api_key=api_key) + if client._get_token() is None: + client._set_token(get_stored_token(endpoint, client)) client.unregister_model( model_type=model_type, model_name=model_name, @@ -437,15 +454,24 @@ def unregister_model( type=str, help="Filter by model type (default is 'LLM').", ) +@click.option( + "--api-key", + "-ak", + default=None, + type=str, + help="Api-Key for access xinference api with authorization.", +) def list_model_registrations( endpoint: Optional[str], model_type: str, + api_key: Optional[str], ): from tabulate import tabulate endpoint = get_endpoint(endpoint) - client = RESTfulClient(base_url=endpoint) - client._set_token(get_stored_token(endpoint, client)) + client = RESTfulClient(base_url=endpoint, api_key=api_key) + if client._get_token() is None: + client._set_token(get_stored_token(endpoint, client)) registrations = client.list_model_registrations(model_type=model_type) @@ -638,6 +664,13 @@ def list_model_registrations( type=bool, help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", ) +@click.option( + "--api-key", + "-ak", + default=None, + type=str, + help="Api-Key for access xinference api with authorization.", +) @click.pass_context def model_launch( ctx, @@ -654,6 +687,7 @@ def model_launch( image_lora_load_kwargs: Optional[Tuple], image_lora_fuse_kwargs: Optional[Tuple], trust_remote_code: bool, + api_key: Optional[str], ): kwargs = {} for i in range(0, len(ctx.args), 2): @@ -686,8 +720,9 @@ def model_launch( if size_in_billions is None or "_" in size_in_billions else int(size_in_billions) ) - client = RESTfulClient(base_url=endpoint) - client._set_token(get_stored_token(endpoint, client)) + client = RESTfulClient(base_url=endpoint, api_key=api_key) + if client._get_token() is None: + client._set_token(get_stored_token(endpoint, client)) model_uid = client.launch_model( model_name=model_name, @@ -718,12 +753,20 @@ def model_launch( type=str, help="Xinference endpoint.", ) -def model_list(endpoint: Optional[str]): +@click.option( + "--api-key", + "-ak", + default=None, + type=str, + help="Api-Key for access xinference api with authorization.", +) +def model_list(endpoint: Optional[str], api_key: Optional[str]): from tabulate import tabulate endpoint = get_endpoint(endpoint) - client = RESTfulClient(base_url=endpoint) - client._set_token(get_stored_token(endpoint, client)) + client = RESTfulClient(base_url=endpoint, api_key=api_key) + if client._get_token() is None: + client._set_token(get_stored_token(endpoint, client)) llm_table = [] embedding_table = [] @@ -844,13 +887,22 @@ def model_list(endpoint: Optional[str]): required=True, help="The unique identifier (UID) of the model.", ) +@click.option( + "--api-key", + "-ak", + default=None, + type=str, + help="Api-Key for access xinference api with authorization.", +) def model_terminate( endpoint: Optional[str], model_uid: str, + api_key: Optional[str], ): endpoint = get_endpoint(endpoint) - client = RESTfulClient(base_url=endpoint) - client._set_token(get_stored_token(endpoint, client)) + client = RESTfulClient(base_url=endpoint, api_key=api_key) + if client._get_token() is None: + client._set_token(get_stored_token(endpoint, client)) client.terminate_model(model_uid=model_uid) @@ -873,15 +925,24 @@ def model_terminate( type=bool, help="Whether to stream the generated text. Use 'True' for streaming (default is True).", ) +@click.option( + "--api-key", + "-ak", + default=None, + type=str, + help="Api-Key for access xinference api with authorization.", +) def model_generate( endpoint: Optional[str], model_uid: str, max_tokens: int, stream: bool, + api_key: Optional[str], ): endpoint = get_endpoint(endpoint) - client = RESTfulClient(base_url=endpoint) - client._set_token(get_stored_token(endpoint, client)) + client = RESTfulClient(base_url=endpoint, api_key=api_key) + if client._get_token() is None: + client._set_token(get_stored_token(endpoint, client)) if stream: # TODO: when stream=True, RestfulClient cannot generate words one by one. # So use Client in temporary. The implementation needs to be changed to @@ -959,16 +1020,25 @@ async def generate_internal(): type=bool, help="Whether to stream the chat messages. Use 'True' for streaming (default is True).", ) +@click.option( + "--api-key", + "-ak", + default=None, + type=str, + help="Api-Key for access xinference api with authorization.", +) def model_chat( endpoint: Optional[str], model_uid: str, max_tokens: int, stream: bool, + api_key: Optional[str], ): # TODO: chat model roles may not be user and assistant. endpoint = get_endpoint(endpoint) - client = RESTfulClient(base_url=endpoint) - client._set_token(get_stored_token(endpoint, client)) + client = RESTfulClient(base_url=endpoint, api_key=api_key) + if client._get_token() is None: + client._set_token(get_stored_token(endpoint, client)) chat_history: "List[ChatCompletionMessage]" = [] if stream: @@ -1048,10 +1118,18 @@ async def chat_internal(): @cli.command("vllm-models", help="Query and display models compatible with vLLM.") @click.option("--endpoint", "-e", type=str, help="Xinference endpoint.") -def vllm_models(endpoint: Optional[str]): +@click.option( + "--api-key", + "-ak", + default=None, + type=str, + help="Api-Key for access xinference api with authorization.", +) +def vllm_models(endpoint: Optional[str], api_key: Optional[str]): endpoint = get_endpoint(endpoint) - client = RESTfulClient(base_url=endpoint) - client._set_token(get_stored_token(endpoint, client)) + client = RESTfulClient(base_url=endpoint, api_key=api_key) + if client._get_token() is None: + client._set_token(get_stored_token(endpoint, client)) vllm_models_dict = client.vllm_models() print("VLLM supported model families:") chat_models = vllm_models_dict["chat"] From b4cb69f91bd46126725a97fe52807650905af38a Mon Sep 17 00:00:00 2001 From: Ago327 Date: Mon, 25 Mar 2024 17:11:51 +0800 Subject: [PATCH 09/18] fix debug output --- xinference/api/oauth2/auth_service.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xinference/api/oauth2/auth_service.py b/xinference/api/oauth2/auth_service.py index 117c55faac..2dfa5a474b 100644 --- a/xinference/api/oauth2/auth_service.py +++ b/xinference/api/oauth2/auth_service.py @@ -78,7 +78,6 @@ def __call__( """ Advanced dependencies. See: https://fastapi.tiangolo.com/advanced/advanced-dependencies/ """ - print(f"DEBUG: Enter __call__ with token {token}") if security_scopes.scopes: authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' else: From f3ada96298b9d5f98f25547857e3b35452b5480a Mon Sep 17 00:00:00 2001 From: Ago327 Date: Mon, 25 Mar 2024 17:27:08 +0800 Subject: [PATCH 10/18] bug fix --- xinference/deploy/cmdline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xinference/deploy/cmdline.py b/xinference/deploy/cmdline.py index b622349786..5e27fe15c1 100644 --- a/xinference/deploy/cmdline.py +++ b/xinference/deploy/cmdline.py @@ -395,7 +395,8 @@ def register_model( model = fd.read() client = RESTfulClient(base_url=endpoint, api_key=api_key) - client._set_token(get_stored_token(endpoint, client)) + if client._get_token() is None: + client._set_token(get_stored_token(endpoint, client)) client.register_model( model_type=model_type, model=model, From 750ae2b7368c7e31f43c3c28056fe77d4083300f Mon Sep 17 00:00:00 2001 From: Ago327 Date: Mon, 25 Mar 2024 19:17:13 +0800 Subject: [PATCH 11/18] fix --- xinference/api/oauth2/auth_service.py | 53 +++++++++------------ xinference/client/restful/restful_client.py | 2 +- xinference/deploy/cmdline.py | 18 +++---- 3 files changed, 33 insertions(+), 40 deletions(-) diff --git a/xinference/api/oauth2/auth_service.py b/xinference/api/oauth2/auth_service.py index 2dfa5a474b..798e6019cd 100644 --- a/xinference/api/oauth2/auth_service.py +++ b/xinference/api/oauth2/auth_service.py @@ -13,7 +13,7 @@ # limitations under the License. import re from datetime import timedelta -from typing import List, Optional +from typing import List, Optional, Tuple from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer, SecurityScopes @@ -43,7 +43,7 @@ def config(self): @staticmethod def is_legal_api_key(key: str): - pattern = re.compile("^[sk]{2}-[a-zA-Z0-9]{48}$") + pattern = re.compile("^[sk]{2}-[a-zA-Z0-9]{13}$") if re.match(pattern, key): return True else: @@ -54,20 +54,20 @@ def init_auth_config(self): config: AuthStartupConfig = parse_file_as( path=self._auth_config_file, type_=AuthStartupConfig ) - total_keys = set() + all_api_keys = set() for user in config.user_config: user.password = get_password_hash(user.password) - if len(set(user.api_keys)) != len(user.api_keys): - raise ValueError("User has duplicate Api-Keys") for api_key in user.api_keys: if not self.is_legal_api_key(api_key): raise ValueError( - "Api-Key should be a string started with 'sk-' with a total length of 51" + "Api-Key should be a string started with 'sk-' with a total length of 16" + ) + if api_key in all_api_keys: + raise ValueError( + "Duplicate api-keys exists, please check your configuration" ) - if api_key in total_keys: - raise ValueError("Api-Keys of different users have conflict") else: - total_keys.add(api_key) + all_api_keys.add(api_key) return config def __call__( @@ -88,13 +88,11 @@ def __call__( headers={"WWW-Authenticate": authenticate_value}, ) - through_api_key = False - - try: - assert self._config is not None - if self.is_legal_api_key(token): - through_api_key = True - else: + if self.is_legal_api_key(token): + user, token_scopes = self.get_user_and_scopes_with_api_key(token) + else: + try: + assert self._config is not None payload = jwt.decode( token, self._config.auth_config.secret_key, @@ -105,22 +103,15 @@ def __call__( if username is None: raise credentials_exception token_scopes = payload.get("scopes", []) - token_data = TokenData(scopes=token_scopes, username=username) - except (JWTError, ValidationError): - raise credentials_exception - if not through_api_key: - user = self.get_user(token_data.username) - else: - user = self.get_user_with_api_key(token) - if user is None: + user = self.get_user(username) + except (JWTError, ValidationError): raise credentials_exception - token_data = TokenData(scopes=user.permissions, username=user.username) if user is None: raise credentials_exception - if "admin" in token_data.scopes: + if "admin" in token_scopes: return user for scope in security_scopes.scopes: - if scope not in token_data.scopes: + if scope not in token_scopes: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions", @@ -134,12 +125,14 @@ def get_user(self, username: str) -> Optional[User]: return user return None - def get_user_with_api_key(self, api_key: str) -> Optional[User]: + def get_user_and_scopes_with_api_key( + self, api_key: str + ) -> Tuple[Optional[User], List]: for user in self._config.user_config: for key in user.api_keys: if api_key == key: - return user - return None + return user, user.permissions + return None, [] def authenticate_user(self, username: str, password: str): user = self.get_user(username) diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index 0e6c00f23a..f712ff83a2 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -656,7 +656,7 @@ def __init__(self, base_url, api_key: Optional[str] = None): self._headers: Dict[str, str] = {} self._cluster_authed = False self._check_cluster_authenticated() - if api_key is not None: + if api_key is not None and self._cluster_authed: self._headers["Authorization"] = f"Bearer {api_key}" def _set_token(self, token: Optional[str]): diff --git a/xinference/deploy/cmdline.py b/xinference/deploy/cmdline.py index 5e27fe15c1..ca1633598d 100644 --- a/xinference/deploy/cmdline.py +++ b/xinference/deploy/cmdline.py @@ -395,7 +395,7 @@ def register_model( model = fd.read() client = RESTfulClient(base_url=endpoint, api_key=api_key) - if client._get_token() is None: + if api_key is None: client._set_token(get_stored_token(endpoint, client)) client.register_model( model_type=model_type, @@ -433,7 +433,7 @@ def unregister_model( endpoint = get_endpoint(endpoint) client = RESTfulClient(base_url=endpoint, api_key=api_key) - if client._get_token() is None: + if api_key is None: client._set_token(get_stored_token(endpoint, client)) client.unregister_model( model_type=model_type, @@ -471,7 +471,7 @@ def list_model_registrations( endpoint = get_endpoint(endpoint) client = RESTfulClient(base_url=endpoint, api_key=api_key) - if client._get_token() is None: + if api_key is None: client._set_token(get_stored_token(endpoint, client)) registrations = client.list_model_registrations(model_type=model_type) @@ -722,7 +722,7 @@ def model_launch( else int(size_in_billions) ) client = RESTfulClient(base_url=endpoint, api_key=api_key) - if client._get_token() is None: + if api_key is None: client._set_token(get_stored_token(endpoint, client)) model_uid = client.launch_model( @@ -766,7 +766,7 @@ def model_list(endpoint: Optional[str], api_key: Optional[str]): endpoint = get_endpoint(endpoint) client = RESTfulClient(base_url=endpoint, api_key=api_key) - if client._get_token() is None: + if api_key is None: client._set_token(get_stored_token(endpoint, client)) llm_table = [] @@ -902,7 +902,7 @@ def model_terminate( ): endpoint = get_endpoint(endpoint) client = RESTfulClient(base_url=endpoint, api_key=api_key) - if client._get_token() is None: + if api_key is None: client._set_token(get_stored_token(endpoint, client)) client.terminate_model(model_uid=model_uid) @@ -942,7 +942,7 @@ def model_generate( ): endpoint = get_endpoint(endpoint) client = RESTfulClient(base_url=endpoint, api_key=api_key) - if client._get_token() is None: + if api_key is None: client._set_token(get_stored_token(endpoint, client)) if stream: # TODO: when stream=True, RestfulClient cannot generate words one by one. @@ -1038,7 +1038,7 @@ def model_chat( # TODO: chat model roles may not be user and assistant. endpoint = get_endpoint(endpoint) client = RESTfulClient(base_url=endpoint, api_key=api_key) - if client._get_token() is None: + if api_key is None: client._set_token(get_stored_token(endpoint, client)) chat_history: "List[ChatCompletionMessage]" = [] @@ -1129,7 +1129,7 @@ async def chat_internal(): def vllm_models(endpoint: Optional[str], api_key: Optional[str]): endpoint = get_endpoint(endpoint) client = RESTfulClient(base_url=endpoint, api_key=api_key) - if client._get_token() is None: + if api_key is None: client._set_token(get_stored_token(endpoint, client)) vllm_models_dict = client.vllm_models() print("VLLM supported model families:") From 2e0d5942d59ed2a0c40e7f8eed6807334ac784b8 Mon Sep 17 00:00:00 2001 From: Ago327 Date: Tue, 26 Mar 2024 11:19:56 +0800 Subject: [PATCH 12/18] fix test --- xinference/api/oauth2/auth_service.py | 2 +- .../client/tests/test_client_with_auth.py | 29 +++++++++++++++++++ xinference/conftest.py | 15 ++++++++-- 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/xinference/api/oauth2/auth_service.py b/xinference/api/oauth2/auth_service.py index 798e6019cd..aca562305e 100644 --- a/xinference/api/oauth2/auth_service.py +++ b/xinference/api/oauth2/auth_service.py @@ -43,7 +43,7 @@ def config(self): @staticmethod def is_legal_api_key(key: str): - pattern = re.compile("^[sk]{2}-[a-zA-Z0-9]{13}$") + pattern = re.compile("^sk-[a-zA-Z0-9]{13}$") if re.match(pattern, key): return True else: diff --git a/xinference/client/tests/test_client_with_auth.py b/xinference/client/tests/test_client_with_auth.py index d64033dfa6..28857b800e 100644 --- a/xinference/client/tests/test_client_with_auth.py +++ b/xinference/client/tests/test_client_with_auth.py @@ -47,3 +47,32 @@ def test_client_auth(setup_with_auth): assert len(client.list_models()) == 1 client.terminate_model(model_uid=model_uid) assert len(client.list_models()) == 0 + + # test with api-key + client = RESTfulClient(endpoint, api_key="sk-wrongapikey12") + with pytest.raises(RuntimeError): + client.list_models() + + client = RESTfulClient(endpoint, api_key="sk-72tkvudyGLPMi") + assert len(client.list_models()) == 0 + + with pytest.raises(RuntimeError): + client.launch_model(model_name="bge-small-en-v1.5", model_type="embedding") + + client = RESTfulClient(endpoint, api_key="sk-ZOTLIY4gt9w11") + model_uid = client.launch_model( + model_name="bge-small-en-v1.5", model_type="embedding" + ) + model = client.get_model(model_uid=model_uid) + assert isinstance(model, RESTfulEmbeddingModelHandle) + + completion = model.create_embedding("write a poem.") + assert len(completion["data"][0]["embedding"]) == 384 + + with pytest.raises(RuntimeError): + client.terminate_model(model_uid=model_uid) + + client = RESTfulClient(endpoint, api_key="sk-3sjLbdwqAhhAF") + assert len(client.list_models()) == 1 + client.terminate_model(model_uid=model_uid) + assert len(client.list_models()) == 0 diff --git a/xinference/conftest.py b/xinference/conftest.py index 0d2822969d..1dfeae0f0e 100644 --- a/xinference/conftest.py +++ b/xinference/conftest.py @@ -261,12 +261,23 @@ def setup_with_auth(): if not cluster_health_check(supervisor_addr, max_attempts=10, sleep_interval=3): raise RuntimeError("Cluster is not available after multiple attempts") - user1 = User(username="user1", password="pass1", permissions=["admin"]) - user2 = User(username="user2", password="pass2", permissions=["models:list"]) + user1 = User( + username="user1", + password="pass1", + permissions=["admin"], + api_keys=["sk-3sjLbdwqAhhAF", "sk-0HCRO1rauFQDL"], + ) + user2 = User( + username="user2", + password="pass2", + permissions=["models:list"], + api_keys=["sk-72tkvudyGLPMi"], + ) user3 = User( username="user3", password="pass3", permissions=["models:list", "models:read", "models:start"], + api_keys=["sk-m6jEzEwmCc4iQ", "sk-ZOTLIY4gt9w11"], ) auth_config = AuthConfig( algorithm="HS256", From aedb26cc51f3f83d49b3d7cd96acd778bae0a4cf Mon Sep 17 00:00:00 2001 From: Ago327 Date: Thu, 28 Mar 2024 10:42:05 +0800 Subject: [PATCH 13/18] test with openaiSDK --- xinference/api/oauth2/auth_service.py | 7 ++--- .../client/tests/test_client_with_auth.py | 27 ++++++++++++++++++- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/xinference/api/oauth2/auth_service.py b/xinference/api/oauth2/auth_service.py index aca562305e..7de97c1020 100644 --- a/xinference/api/oauth2/auth_service.py +++ b/xinference/api/oauth2/auth_service.py @@ -42,12 +42,9 @@ def config(self): return self._config @staticmethod - def is_legal_api_key(key: str): + def is_legal_api_key(key: str) -> bool: pattern = re.compile("^sk-[a-zA-Z0-9]{13}$") - if re.match(pattern, key): - return True - else: - return False + return re.match(pattern, key) is not None def init_auth_config(self): if self._auth_config_file: diff --git a/xinference/client/tests/test_client_with_auth.py b/xinference/client/tests/test_client_with_auth.py index 28857b800e..b2ce2a658d 100644 --- a/xinference/client/tests/test_client_with_auth.py +++ b/xinference/client/tests/test_client_with_auth.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest +from openai import AuthenticationError, OpenAI, PermissionDeniedError from ..restful.restful_client import Client as RESTfulClient from ..restful.restful_client import RESTfulEmbeddingModelHandle @@ -74,5 +75,29 @@ def test_client_auth(setup_with_auth): client = RESTfulClient(endpoint, api_key="sk-3sjLbdwqAhhAF") assert len(client.list_models()) == 1 - client.terminate_model(model_uid=model_uid) + + # test with openai SDK + client_ai = OpenAI(endpoint + "/v1", api_key="sk-wrongapikey12") + with pytest.raises(AuthenticationError): + client_ai.models.list() + + client_ai = OpenAI(endpoint + "/v1", api_key="sk-72tkvudyGLPMi") + assert len(client_ai.models.list().data) == 1 + + with pytest.raises(PermissionDeniedError): + chat_completion = client_ai.chat.completions.create( + model="bge-small-en-v1.5", + messages=[{"role": "user", "content": "write a poem."}], + ) + + client_ai = OpenAI(endpoint + "/v1", api_key="sk-72tkvudyGLPMi") + chat_completion = client_ai.chat.completions.create( + model="bge-small-en-v1.5", + messages=[{"role": "user", "content": "write a poem."}], + ) + + assert len(chat_completion["data"][0]["embedding"]) == 384 + + client_ai.terminate_model(model_uid) assert len(client.list_models()) == 0 + assert len(client_ai.models.list().data) == 0 From edf6ffbb64bd5a62ef36d44dbc639e20d974916f Mon Sep 17 00:00:00 2001 From: Ago327 Date: Thu, 28 Mar 2024 11:30:59 +0800 Subject: [PATCH 14/18] fix SDK --- xinference/client/tests/test_client_with_auth.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xinference/client/tests/test_client_with_auth.py b/xinference/client/tests/test_client_with_auth.py index b2ce2a658d..890acf12b4 100644 --- a/xinference/client/tests/test_client_with_auth.py +++ b/xinference/client/tests/test_client_with_auth.py @@ -77,11 +77,11 @@ def test_client_auth(setup_with_auth): assert len(client.list_models()) == 1 # test with openai SDK - client_ai = OpenAI(endpoint + "/v1", api_key="sk-wrongapikey12") + client_ai = OpenAI(base_url=endpoint + "/v1", api_key="sk-wrongapikey12") with pytest.raises(AuthenticationError): client_ai.models.list() - client_ai = OpenAI(endpoint + "/v1", api_key="sk-72tkvudyGLPMi") + client_ai = OpenAI(base_url=endpoint + "/v1", api_key="sk-72tkvudyGLPMi") assert len(client_ai.models.list().data) == 1 with pytest.raises(PermissionDeniedError): @@ -90,7 +90,7 @@ def test_client_auth(setup_with_auth): messages=[{"role": "user", "content": "write a poem."}], ) - client_ai = OpenAI(endpoint + "/v1", api_key="sk-72tkvudyGLPMi") + client_ai = OpenAI(base_url=endpoint + "/v1", api_key="sk-72tkvudyGLPMi") chat_completion = client_ai.chat.completions.create( model="bge-small-en-v1.5", messages=[{"role": "user", "content": "write a poem."}], From aea8dccdae98c3961c9c636450eecd8dfa4faf67 Mon Sep 17 00:00:00 2001 From: Ago327 Date: Thu, 28 Mar 2024 12:22:55 +0800 Subject: [PATCH 15/18] fix import and test key --- xinference/client/tests/test_client_with_auth.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/xinference/client/tests/test_client_with_auth.py b/xinference/client/tests/test_client_with_auth.py index 890acf12b4..291bae7fc7 100644 --- a/xinference/client/tests/test_client_with_auth.py +++ b/xinference/client/tests/test_client_with_auth.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest -from openai import AuthenticationError, OpenAI, PermissionDeniedError from ..restful.restful_client import Client as RESTfulClient from ..restful.restful_client import RESTfulEmbeddingModelHandle @@ -77,27 +76,30 @@ def test_client_auth(setup_with_auth): assert len(client.list_models()) == 1 # test with openai SDK + from openai import AuthenticationError, OpenAI, PermissionDeniedError + client_ai = OpenAI(base_url=endpoint + "/v1", api_key="sk-wrongapikey12") with pytest.raises(AuthenticationError): client_ai.models.list() client_ai = OpenAI(base_url=endpoint + "/v1", api_key="sk-72tkvudyGLPMi") assert len(client_ai.models.list().data) == 1 - with pytest.raises(PermissionDeniedError): chat_completion = client_ai.chat.completions.create( model="bge-small-en-v1.5", messages=[{"role": "user", "content": "write a poem."}], ) - client_ai = OpenAI(base_url=endpoint + "/v1", api_key="sk-72tkvudyGLPMi") + client_ai = OpenAI(base_url=endpoint + "/v1", api_key="sk-ZOTLIY4gt9w11") chat_completion = client_ai.chat.completions.create( model="bge-small-en-v1.5", messages=[{"role": "user", "content": "write a poem."}], ) - assert len(chat_completion["data"][0]["embedding"]) == 384 + with pytest.raises(RuntimeError): + client_ai.terminate_model(model_uid) + client_ai = OpenAI(base_url=endpoint + "/v1", api_key="sk-3sjLbdwqAhhAF") client_ai.terminate_model(model_uid) assert len(client.list_models()) == 0 assert len(client_ai.models.list().data) == 0 From f50a5b250e6fd0f0fdeae98e62e0e8c80d910362 Mon Sep 17 00:00:00 2001 From: Ago327 Date: Thu, 28 Mar 2024 15:04:35 +0800 Subject: [PATCH 16/18] fix embedding --- xinference/client/tests/test_client_with_auth.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xinference/client/tests/test_client_with_auth.py b/xinference/client/tests/test_client_with_auth.py index 291bae7fc7..9eecffa9e4 100644 --- a/xinference/client/tests/test_client_with_auth.py +++ b/xinference/client/tests/test_client_with_auth.py @@ -85,15 +85,15 @@ def test_client_auth(setup_with_auth): client_ai = OpenAI(base_url=endpoint + "/v1", api_key="sk-72tkvudyGLPMi") assert len(client_ai.models.list().data) == 1 with pytest.raises(PermissionDeniedError): - chat_completion = client_ai.chat.completions.create( + chat_completion = client_ai.embeddings.create( model="bge-small-en-v1.5", - messages=[{"role": "user", "content": "write a poem."}], + input="write a poem.", ) client_ai = OpenAI(base_url=endpoint + "/v1", api_key="sk-ZOTLIY4gt9w11") - chat_completion = client_ai.chat.completions.create( + chat_completion = client_ai.embeddings.create( model="bge-small-en-v1.5", - messages=[{"role": "user", "content": "write a poem."}], + input="write a poem.", ) assert len(chat_completion["data"][0]["embedding"]) == 384 with pytest.raises(RuntimeError): From 1daa967dfafce6bb95c3cbf9e1bcf8993e4c9f3c Mon Sep 17 00:00:00 2001 From: Ago327 Date: Thu, 28 Mar 2024 15:38:50 +0800 Subject: [PATCH 17/18] embedding --- xinference/client/tests/test_client_with_auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xinference/client/tests/test_client_with_auth.py b/xinference/client/tests/test_client_with_auth.py index 9eecffa9e4..43c0a2a51f 100644 --- a/xinference/client/tests/test_client_with_auth.py +++ b/xinference/client/tests/test_client_with_auth.py @@ -95,7 +95,7 @@ def test_client_auth(setup_with_auth): model="bge-small-en-v1.5", input="write a poem.", ) - assert len(chat_completion["data"][0]["embedding"]) == 384 + assert len(chat_completion.data[0].embedding) == 384 with pytest.raises(RuntimeError): client_ai.terminate_model(model_uid) From 5acf72f6438e18f3c739e436b8456b94eff3ea79 Mon Sep 17 00:00:00 2001 From: Ago327 Date: Thu, 28 Mar 2024 16:18:52 +0800 Subject: [PATCH 18/18] fix terminate --- xinference/client/tests/test_client_with_auth.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xinference/client/tests/test_client_with_auth.py b/xinference/client/tests/test_client_with_auth.py index 43c0a2a51f..68a6bc3221 100644 --- a/xinference/client/tests/test_client_with_auth.py +++ b/xinference/client/tests/test_client_with_auth.py @@ -96,10 +96,8 @@ def test_client_auth(setup_with_auth): input="write a poem.", ) assert len(chat_completion.data[0].embedding) == 384 - with pytest.raises(RuntimeError): - client_ai.terminate_model(model_uid) client_ai = OpenAI(base_url=endpoint + "/v1", api_key="sk-3sjLbdwqAhhAF") - client_ai.terminate_model(model_uid) + client.terminate_model(model_uid) assert len(client.list_models()) == 0 assert len(client_ai.models.list().data) == 0